pytorch邏輯回歸實(shí)現(xiàn)步驟詳解
1. 導(dǎo)入庫(kù)
機(jī)器學(xué)習(xí)的任務(wù)分為兩大類:分類和回歸
分類是對(duì)一堆目標(biāo)進(jìn)行識(shí)別歸類,例如貓狗分類、手寫(xiě)數(shù)字分類等等
回歸是對(duì)某樣事物接下來(lái)行為的預(yù)測(cè),例如預(yù)測(cè)天氣等等
這次我們要完成的任務(wù)是邏輯回歸,雖然名字叫做回歸,其實(shí)是個(gè)二元分類的任務(wù)
首先看看我們需要的庫(kù)文件
torch.nn 是專門(mén)為神經(jīng)網(wǎng)絡(luò)設(shè)計(jì)的接口
matplotlib 用來(lái)繪制圖像,幫助可視化任務(wù)
torch 定義張量,數(shù)據(jù)的傳輸利用張量來(lái)實(shí)現(xiàn)
optim 優(yōu)化器的包,例如SGD等
numpy 數(shù)據(jù)處理的包
2. 定義數(shù)據(jù)集
簡(jiǎn)單說(shuō)明一下任務(wù),想在一個(gè)正方形的區(qū)域內(nèi)生成若干點(diǎn),然后手工設(shè)計(jì)label,最后通過(guò)神經(jīng)網(wǎng)絡(luò)的訓(xùn)練,畫(huà)出決策邊界
假設(shè):正方形的邊長(zhǎng)是2,左下角的坐標(biāo)為(0,0),右上角的坐標(biāo)為(2,2)
然后我們手工定義分界線 y = x ,在分界線的上方定義為藍(lán)色,下方定義為紅色
2.1 生成數(shù)據(jù)
首先生成數(shù)據(jù)的代碼為
首先通過(guò)rand(0-1的均勻分布)生成200個(gè)點(diǎn),并將他們擴(kuò)大2倍,x1代表橫坐標(biāo),x2代表縱坐標(biāo)
然后定義一下分類,這里簡(jiǎn)單介紹一下zip函數(shù)。
zip會(huì)將這里的a,b對(duì)應(yīng)打包成一對(duì),這樣i對(duì)應(yīng)的就是(1,‘a’),i[0] 對(duì)應(yīng)的就是1 2 3
再回到我們的代碼,因?yàn)槲覀円獙?shí)現(xiàn)的是二元分類,所以我們定義兩個(gè)不同的類型,用pos,neg存起來(lái)。然后我們知道i[1] 代表的是 x2 ,i[0] 代表的是x1 , 所以 x2 - x1 <0 也就是也就是在直線y=x的下面為pos類型。否則,為neg類型
最后,我們需要將pos,neg類型的繪制出來(lái)。因?yàn)閜os里面其實(shí)是類似于(1,1)這樣的坐標(biāo),因?yàn)閜os.append(i) 里面的 i 其實(shí)是(x1,x2) 的坐標(biāo)形式, 所以我們將pos 里面的第一個(gè)元素x1定義為賦值給橫坐標(biāo),第二個(gè)元素x2賦值給縱坐標(biāo)
然后通過(guò)scatter 繪制離散的點(diǎn)就可以,將pos 繪制成 red 顏色,neg 繪制成 blue 顏色,如圖
2.2 設(shè)置label
我們進(jìn)行的其實(shí)是有監(jiān)督學(xué)習(xí),所以需要label
這里需要注意的是,不同于回歸任務(wù),x1不是輸入,x2也不是輸出。應(yīng)該x1,x2都是輸入的元素,也就是特征feature。所以我們應(yīng)該將紅色的點(diǎn)集設(shè)置一個(gè)標(biāo)簽,例如 1 ,藍(lán)色的點(diǎn)集設(shè)置一個(gè)標(biāo)簽,例如 0.
實(shí)現(xiàn)代碼如下
很容易理解,訓(xùn)練集x_data 應(yīng)該是所有樣本,也就是pos和neg的所以元素。而之前介紹了x1,x2都是輸入的特征,那么x_data的shape 應(yīng)該是 [200,2] 的。而y_data 只有1(pos 紅色)類別,或者 0(neg 藍(lán)色)類型,所以y_data 的shape 應(yīng)該是 [200,1] 的。y_data view的原因是變成矩陣的形式而不是向量的形式
這里的意思是,假如坐標(biāo)是(1.5,0.5)那么應(yīng)該落在紅色區(qū)域,那么這個(gè)點(diǎn)的標(biāo)簽就是1
3. 搭建網(wǎng)絡(luò)+優(yōu)化器
網(wǎng)絡(luò)的類型很簡(jiǎn)單,不再贅述。至于為什么要繼承nn.Module或者super那步是干啥的不用管,基本上都是這樣寫(xiě)的,記住就行。
需要注意的是我們輸入的特征是(n * 2) ,所以Linear 應(yīng)該是(2,1)
二元分類最后的輸出一般選用sigmoid函數(shù)
這里的損失函數(shù)我們選擇BCE,二元交叉熵?fù)p失函數(shù)。
算法為隨機(jī)梯度下降
4. 訓(xùn)練
訓(xùn)練的過(guò)程也比較簡(jiǎn)單,就是將模型的預(yù)測(cè)輸出值和真實(shí)的label作比較。然后將梯度歸零,在反向傳播并且更新梯度。
5. 繪制決策邊界
這里模型訓(xùn)練完成后,將w0,w1 ,b取出來(lái),然后繪制出直線
這里要繪制的是w0 * x1+ w1 * x2 + b = 0 ,因?yàn)樽铋_(kāi)始介紹了x1代表橫坐標(biāo)x,x2代表縱坐標(biāo)y。通過(guò)變形可知y = (-w0 * x1 - b ) / w1,結(jié)果如圖
程序輸出的損失為
最后,w0 = 4.1911 , w1 = -4.0290 ,b = 0.0209 ,近似等于y = x,和我們剛開(kāi)始定義的分界線類似
6. 代碼
import torch.nn as nn import matplotlib.pyplot as plt import torch from torch import optim import numpy as np torch.manual_seed(1) # 保證程序隨機(jī)生成數(shù)一樣 x1 = torch.rand(200) * 2 x2 = torch.rand(200) * 2 data = zip(x1,x2) pos = [] # 定義類型 1 neg = [] # 定義類型 2 def classification(data): for i in data: if(i[1] - i[0] < 0): pos.append(i) else: neg.append(i) classification(data) pos_x = [i[0] for i in pos] pos_y = [i[1] for i in pos] neg_x = [i[0] for i in neg] neg_y = [i[1] for i in neg] plt.scatter(pos_x,pos_y,c='r') plt.scatter(neg_x,neg_y,c='b') plt.show() x_data = [[i[0],i[1]] for i in pos] x_data.extend([[i[0],i[1]] for i in neg]) x_data = torch.Tensor(x_data) # 輸入數(shù)據(jù) feature y_data = [1 for i in range(len(pos))] y_data.extend([0 for i in range(len(neg))]) y_data = torch.Tensor(y_data).view(-1,1) # 對(duì)應(yīng)的標(biāo)簽 class LogisticRegressionModel(nn.Module): # 定義網(wǎng)絡(luò) def __init__(self): super(LogisticRegressionModel,self).__init__() self.linear = nn.Linear(2,1) self.sigmoid = nn.Sigmoid() def forward(self,x): x = self.linear(x) x = self.sigmoid(x) return x model = LogisticRegressionModel() criterion = nn.BCELoss() optimizer = optim.SGD(model.parameters(),lr =0.01) for epoch in range(10000): y_pred = model(x_data) loss = criterion(y_pred,y_data) # 計(jì)算損失值 if epoch % 1000 == 0: print(epoch,loss.item()) # 打印損失值 optimizer.zero_grad() # 梯度清零 loss.backward() # 反向傳播 optimizer.step() # 梯度更新 w = model.linear.weight[0] # 取出訓(xùn)練完成的結(jié)果 w0 = w[0] w1 = w[1] b = model.linear.bias.item() with torch.no_grad(): # 繪制決策邊界,這里不需要計(jì)算梯度 x= torch.arange(0,3).view(-1,1) y = (- w0 * x - b) / w1 plt.plot(x.numpy(),y.numpy()) plt.scatter(pos_x,pos_y,c='r') plt.scatter(neg_x,neg_y,c='b') plt.xlim(0,2) plt.ylim(0,2) plt.show()
程序結(jié)果
到此這篇關(guān)于pytorch邏輯回歸實(shí)現(xiàn)步驟詳解的文章就介紹到這了,更多相關(guān)pytorch邏輯回歸內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python實(shí)現(xiàn)自動(dòng)裝機(jī)功能案例分析
這篇文章主要介紹了Python實(shí)現(xiàn)自動(dòng)裝機(jī)功能,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-10-10手把手教你利用Python創(chuàng)建一個(gè)游戲窗口
pygame是python用來(lái)寫(xiě)游戲的擴(kuò)展包,用這個(gè)擴(kuò)展包,可以比較容易的構(gòu)造一個(gè)游戲窗口,這篇文章主要給大家介紹了關(guān)于如何利用Python創(chuàng)建一個(gè)游戲窗口的相關(guān)資料,需要的朋友可以參考下2022-07-07Python設(shè)計(jì)模式中的結(jié)構(gòu)型適配器模式
這篇文章主要介紹了Python設(shè)計(jì)中的結(jié)構(gòu)型適配器模式,適配器模式即Adapter?Pattern,將一個(gè)類的接口轉(zhuǎn)換成為客戶希望的另外一個(gè)接口,下文內(nèi)容具有一定的參考價(jià)值,需要的小伙伴可以參考一下2022-02-02Python中出現(xiàn)IndentationError:unindent does not match any outer
今天在網(wǎng)上copy的一段代碼,代碼很簡(jiǎn)單,每行看起來(lái)該縮進(jìn)的都縮進(jìn)了,運(yùn)行的時(shí)候出現(xiàn)了如下錯(cuò)誤,IndentationError: unindent does not match any outer indentation level,如果看起來(lái)縮進(jìn)正常所有tab與空格混用就會(huì)出現(xiàn)這個(gè)問(wèn)題2019-01-01Python使用MD5加密算法對(duì)字符串進(jìn)行加密操作示例
這篇文章主要介紹了Python使用MD5加密算法對(duì)字符串進(jìn)行加密操作,結(jié)合實(shí)例形式分析了Python實(shí)現(xiàn)md5加密相關(guān)操作技巧,需要的朋友可以參考下2018-03-03在Python中操作MongoDB的詳細(xì)教程和案例分享
MongoDB是一個(gè)高性能、開(kāi)源、無(wú)模式的文檔型數(shù)據(jù)庫(kù),非常適合存儲(chǔ)JSON風(fēng)格的數(shù)據(jù),Python作為一種廣泛使用的編程語(yǔ)言,通過(guò)PyMongo庫(kù)可以方便地與MongoDB進(jìn)行交互,本文將詳細(xì)介紹如何在Python中使用PyMongo庫(kù)來(lái)操作MongoDB數(shù)據(jù)庫(kù),需要的朋友可以參考下2024-08-08