3 回答

TA貢獻1998條經驗 獲得超6個贊
我似乎無法弄清楚如何使線性回歸線(又名最佳擬合線)跨越圖形的整個寬度。它似乎只是上升了左邊最遠的數據點和右邊最遠的數據點,沒有進一步。我將如何解決這個問題?
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from scipy.interpolate import *
import MySQLdb
# connect to MySQL database
def mysql_select_all():
conn = MySQLdb.connect(host='localhost',
user='root',
passwd='XXXXX',
db='world')
cursor = conn.cursor()
sql = """
SELECT
GNP, Population
FROM
country
WHERE
Name LIKE 'United States'
OR Name LIKE 'Canada'
OR Name LIKE 'United Kingdom'
OR Name LIKE 'Russia'
OR Name LIKE 'Germany'
OR Name LIKE 'Poland'
OR Name LIKE 'Italy'
OR Name LIKE 'China'
OR Name LIKE 'India'
OR Name LIKE 'Japan'
OR Name LIKE 'Brazil';
"""
cursor.execute(sql)
result = cursor.fetchall()
list_x = []
list_y = []
for row in result:
list_x.append(('%r' % (row[0],)))
for row in result:
list_y.append(('%r' % (row[1],)))
list_x = list(map(float, list_x))
list_y = list(map(float, list_y))
fig = plt.figure()
ax1 = plt.subplot2grid((1,1), (0,0))
p1 = np.polyfit(list_x, list_y, 1) # this line refers to line of regression
ax1.xaxis.labelpad = 50
ax1.yaxis.labelpad = 50
plt.plot(list_x, np.polyval(p1,list_x),'r-') # this refers to line of regression
plt.scatter(list_x, list_y, color = 'darkgreen', s = 100)
plt.xlabel("GNP (US dollars)", fontsize=30)
plt.ylabel("Population(in billions)", fontsize=30)
plt.xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000,
7000000, 8000000, 9000000], rotation=45, fontsize=14)
plt.yticks(fontsize=14)
plt.show()
cursor.close()
mysql_select_all()
而延長之后,

TA貢獻1906條經驗 獲得超10個贊
如果您希望繪圖不超出 x 軸上的數據,只需執行以下操作:
fig, ax = plt.subplots()
ax.margins(x=0)
# Don't use plt.plot
ax.plot(list_x, np.polyval(p1,list_x),'r-')
ax.scatter(list_x, list_y, color = 'darkgreen', s = 100)
ax.set_xlabel("GNP (US dollars)", fontsize=30)
ax.set_ylabel("Population(in billions)", fontsize=30)
ax.set_xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000, 7000000, 8000000, 9000000], rotation=45, fontsize=14)
ax.tick_params(axis='y', labelsize=14)
添加回答
舉報