PyTorch使用torch.nn.Module模塊自定義模型結(jié)構(gòu)方式
以實現(xiàn)LeNet網(wǎng)絡(luò)為例,來學習使用pytorch如何搭建一個神經(jīng)網(wǎng)絡(luò)。
LeNet網(wǎng)絡(luò)的結(jié)構(gòu)如下圖所示。
一、使用torch.nn.Module類構(gòu)建網(wǎng)絡(luò)模型
搭建自己的網(wǎng)絡(luò)模型,我們需要新建一個類,讓它繼承torch.nn.Module類,并必須重寫Module類中的__init__()和forward()函數(shù)。
init()函數(shù)用來申明模型中各層的定義,forward()函數(shù)用來描述各層之間的連接關(guān)系,定義前向傳播計算的過程。
也就是說__init__()函數(shù)只是用來定義層,但并沒有將它們連接起來,forward()函數(shù)的作用就是將這些定義好的層連接成網(wǎng)絡(luò)。
使用上述方法實現(xiàn)LeNet網(wǎng)絡(luò)的代碼如下。
import torch.nn as nn class LeNet(nn.Module): def __init__(self): super().__init__() self.C1 = nn.Conv2d(1, 6, 5) self.sig = nn.Sigmoid() self.S2 = nn.MaxPool2d(2, 2) self.C3 = nn.Conv2d(6, 16, 5) self.S4 = nn.MaxPool2d(2, 2) self.C5 = nn.Conv2d(16, 120, 5) self.C6 = nn.Linear(120, 84) self.C7 = nn.Linear(84, 10) def forward(self, x): x1 = self.C1(x) x2 = self.sig(x1) x3 = self.S2(x2) x4 = self.C3(x3) x5 = self.sig(x4) x6 = self.S4(x5) x7 = self.C5(x6) x8 = self.C6(x7) y = self.C7(x8) return y net = LeNet() print(net)
結(jié)果為
在__init__()函數(shù)中,實例化了nn.Linear()、nn.Conv2d()這種pytorch封裝好的類,用來定義全連接層、卷積層等網(wǎng)絡(luò)層,并規(guī)定好它們的參數(shù)。
例如,self.C1 = nn.Conv2d(1, 6, 5)表示定義一個卷積層,它的卷積核輸入通道為1、輸出通道為6,大小為5×5。
真正向這個卷積層輸入數(shù)據(jù)是在forward()函數(shù)中,x1 = self.C1(x)表示將輸入x喂給卷積層,并得到輸出x1。
二、引入torch.nn.functional實現(xiàn)層的運算
引入torch.nn.functional模塊中的函數(shù),可以簡化__init__()函數(shù)中的內(nèi)容。
在__init__()函數(shù)中,我們可以只定義具有需要學習的參數(shù)的層,如卷積層、線性層,它們的權(quán)重都需要學習。
對于不需要學習參數(shù)的層,我們不需要在__init__()函數(shù)中定義,只需要在forward()函數(shù)中引入torch.nn.functional類中相關(guān)函數(shù)的調(diào)用。
例如LeNet中,我們在__init__()中只定義了卷積層和全連接層。池化層和激活函數(shù)只需要在forward()函數(shù)中,調(diào)用torch.nn.functional中的函數(shù)進行實現(xiàn)即可。
import torch.nn as nn import torch.nn.functional as F class LeNet(nn.Module): def __init__(self): super().__init__() self.C1 = nn.Conv2d(1, 6, 5) self.C3 = nn.Conv2d(6, 16, 5) self.C5 = nn.Conv2d(16, 120, 5) self.C6 = nn.Linear(120, 84) self.C7 = nn.Linear(84, 10) def forward(self, x): x1 = self.C1(x) x2 = F.sigmoid(x1) x3 = F.max_pool2d(x2) x4 = self.C3(x3) x5 = F.sigmoid(x4) x6 = F.max_pool2d(x5) x7 = self.C5(x6) x8 = self.C6(x7) y = self.C7(x8) return y net = LeNet() print(net)
運行結(jié)果為
當然,torch.nn.functional中也對需要學習參數(shù)的層進行了實現(xiàn),包括卷積層conv2d()和線性層linear(),但pytorch官方推薦我們只對不需要學習參數(shù)的層使用nn.functional中的函數(shù)。
對于一個層,使用nn.Xxx實現(xiàn)和使用nn.functional.xxx()實現(xiàn)的區(qū)別為:
1.nn.Xxx是一個類,繼承自nn.Modules,因此內(nèi)部會有很多屬性和方法,如train(), eval(),load_state_dict, state_dict 等。
2.nn.functional.xxx()僅僅是一個函數(shù)。作為一個類,nn.Xxx需要先實例化并傳入?yún)?shù),然后以函數(shù)調(diào)用的方式向?qū)嵗瘜ο笾形谷胼斎霐?shù)據(jù)。
conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding) output = conv(input)
nn.functional.xxx()是在調(diào)用時同時傳入輸入數(shù)據(jù)和設(shè)置參數(shù)。
output = nn.functional.conv2d(input, weight, bias, padding)
3.nn.Xxx不需要自己定義和管理權(quán)重,但nn.functional.xxx()需要自己定義權(quán)重,每次調(diào)用時要手動傳入。
三、Sequential類
1. 基礎(chǔ)使用
Sequential類繼承自Module類。對于一個簡單的序貫模型,可以不必自己再多寫一個類繼承Module類,而是直接使用pytorch提供的Sequential類,來將若干層或若干子模塊直接包裝成一個大的模塊。
例如在LeNet中,我們直接將各個層按順序排列好,然后用Sequential類包裝一下,就可以方便地構(gòu)建好一個神經(jīng)網(wǎng)路了。
import torch.nn as nn net = nn.Sequential( nn.Conv2d(1, 6, 5), nn.Sigmoid(), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5), nn.Sigmoid(), nn.MaxPool2d(2, 2), nn.Conv2d(16, 120, 5), nn.Linear(120, 84), nn.Linear(84, 10) ) print(net) print(net[2]) #通過索引可以獲取到層
運行結(jié)果為
上面這種方法沒有給每一個層指定名稱,默認使用層的索引數(shù)0、1、2來命名。我們可以通過索引值來直接獲對應的層的信息。
當然,我們也可以給層指定名稱,但我們并不能通過名稱獲取層,想獲取層依舊要使用索引數(shù)字。
import torch.nn as nn from collections import OrderedDict net = nn.Sequential(OrderedDict([ ('C1', nn.Conv2d(1, 6, 5)), ('Sig1', nn.Sigmoid()), ('S2', nn.MaxPool2d(2, 2)), ('C3', nn.Conv2d(6, 16, 5)), ('Sig2', nn.Sigmoid()), ('S4', nn.MaxPool2d(2, 2)), ('C5', nn.Conv2d(16, 120, 5)), ('C6', nn.Linear(120, 84)), ('C7', nn.Linear(84, 10)) ])) print(net) print(net[2]) #通過索引可以獲取到層
運行結(jié)果為
也可以使用add_module函數(shù)向Sequential()中添加層。
import torch.nn as nn net = nn.Sequential() net.add_module('C1', nn.Conv2d(1, 6, 5)) net.add_module('Sig1', nn.Sigmoid()) net.add_module('S2', nn.MaxPool2d(2, 2)) net.add_module('C3', nn.Conv2d(6, 16, 5)) net.add_module('Sig2', nn.Sigmoid()) net.add_module('S4', nn.MaxPool2d(2, 2)) net.add_module('C5', nn.Conv2d(16, 120, 5)) net.add_module('C6', nn.Linear(120, 84)) net.add_module('C7', nn.Linear(84, 10)) print(net) print(net[2])
輸出為
2. 使用Sequential類將層包裝成子模塊
Sequential類也可以應用到自定義Module類的方法中,用來將幾個層包裝成一個大層(塊)。
當然Sequential依舊有三種使用方法,我們這里只使用第一種作為舉例。
import torch.nn as nn class LeNet(nn.Module): def __init__(self): super().__init__() self.conv = nn.Sequential( nn.Conv2d(1, 6, 5), nn.Sigmoid(), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5), nn.Sigmoid(), nn.MaxPool2d(2, 2) ) self.fc = nn.Sequential( nn.Conv2d(16, 120, 5), nn.Linear(120, 84), nn.Linear(84, 10) ) def forward(self, x): x1 = self.conv(x) y = self.fc(x1) return y net = LeNet() print(net)
輸出為
四、ModuleList類和ModuleDict類
ModuleList類和ModuleDict類都是Modules類的子類,和Sequential類似,它也可以對若干層或子模塊進行打包,列表化的構(gòu)造網(wǎng)絡(luò)。
但與Sequential類不同的是,這兩個類只是將這些層定義并排列成列表(List)或字典(Dict),但并沒有將它們連接起來,也就是說并沒有實現(xiàn)forward()函數(shù)。
因此,這兩個類并不要求相鄰層的輸入輸出維度匹配,也不能直接向ModuleList和ModuleDict中直接喂入輸入數(shù)據(jù)。
ModuleList的訪問方法和普通的List類似。
net = nn.ModuleList([ nn.Linear(784, 256), nn.ReLU() ]) net.append(nn.Linear(256, 20)) # ModuleList可以像普通的List以下進行append操作 print(net[-1]) # ModuleList的訪問方法與List也相似 print(net) # X = torch.zeros(1, 784) # net(X) # 出錯。向ModuleList中輸入數(shù)據(jù)會出錯,因為ModuleList的作用僅僅是存儲 # 網(wǎng)絡(luò)的各個模塊,但并不連接它們,即沒有實現(xiàn)forward()
輸出為
ModuleDict的使用方法也和普通的字典類似。
net = nn.ModuleDict({ 'linear': nn.Linear(784, 256), 'act': nn.ReLU(), }) net['output'] = nn.Linear(256, 10) # 添加 print(net['linear']) # 訪問 print(net.output) print(net) # net(torch.zeros(1, 784)) # 會報NotImplementedError
輸出為
ModuleList和ModuleDict的使用是為了在定義前向傳播時能更加靈活。下面是官網(wǎng)上的一個關(guān)于ModuleList使用的例子。
class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) def forward(self, x): # ModuleList can act as an iterable, or be indexed using ints for i, l in enumerate(self.linears): x = self.linears[i // 2](x) + l(x) return x
此外,ModuleList和ModuleDict里,所有子模塊的參數(shù)都會被自動添加到神經(jīng)網(wǎng)絡(luò)中,這一點是與普通的List和Dict不同的。
舉個例子。
class Module_ModuleList(nn.Module): def __init__(self): super(Module_ModuleList, self).__init__() self.linears = nn.ModuleList([nn.Linear(10, 10)]) class Module_List(nn.Module): def __init__(self): super(Module_List, self).__init__() self.linears = [nn.Linear(10, 10)] net1 = Module_ModuleList() net2 = Module_List() print("net1:") for p in net1.parameters(): print(p.size()) print("net2:") for p in net2.parameters(): print(p)
輸出為
五、向模型中輸入數(shù)據(jù)
假設(shè)我們向模型中輸入的數(shù)據(jù)為input,從模型中得到的前向傳播結(jié)果為output,則輸入數(shù)據(jù)的方法為
output = net(input)
net是對象名,我們直接將輸入作為參數(shù)傳入到對象名中,而并沒有顯示的調(diào)用forward()函數(shù),就完成了前向傳播的計算。
上面的寫法其實等價于
output = net.forward(input)
這是因為在torch.nn.Module類中,定義了__call__()函數(shù),其中就包括了對forward()方法的調(diào)用。
在python語法中__call__()方法使得類實例對象可以像調(diào)用普通函數(shù)那樣,以“對象名()”的形式使用,并執(zhí)行__call__()函數(shù)體中的內(nèi)容。
總結(jié)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python用requests模塊實現(xiàn)動態(tài)網(wǎng)頁爬蟲
大家好,本篇文章主要講的是Python用requests模塊實現(xiàn)動態(tài)網(wǎng)頁爬蟲,感興趣的同學趕快來看一看吧,對你有幫助的話記得收藏一下2022-02-02python中可以發(fā)生異常自動重試庫retrying
這篇文章主要介紹了python中可以發(fā)生異常自動重試庫retrying,retrying是一個極簡的使用Python編寫的庫,主題更多相關(guān)內(nèi)容需要的朋友可以參考一下2022-06-06初次部署django+gunicorn+nginx的方法步驟
這篇文章主要介紹了初次部署django+gunicorn+nginx的方法步驟,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2019-09-09Python中的數(shù)據(jù)標準化與反標準化全面指南
在數(shù)據(jù)處理和機器學習中,數(shù)據(jù)標準化是一項至關(guān)重要的預處理步驟,標準化能夠?qū)⒉煌叨群头秶臄?shù)據(jù)轉(zhuǎn)換為相同的標準,有助于提高模型的性能和穩(wěn)定性,Python提供了多種庫和函數(shù)來執(zhí)行數(shù)據(jù)標準化和反標準化,如Scikit-learn和TensorFlow2024-01-01Python實現(xiàn)讀取txt文件并畫三維圖簡單代碼示例
這篇文章主要介紹了Python實現(xiàn)讀取txt文件并畫三維圖簡單代碼示例,具有一定借鑒價值,需要的朋友可以參考下。2017-12-12Win7下搭建python開發(fā)環(huán)境圖文教程(安裝Python、pip、解釋器)
這篇文章主要為大家分享了Win7下搭建python開發(fā)環(huán)境圖文教程,本文主要介紹了安裝Python、pip、解釋器的詳細步驟,感興趣的小伙伴們可以參考一下2016-05-05