PyTorch模型創(chuàng)建與nn.Module構(gòu)建
模型創(chuàng)建與nn.Module
文章和代碼已經(jīng)歸檔至【Github倉庫:https://github.com/timerring/dive-into-AI 】
創(chuàng)建網(wǎng)絡(luò)模型通常有2個(gè)要素:
- 構(gòu)建子模塊
- 拼接子模塊
class LeNet(nn.Module): # 子模塊創(chuàng)建 ? ?def __init__(self, classes): ? ? ? ?super(LeNet, self).__init__() ? ? ? ?self.conv1 = nn.Conv2d(3, 6, 5) ? ? ? ?self.conv2 = nn.Conv2d(6, 16, 5) ? ? ? ?self.fc1 = nn.Linear(16*5*5, 120) ? ? ? ?self.fc2 = nn.Linear(120, 84) ? ? ? ?self.fc3 = nn.Linear(84, classes) # 子模塊拼接 ? ?def forward(self, x): ? ? ? ?out = F.relu(self.conv1(x)) ? ? ? ?out = F.max_pool2d(out, 2) ? ? ? ?out = F.relu(self.conv2(out)) ? ? ? ?out = F.max_pool2d(out, 2) ? ? ? ?out = out.view(out.size(0), -1) ? ? ? ?out = F.relu(self.fc1(out)) ? ? ? ?out = F.relu(self.fc2(out)) ? ? ? ?out = self.fc3(out) ? ? ? ?return out
調(diào)用net = LeNet(classes=2)
創(chuàng)建模型時(shí),會(huì)調(diào)用__init__()
方法創(chuàng)建模型的子模塊。
訓(xùn)練調(diào)用outputs = net(inputs)
時(shí),會(huì)進(jìn)入module.py
的call()
函數(shù)中:
def __call__(self, *input, **kwargs): ? ? ? ?for hook in self._forward_pre_hooks.values(): ? ? ? ? ? ?result = hook(self, input) ? ? ? ? ? ?if result is not None: ? ? ? ? ? ? ? ?if not isinstance(result, tuple): ? ? ? ? ? ? ? ? ? ?result = (result,) ? ? ? ? ? ? ? ?input = result ? ? ? ?if torch._C._get_tracing_state(): ? ? ? ? ? ?result = self._slow_forward(*input, **kwargs) ? ? ? ?else: ? ? ? ? ? ?result = self.forward(*input, **kwargs) ? ? ? ... ? ? ? ... ? ? ? ...
最終會(huì)調(diào)用result = self.forward(*input, **kwargs)
函數(shù),該函數(shù)會(huì)進(jìn)入模型的forward()
函數(shù)中,進(jìn)行前向傳播。
在 torch.nn
中包含 4 個(gè)模塊,如下圖所示。
本次重點(diǎn)就在于nn.Model的解析:
nn.Module
nn.Module
有 8 個(gè)屬性,都是OrderDict
(有序字典)的結(jié)構(gòu)。在 LeNet 的__init__()
方法中會(huì)調(diào)用父類nn.Module
的__init__()
方法,創(chuàng)建這 8 個(gè)屬性。
def __init__(self): ? ? ? ?""" ? ? ? Initializes internal Module state, shared by both nn.Module and ScriptModule. ? ? ? """ ? ? ? ?torch._C._log_api_usage_once("python.nn_module") ? ? ? ? ?self.training = True ? ? ? ?self._parameters = OrderedDict() ? ? ? ?self._buffers = OrderedDict() ? ? ? ?self._backward_hooks = OrderedDict() ? ? ? ?self._forward_hooks = OrderedDict() ? ? ? ?self._forward_pre_hooks = OrderedDict() ? ? ? ?self._state_dict_hooks = OrderedDict() ? ? ? ?self._load_state_dict_pre_hooks = OrderedDict() ? ? ? ?self._modules = OrderedDict()
- _parameters 屬性:存儲(chǔ)管理 nn.Parameter 類型的參數(shù)
- _modules 屬性:存儲(chǔ)管理 nn.Module 類型的參數(shù)
- _buffers 屬性:存儲(chǔ)管理緩沖屬性,如 BN 層中的 running_mean
- 5 個(gè) *_hooks 屬性:存儲(chǔ)管理鉤子函數(shù)
LeNet 的__init__()
中創(chuàng)建了 5 個(gè)子模塊,nn.Conv2d()
和nn.Linear()
都繼承于nn.module
,即一個(gè) module 都是包含多個(gè)子 module 的。
class LeNet(nn.Module): # 子模塊創(chuàng)建 ? ?def __init__(self, classes): ? ? ? ?super(LeNet, self).__init__() ? ? ? ?self.conv1 = nn.Conv2d(3, 6, 5) ? ? ? ?self.conv2 = nn.Conv2d(6, 16, 5) ? ? ? ?self.fc1 = nn.Linear(16*5*5, 120) ? ? ? ?self.fc2 = nn.Linear(120, 84) ? ? ? ?self.fc3 = nn.Linear(84, classes) ? ? ? ?... ? ? ? ?... ? ? ? ?...
當(dāng)調(diào)用net = LeNet(classes=2)
創(chuàng)建模型后,net
對(duì)象的 modules 屬性就包含了這 5 個(gè)子網(wǎng)絡(luò)模塊。
下面看下每個(gè)子模塊是如何添加到 LeNet 的_modules
屬性中的。以self.conv1 = nn.Conv2d(3, 6, 5)
為例,當(dāng)我們運(yùn)行到這一行時(shí),首先 Step Into 進(jìn)入 Conv2d
的構(gòu)造,然后 Step Out。右鍵Evaluate Expression
查看nn.Conv2d(3, 6, 5)
的屬性。
上面說了Conv2d
也是一個(gè) module,里面的_modules
屬性為空,_parameters
屬性里包含了該卷積層的可學(xué)習(xí)參數(shù),這些參數(shù)的類型是 Parameter,繼承自 Tensor。
此時(shí)只是完成了nn.Conv2d(3, 6, 5)
module 的創(chuàng)建。還沒有賦值給self.conv1
。在nn.Module
里有一個(gè)機(jī)制,會(huì)攔截所有的類屬性賦值操作(self.conv1
是類屬性) ,進(jìn)入到__setattr__()
函數(shù)中。我們?cè)俅?Step Into 就可以進(jìn)入__setattr__()
。
def __setattr__(self, name, value): ? ? ? ?def remove_from(*dicts): ? ? ? ? ? ?for d in dicts: ? ? ? ? ? ? ? ?if name in d: ? ? ? ? ? ? ? ? ? ?del d[name] ? ? ? ? ?params = self.__dict__.get('_parameters') ? ? ? ?if isinstance(value, Parameter): ? ? ? ? ? ?if params is None: ? ? ? ? ? ? ? ?raise AttributeError( ? ? ? ? ? ? ? ? ? ?"cannot assign parameters before Module.__init__() call") ? ? ? ? ? ?remove_from(self.__dict__, self._buffers, self._modules) ? ? ? ? ? ?self.register_parameter(name, value) ? ? ? ?elif params is not None and name in params: ? ? ? ? ? ?if value is not None: ? ? ? ? ? ? ? ?raise TypeError("cannot assign '{}' as parameter '{}' " ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?"(torch.nn.Parameter or None expected)" ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? .format(torch.typename(value), name)) ? ? ? ? ? ?self.register_parameter(name, value) ? ? ? ?else: ? ? ? ? ? ?modules = self.__dict__.get('_modules') ? ? ? ? ? ?if isinstance(value, Module): ? ? ? ? ? ? ? ?if modules is None: ? ? ? ? ? ? ? ? ? ?raise AttributeError( ? ? ? ? ? ? ? ? ? ? ? ?"cannot assign module before Module.__init__() call") ? ? ? ? ? ? ? ?remove_from(self.__dict__, self._parameters, self._buffers) ? ? ? ? ? ? ? ?modules[name] = value ? ? ? ? ? ?elif modules is not None and name in modules: ? ? ? ? ? ? ? ?if value is not None: ? ? ? ? ? ? ? ? ? ?raise TypeError("cannot assign '{}' as child module '{}' " ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?"(torch.nn.Module or None expected)" ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? .format(torch.typename(value), name)) ? ? ? ? ? ? ? ?modules[name] = value ? ? ? ? ? ... ? ? ? ? ? ... ? ? ? ? ? ...
在這里判斷 value 的類型是Parameter
還是Module
,存儲(chǔ)到對(duì)應(yīng)的有序字典中。
這里nn.Conv2d(3, 6, 5)
的類型是Module
,因此會(huì)執(zhí)行modules[name] = value
,key 是類屬性的名字conv1
,value 就是nn.Conv2d(3, 6, 5)
。
總結(jié)
- 一個(gè) module 里可包含多個(gè)子 module。比如 LeNet 是一個(gè) Module,里面包括多個(gè)卷積層、池化層、全連接層等子 module
- 一個(gè) module 相當(dāng)于一個(gè)運(yùn)算,必須實(shí)現(xiàn) forward() 函數(shù)
- 每個(gè) module 都有 8 個(gè)字典管理自己的屬性
以上就是PyTorch模型創(chuàng)建與nn.Module構(gòu)建的詳細(xì)內(nèi)容,更多關(guān)于PyTorch模型創(chuàng)建nn.Module的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
- 關(guān)于PyTorch中nn.Module類的簡(jiǎn)介
- Pytorch參數(shù)注冊(cè)和nn.ModuleList nn.ModuleDict的問題
- 人工智能學(xué)習(xí)PyTorch實(shí)現(xiàn)CNN卷積層及nn.Module類示例分析
- pytorch 中的重要模塊化接口nn.Module的使用
- 用pytorch的nn.Module構(gòu)造簡(jiǎn)單全鏈接層實(shí)例
- 淺析PyTorch中nn.Module的使用
- 對(duì)Pytorch中nn.ModuleList 和 nn.Sequential詳解
- PyTorch的nn.Module類的定義和使用介紹
相關(guān)文章
Python BeautifulSoup基本用法詳解(通過標(biāo)簽及class定位元素)
這篇文章主要介紹了Python BeautifulSoup基本用法(通過標(biāo)簽及class定位元素),本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2021-08-08利用Python查看微信共同好友功能的實(shí)現(xiàn)代碼
這篇文章主要介紹了利用Python查看微信共同好友功能的實(shí)現(xiàn)代碼,代碼簡(jiǎn)單易懂,非常不錯(cuò),具有一定的參考借鑒價(jià)值 ,需要的朋友可以參考下2019-04-04十個(gè)簡(jiǎn)單使用的Python自動(dòng)化腳本分享
今天小編給大家分享10個(gè)Python高級(jí)腳本,幫助我們減少無謂的時(shí)間浪費(fèi),提高工作學(xué)習(xí)中的效率。文中示例代碼講解詳細(xì),需要的可以參考一下2022-05-05Python 實(shí)現(xiàn)自動(dòng)完成A4標(biāo)簽排版打印功能
這篇文章主要介紹了Python 實(shí)現(xiàn)自動(dòng)完成A4標(biāo)簽排版打印功能,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-04-04Python中dictionary items()系列函數(shù)的用法實(shí)例
這篇文章主要介紹了Python中dictionary items()系列函數(shù)的用法,很實(shí)用的函數(shù),需要的朋友可以參考下2014-08-08python中itertools模塊zip_longest函數(shù)詳解
itertools模塊包含創(chuàng)建高效迭代器的函數(shù),這些函數(shù)的返回值不是list,而是iterator(可迭代對(duì)象),可以用各種方式對(duì)數(shù)據(jù)執(zhí)行循環(huán)操作,今天我們來詳細(xì)探討下zip_longest函數(shù)2018-06-06Python新手入門之常用關(guān)鍵字的簡(jiǎn)單示例詳解
關(guān)鍵字是預(yù)先保留的標(biāo)識(shí)符,每個(gè)關(guān)鍵字都有特殊的含義,下面這篇文章主要給大家介紹了關(guān)于Python新手入門之常用關(guān)鍵字的簡(jiǎn)單示例,文中通過代碼介紹的非常詳細(xì),需要的朋友可以參考下2024-03-03