PyTorch實現(xiàn)線性回歸詳細過程
一、實現(xiàn)步驟
1、準備數(shù)據(jù)
x_data = torch.tensor([[1.0],[2.0],[3.0]]) y_data = torch.tensor([[2.0],[4.0],[6.0]])
2、設(shè)計模型
class LinearModel(torch.nn.Module): ? ? def __init__(self): ? ? ? ? super(LinearModel,self).__init__() ? ? ? ? self.linear = torch.nn.Linear(1,1) ? ? ? ?? ? ? def forward(self, x): ? ? ? ? y_pred = self.linear(x) ? ? ? ? return y_pred ? ? ? ?? model = LinearModel() ?
3、構(gòu)造損失函數(shù)和優(yōu)化器
criterion = torch.nn.MSELoss(reduction='sum') optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
4、訓練過程
epoch_list = [] loss_list = [] w_list = [] b_list = [] for epoch in range(1000): ? ? y_pred = model(x_data)?? ??? ??? ??? ??? ? ?# 計算預測值 ? ? loss = criterion(y_pred, y_data)?? ?# 計算損失 ? ? print(epoch,loss) ? ?? ? ? epoch_list.append(epoch) ? ? loss_list.append(loss.data.item()) ? ? w_list.append(model.linear.weight.item()) ? ? b_list.append(model.linear.bias.item()) ? ?? ? ? optimizer.zero_grad() ? # 梯度歸零 ? ? loss.backward() ? ? ? ? # 反向傳播 ? ? optimizer.step() ? ? ? ?# 更新
5、結(jié)果展示
展示最終的權(quán)重和偏置:
# 輸出權(quán)重和偏置
print('w = ',model.linear.weight.item())
print('b = ',model.linear.bias.item())結(jié)果為:
w = 1.9998501539230347
b = 0.0003405189490877092
模型測試:
# 測試模型
x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ',y_test.data)
y_pred = ?tensor([[7.9997]])分別繪制損失值隨迭代次數(shù)變化的二維曲線圖和其隨權(quán)重與偏置變化的三維散點圖:
# 二維曲線圖
plt.plot(epoch_list,loss_list,'b')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()
# 三維散點圖
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(w_list,b_list,loss_list,c='r')
#設(shè)置坐標軸
ax.set_xlabel('weight')
ax.set_ylabel('bias')
ax.set_zlabel('loss')
plt.show()結(jié)果如下圖所示:


到此這篇關(guān)于PyTorch實現(xiàn)線性回歸詳細過程的文章就介紹到這了,更多相關(guān)PyTorch線性回歸內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
二、參考文獻
相關(guān)文章
淺析python表達式4+0.5值的數(shù)據(jù)類型
在本篇文章里小編給大家整理的是一篇關(guān)于python表達式4+0.5值的數(shù)據(jù)類型的知識點內(nèi)容,需要的的朋友們學習下。2020-02-02
Python3使用Selenium獲取session和token方法詳解
這篇文章主要介紹了Python3使用Selenium獲取session和token方法詳解,需要的朋友可以參考下2021-02-02
Python+pyecharts繪制雙動態(tài)曲線教程詳解
pyecharts 是一個用于生成 Echarts 圖表的類庫。Echarts 是百度開源的一個數(shù)據(jù)可視化 JS 庫。用 Echarts 生成的圖可視化效果非常棒。本文將用pyecharts繪制雙動態(tài)曲線,需要的可以參考一下2022-06-06
推薦一款高效的python數(shù)據(jù)框處理工具Sidetable
這篇文章主要為大家介紹推薦一款高效的python數(shù)據(jù)框處理工具Sidetable,文章詳細的講解了Sidetable的安裝及用法,有需要的朋友可以借鑒參考下,希望能夠有所幫助2021-11-11
Python解決MySQL數(shù)據(jù)處理從SQL批量刪除報錯
這篇文章主要為大家介紹了Python解決MySQL數(shù)據(jù)處理從SQL批量刪除報錯,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪2023-12-12
基于python的opencv圖像處理實現(xiàn)對斑馬線的檢測示例
這篇文章主要介紹了基于python的opencv圖像處理實現(xiàn)對斑馬線的檢測示例,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2020-11-11

