pytorch使用nn.Moudle實現(xiàn)邏輯回歸
更新時間:2022年07月30日 15:42:35 作者:ALEN.Z
這篇文章主要為大家詳細(xì)介紹了pytorch使用nn.Moudle實現(xiàn)邏輯回歸,文中示例代碼介紹的非常詳細(xì),具有一定的參考價值,感興趣的小伙伴們可以參考一下
本文實例為大家分享了pytorch使用nn.Moudle實現(xiàn)邏輯回歸的具體代碼,供大家參考,具體內(nèi)容如下
內(nèi)容
pytorch使用nn.Moudle實現(xiàn)邏輯回歸
問題
loss下降不明顯
解決方法
#源代碼 out的數(shù)據(jù)接收方式 ? ? ?if torch.cuda.is_available(): ? ? ? ? ?x_data=Variable(x).cuda() ? ? ? ? ?y_data=Variable(y).cuda() ? ? ?else: ? ? ? ? ?x_data=Variable(x) ? ? ? ? ?y_data=Variable(y) ? ?? ? ? out=logistic_model(x_data) ?#根據(jù)邏輯回歸模型擬合出的y值 ? ? loss=criterion(out.squeeze(),y_data) ?#計算損失函數(shù)
#源代碼 out的數(shù)據(jù)有拼裝數(shù)據(jù)直接輸入 # ? ? if torch.cuda.is_available(): # ? ? ? ? x_data=Variable(x).cuda() # ? ? ? ? y_data=Variable(y).cuda() # ? ? else: # ? ? ? ? x_data=Variable(x) # ? ? ? ? y_data=Variable(y) ? ?? ? ? out=logistic_model(x_data) ?#根據(jù)邏輯回歸模型擬合出的y值 ? ? loss=criterion(out.squeeze(),y_data) ?#計算損失函數(shù) ? ? print_loss=loss.data.item() ?#得出損失函數(shù)值
源代碼
import torch from torch import nn from torch.autograd import Variable import matplotlib.pyplot as plt import numpy as np #生成數(shù)據(jù) sample_nums = 100 mean_value = 1.7 bias = 1 n_data = torch.ones(sample_nums, 2) x0 = torch.normal(mean_value * n_data, 1) + bias ? ? ?# 類別0 數(shù)據(jù) shape=(100, 2) y0 = torch.zeros(sample_nums) ? ? ? ? ? ? ? ? ? ? ? ? # 類別0 標(biāo)簽 shape=(100, 1) x1 = torch.normal(-mean_value * n_data, 1) + bias ? ? # 類別1 數(shù)據(jù) shape=(100, 2) y1 = torch.ones(sample_nums) ? ? ? ? ? ? ? ? ? ? ? ? ?# 類別1 標(biāo)簽 shape=(100, 1) x_data = torch.cat((x0, x1), 0) ?#按維數(shù)0行拼接 y_data = torch.cat((y0, y1), 0) #畫圖 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn') plt.show() # 利用torch.nn實現(xiàn)邏輯回歸 class LogisticRegression(nn.Module): ? ? def __init__(self): ? ? ? ? super(LogisticRegression, self).__init__() ? ? ? ? self.lr = nn.Linear(2, 1) ? ? ? ? self.sm = nn.Sigmoid() ? ? def forward(self, x): ? ? ? ? x = self.lr(x) ? ? ? ? x = self.sm(x) ? ? ? ? return x ? ?? logistic_model = LogisticRegression() # if torch.cuda.is_available(): # ? ? logistic_model.cuda() #loss函數(shù)和優(yōu)化 criterion = nn.BCELoss() optimizer = torch.optim.SGD(logistic_model.parameters(), lr=0.01, momentum=0.9) #開始訓(xùn)練 #訓(xùn)練10000次 for epoch in range(10000): # ? ? if torch.cuda.is_available(): # ? ? ? ? x_data=Variable(x).cuda() # ? ? ? ? y_data=Variable(y).cuda() # ? ? else: # ? ? ? ? x_data=Variable(x) # ? ? ? ? y_data=Variable(y) ? ?? ? ? out=logistic_model(x_data) ?#根據(jù)邏輯回歸模型擬合出的y值 ? ? loss=criterion(out.squeeze(),y_data) ?#計算損失函數(shù) ? ? print_loss=loss.data.item() ?#得出損失函數(shù)值 ? ? #反向傳播 ? ? loss.backward() ? ? optimizer.step() ? ? optimizer.zero_grad() ? ?? ? ? mask=out.ge(0.5).float() ?#以0.5為閾值進(jìn)行分類 ? ? correct=(mask==y_data).sum().squeeze() ?#計算正確預(yù)測的樣本個數(shù) ? ? acc=correct.item()/x_data.size(0) ?#計算精度 ? ? #每隔20輪打印一下當(dāng)前的誤差和精度 ? ? if (epoch+1)%100==0: ? ? ? ? print('*'*10) ? ? ? ? print('epoch {}'.format(epoch+1)) ?#誤差 ? ? ? ? print('loss is {:.4f}'.format(print_loss)) ? ? ? ? print('acc is {:.4f}'.format(acc)) ?#精度 ? ? ? ?? ? ? ? ?? w0, w1 = logistic_model.lr.weight[0] w0 = float(w0.item()) w1 = float(w1.item()) b = float(logistic_model.lr.bias.item()) plot_x = np.arange(-7, 7, 0.1) plot_y = (-w0 * plot_x - b) / w1 plt.xlim(-5, 7) plt.ylim(-7, 7) plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=logistic_model(x_data)[:,0].cpu().data.numpy(), s=100, lw=0, cmap='RdYlGn') plt.plot(plot_x, plot_y) plt.show()
輸出結(jié)果
以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
簡單介紹一下pyinstaller打包以及安全性的實現(xiàn)
這篇文章主要介紹了簡單介紹一下pyinstaller打包以及安全性的實現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-06-06調(diào)用其他python腳本文件里面的類和方法過程解析
這篇文章主要介紹了調(diào)用其他python腳本文件里面的類和方法過程解析,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2019-11-11使用Python的toolz庫開始函數(shù)式編程的方法
這篇文章主要介紹了使用Python的toolz庫開始函數(shù)式編程的方法,非常不錯,具有一定的參考借鑒價值,需要的朋友可以參考下2018-11-11