亚洲在线久爱草,狠狠天天香蕉网,天天搞日日干久草,伊人亚洲日本欧美

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

Matplotlib - 將線性回歸線擴展到圖的整個寬度

Matplotlib - 將線性回歸線擴展到圖的整個寬度

神不在的星期二 2021-10-26 15:42:51
我似乎無法弄清楚如何使線性回歸線(又名最佳擬合線)跨越圖形的整個寬度。它似乎只是上升了左邊最遠的數據點和右邊最遠的數據點,沒有進一步。我將如何解決這個問題?import matplotlib.pyplot as pltimport numpy as npfrom scipy import statsfrom scipy.interpolate import *import MySQLdb# connect to MySQL databasedef mysql_select_all():    conn = MySQLdb.connect(host='localhost',                           user='root',                           passwd='XXXXX',                           db='world')    cursor = conn.cursor()    sql = """        SELECT            GNP, Population        FROM            country        WHERE            Name LIKE 'United States'                OR Name LIKE 'Canada'                OR Name LIKE 'United Kingdom'                OR Name LIKE 'Russia'                OR Name LIKE 'Germany'                OR Name LIKE 'Poland'                OR Name LIKE 'Italy'                OR Name LIKE 'China'                OR Name LIKE 'India'                OR Name LIKE 'Japan'                OR Name LIKE 'Brazil';    """    cursor.execute(sql)    result = cursor.fetchall()    list_x = []    list_y = []    for row in result:        list_x.append(('%r' % (row[0],)))    for row in result:        list_y.append(('%r' % (row[1],)))    list_x = list(map(float, list_x))    list_y = list(map(float, list_y))    fig = plt.figure()    ax1 = plt.subplot2grid((1,1), (0,0))    p1 = np.polyfit(list_x, list_y, 1)          # this line refers to line of regression    ax1.xaxis.labelpad = 50    ax1.yaxis.labelpad = 50    plt.plot(list_x, np.polyval(p1,list_x),'r-') # this refers to line of regression      plt.scatter(list_x, list_y, color = 'darkgreen', s = 100)    plt.xlabel("GNP (US dollars)", fontsize=30)    plt.ylabel("Population(in billions)", fontsize=30)    plt.xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000,                 7000000, 8000000, 9000000],  rotation=45, fontsize=14)    plt.yticks(fontsize=14)    plt.show()    cursor.close()mysql_select_all()
查看完整描述

3 回答

?
米琪卡哇伊

TA貢獻1998條經驗 獲得超6個贊

我似乎無法弄清楚如何使線性回歸線(又名最佳擬合線)跨越圖形的整個寬度。它似乎只是上升了左邊最遠的數據點和右邊最遠的數據點,沒有進一步。我將如何解決這個問題?


import matplotlib.pyplot as plt

import numpy as np

from scipy import stats

from scipy.interpolate import *

import MySQLdb


# connect to MySQL database

def mysql_select_all():

    conn = MySQLdb.connect(host='localhost',

                           user='root',

                           passwd='XXXXX',

                           db='world')

    cursor = conn.cursor()

    sql = """

        SELECT

            GNP, Population

        FROM

            country

        WHERE

            Name LIKE 'United States'

                OR Name LIKE 'Canada'

                OR Name LIKE 'United Kingdom'

                OR Name LIKE 'Russia'

                OR Name LIKE 'Germany'

                OR Name LIKE 'Poland'

                OR Name LIKE 'Italy'

                OR Name LIKE 'China'

                OR Name LIKE 'India'

                OR Name LIKE 'Japan'

                OR Name LIKE 'Brazil';

    """


    cursor.execute(sql)

    result = cursor.fetchall()


    list_x = []

    list_y = []


    for row in result:

        list_x.append(('%r' % (row[0],)))


    for row in result:

        list_y.append(('%r' % (row[1],)))


    list_x = list(map(float, list_x))

    list_y = list(map(float, list_y))


    fig = plt.figure()

    ax1 = plt.subplot2grid((1,1), (0,0))


    p1 = np.polyfit(list_x, list_y, 1)          # this line refers to line of regression


    ax1.xaxis.labelpad = 50

    ax1.yaxis.labelpad = 50


    plt.plot(list_x, np.polyval(p1,list_x),'r-') # this refers to line of regression  

    plt.scatter(list_x, list_y, color = 'darkgreen', s = 100)

    plt.xlabel("GNP (US dollars)", fontsize=30)

    plt.ylabel("Population(in billions)", fontsize=30)

    plt.xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000, 

                7000000, 8000000, 9000000],  rotation=45, fontsize=14)

    plt.yticks(fontsize=14)


    plt.show()

    cursor.close()


mysql_select_all()

http://img1.sycdn.imooc.com//6177b19d0001c4a019581284.jpg

而延長之后,

http://img1.sycdn.imooc.com//6177b1aa00016a3619621284.jpg


查看完整回答
反對 回復 2021-10-26
?
隔江千里

TA貢獻1906條經驗 獲得超10個贊

如果您希望繪圖不超出 x 軸上的數據,只需執行以下操作:


fig, ax = plt.subplots()

ax.margins(x=0)

# Don't use plt.plot

ax.plot(list_x, np.polyval(p1,list_x),'r-')

ax.scatter(list_x, list_y, color = 'darkgreen', s = 100)

ax.set_xlabel("GNP (US dollars)", fontsize=30)

ax.set_ylabel("Population(in billions)", fontsize=30)

ax.set_xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000, 7000000, 8000000, 9000000],  rotation=45, fontsize=14)

ax.tick_params(axis='y', labelsize=14)


查看完整回答
反對 回復 2021-10-26
  • 3 回答
  • 0 關注
  • 462 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯系客服咨詢優惠詳情

幫助反饋 APP下載

慕課網APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網微信公眾號