Pytorch模型中的parameter與buffer用法
Parameter 和 buffer
If you have parameters in your model, which should be saved and restored in the state_dict, but not trained by the optimizer, you should register them as buffers.Buffers won't be returned in model.parameters(), so that the optimizer won't have a change to update them.
模型中需要保存下來(lái)的參數(shù)包括兩種
一種是反向傳播需要被optimizer更新的,稱之為 parameter
一種是反向傳播不需要被optimizer更新,稱之為 buffer
第一種參數(shù)我們可以通過(guò) model.parameters() 返回;第二種參數(shù)我們可以通過(guò) model.buffers() 返回。因?yàn)槲覀兊哪P捅4娴氖?state_dict 返回的 OrderDict,所以這兩種參數(shù)不僅要滿足是否需要被更新的要求,還需要被保存到OrderDict。
那么現(xiàn)在的問(wèn)題是這兩種參數(shù)如何創(chuàng)建呢,創(chuàng)建好了如何保存到OrderDict呢?
第一種參數(shù)有兩種方式
我們可以直接將模型的成員變量(http://self.xxx) 通過(guò)nn.Parameter() 創(chuàng)建,會(huì)自動(dòng)注冊(cè)到parameters中,可以通過(guò)model.parameters() 返回,并且這樣創(chuàng)建的參數(shù)會(huì)自動(dòng)保存到OrderDict中去;
通過(guò)nn.Parameter() 創(chuàng)建普通Parameter對(duì)象,不作為模型的成員變量,然后將Parameter對(duì)象通過(guò)register_parameter()進(jìn)行注冊(cè),可以通model.parameters() 返回,注冊(cè)后的參數(shù)也會(huì)自動(dòng)保存到OrderDict中去;
第二種參數(shù)我們需要?jiǎng)?chuàng)建tensor
然后將tensor通過(guò)register_buffer()進(jìn)行注冊(cè),可以通model.buffers() 返回,注冊(cè)完后參數(shù)也會(huì)自動(dòng)保存到OrderDict中去。
Pytorch中Module,Parameter和Buffer區(qū)別
下文都將torch.nn簡(jiǎn)寫成nn
Module: 就是我們常用的torch.nn.Module類,你定義的所有網(wǎng)絡(luò)結(jié)構(gòu)都必須繼承這個(gè)類。
Buffer: buffer和parameter相對(duì),就是指那些不需要參與反向傳播的參數(shù)
示例如下:
class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.my_tensor = torch.randn(1) # 參數(shù)直接作為模型類成員變量 self.register_buffer('my_buffer', torch.randn(1)) # 參數(shù)注冊(cè)為 buffer self.my_param = nn.Parameter(torch.randn(1)) def forward(self, x): return x model = MyModel() print(model.state_dict()) >>>OrderedDict([('my_param', tensor([1.2357])), ('my_buffer', tensor([-0.9982]))]) Parameter: 是nn.parameter.Paramter,也就是組成Module的參數(shù)。例如一個(gè)nn.Linear通常由weight和bias參數(shù)組成。它的特點(diǎn)是默認(rèn)requires_grad=True,也就是說(shuō)訓(xùn)練過(guò)程中需要反向傳播的,就需要使用這個(gè) import torch.nn as nn fc = nn.Linear(2,2) # 讀取參數(shù)的方式一 fc._parameters >>> OrderedDict([('weight', Parameter containing: tensor([[0.4142, 0.0424], [0.3940, 0.0796]], requires_grad=True)), ('bias', Parameter containing: tensor([-0.2885, 0.5825], requires_grad=True))]) # 讀取參數(shù)的方式二(推薦這種) for n, p in fc.named_parameters(): print(n,p) >>>weight Parameter containing: tensor([[0.4142, 0.0424], [0.3940, 0.0796]], requires_grad=True) bias Parameter containing: tensor([-0.2885, 0.5825], requires_grad=True) # 讀取參數(shù)的方式三 for p in fc.parameters(): print(p) >>>Parameter containing: tensor([[0.4142, 0.0424], [0.3940, 0.0796]], requires_grad=True) Parameter containing: tensor([-0.2885, 0.5825], requires_grad=True)
通過(guò)上面的例子可以看到,nn.parameter.Paramter的requires_grad屬性值默認(rèn)為True。另外上面例子給出了三種讀取parameter的方法,推薦使用后面兩種,因?yàn)槭且缘善鞯姆绞絹?lái)讀取,第一種方式是一股腦的把參數(shù)全丟給你,要是模型很大,估計(jì)你的電腦會(huì)吃不消。
另外需要介紹的是_parameters是nn.Module在__init__()函數(shù)中就定義了的一個(gè)OrderDict類,這個(gè)可以通過(guò)看下面給出的部分源碼看到,可以看到還初始化了很多其他東西,其實(shí)原理都大同小異,你理解了這個(gè)之后,其他的也是同樣的道理。
class Module(object): ... def __init__(self): self._backend = thnn_backend self._parameters = OrderedDict() self._buffers = OrderedDict() self._backward_hooks = OrderedDict() self._forward_hooks = OrderedDict() self._forward_pre_hooks = OrderedDict() self._state_dict_hooks = OrderedDict() self._load_state_dict_pre_hooks = OrderedDict() self._modules = OrderedDict() self.training = True
每當(dāng)我們給一個(gè)成員變量定義一個(gè)nn.parameter.Paramter的時(shí)候,都會(huì)自動(dòng)注冊(cè)到_parameters,具體的步驟如下:
import torch.nn as nn class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() # 下面兩種定義方式均可 self.p1 = nn.paramter.Paramter(torch.tensor(1.0)) print(self._parameters) self.p2 = nn.Paramter(torch.tensor(2.0)) print(self._parameters)
首先運(yùn)行super(MyModel, self).__init__(),這樣MyModel就初始化了_paramters等一系列的OrderDict,此時(shí)所有變量還都是空的。
self.p1 = nn.paramter.Paramter(torch.tensor(1.0)): 這行代碼會(huì)觸發(fā)nn.Module預(yù)定義好的__setattr__函數(shù),該函數(shù)部分源碼如下:
def __setattr__(self, name, value): ... params = self.__dict__.get('_parameters') if isinstance(value, Parameter): if params is None: raise AttributeError( "cannot assign parameters before Module.__init__() call") remove_from(self.__dict__, self._buffers, self._modules) self.register_parameter(name, value) ...
__setattr__函數(shù)作用簡(jiǎn)單理解就是判斷你定義的參數(shù)是否正確,如果正確就繼續(xù)調(diào)用register_parameter函數(shù)進(jìn)行注冊(cè),這個(gè)函數(shù)簡(jiǎn)單概括就是做了下面這件事
def register_parameter(self,name,param): ... self._parameters[name]=param
下面我們實(shí)例化這個(gè)模型看結(jié)果怎樣
model = MyModel() >>>OrderedDict([('p1', Parameter containing: tensor(1., requires_grad=True))]) OrderedDict([('p1', Parameter containing: tensor(1., requires_grad=True)), ('p2', Parameter containing: tensor(2., requires_grad=True))])
結(jié)果和上面分析的一致。
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python多項(xiàng)式回歸的實(shí)現(xiàn)方法
這篇文章主要介紹了Python多項(xiàng)式回歸的實(shí)現(xiàn)方法,小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2019-03-03Python 流媒體播放器的實(shí)現(xiàn)(基于VLC)
這篇文章主要介紹了Python 流媒體播放器的實(shí)現(xiàn)(基于VLC),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-04-04一文搞懂Python讀取text,CSV,JSON文件的方法
文件處理是一種用于創(chuàng)建文件、寫入數(shù)據(jù)和從中讀取數(shù)據(jù)的過(guò)程,Python 擁有豐富的用于處理不同文件類型的包,從而使得我們可以更加輕松方便的完成文件處理的工作,本文將來(lái)為大家詳細(xì)講講2022-06-06初探利用Python進(jìn)行圖文識(shí)別(OCR)
這篇文章主要介紹了初探利用Python進(jìn)行圖文識(shí)別(OCR),小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2019-02-02Numpy數(shù)組的廣播機(jī)制的實(shí)現(xiàn)
這篇文章主要介紹了Numpy數(shù)組的廣播機(jī)制的實(shí)現(xiàn),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-11-11tensorboard 可以顯示graph,卻不能顯示scalar的解決方式
今天小編就為大家分享一篇tensorboard 可以顯示graph,卻不能顯示scalar的解決方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-02-02