pytorch 狀態(tài)字典:state_dict使用詳解
pytorch 中的 state_dict 是一個(gè)簡單的python的字典對象,將每一層與它的對應(yīng)參數(shù)建立映射關(guān)系.(如model的每一層的weights及偏置等等)
(注意,只有那些參數(shù)可以訓(xùn)練的layer才會被保存到模型的state_dict中,如卷積層,線性層等等)
優(yōu)化器對象Optimizer也有一個(gè)state_dict,它包含了優(yōu)化器的狀態(tài)以及被使用的超參數(shù)(如lr, momentum,weight_decay等)
備注:
1) state_dict是在定義了model或optimizer之后pytorch自動生成的,可以直接調(diào)用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"
torch.save(model.state_dict(), PATH)
2) load_state_dict 也是model或optimizer之后pytorch自動具備的函數(shù),可以直接調(diào)用
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因?yàn)?只有在執(zhí)行該命令后,"dropout層"及"batch normalization層"才會進(jìn)入 evalution 模態(tài). 而在"訓(xùn)練(training)模態(tài)"與"評估(evalution)模態(tài)"下,這兩層有不同的表現(xiàn)形式.
模態(tài)字典(state_dict)的保存(model是一個(gè)網(wǎng)絡(luò)結(jié)構(gòu)類的對象)
1.1)僅保存學(xué)習(xí)到的參數(shù),用以下命令
torch.save(model.state_dict(), PATH)
1.2)加載model.state_dict,用以下命令
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
備注:model.load_state_dict的操作對象是 一個(gè)具體的對象,而不能是文件名
2.1)保存整個(gè)model的狀態(tài),用以下命令
torch.save(model,PATH)
2.2)加載整個(gè)model的狀態(tài),用以下命令:
# Model class must be defined somewhere model = torch.load(PATH) model.eval()
state_dict 是一個(gè)python的字典格式,以字典的格式存儲,然后以字典的格式被加載,而且只加載key匹配的項(xiàng)
如何僅加載某一層的訓(xùn)練的到的參數(shù)(某一層的state)
If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
加載模型參數(shù)后,如何設(shè)置某層某參數(shù)的"是否需要訓(xùn)練"(param.requires_grad)
for param in list(model.pretrained.parameters()): param.requires_grad = False
注意: requires_grad的操作對象是tensor.
疑問:能否直接對某個(gè)層直接之用requires_grad呢?例如:model.conv1.requires_grad=False
回答:經(jīng)測試,不可以.model.conv1 沒有requires_grad屬性.
全部測試代碼:
#-*-coding:utf-8-*- import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim # define model class TheModelClass(nn.Module): def __init__(self): super(TheModelClass,self).__init__() self.conv1 = nn.Conv2d(3,6,5) self.pool = nn.MaxPool2d(2,2) self.conv2 = nn.Conv2d(6,16,5) self.fc1 = nn.Linear(16*5*5,120) self.fc2 = nn.Linear(120,84) self.fc3 = nn.Linear(84,10) def forward(self,x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1,16*5*5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # initial model model = TheModelClass() #initialize the optimizer optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9) # print the model's state_dict print("model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor,'\t',model.state_dict()[param_tensor].size()) print("\noptimizer's state_dict") for var_name in optimizer.state_dict(): print(var_name,'\t',optimizer.state_dict()[var_name]) print("\nprint particular param") print('\n',model.conv1.weight.size()) print('\n',model.conv1.weight) print("------------------------------------") torch.save(model.state_dict(),'./model_state_dict.pt') # model_2 = TheModelClass() # model_2.load_state_dict(torch.load('./model_state_dict')) # model.eval() # print('\n',model_2.conv1.weight) # print((model_2.conv1.weight == model.conv1.weight).size()) ## 僅僅加載某一層的參數(shù) conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight'] print(conv1_weight_state==model.conv1.weight) model_2 = TheModelClass() model_2.load_state_dict(torch.load('./model_state_dict.pt')) model_2.conv1.requires_grad=False print(model_2.conv1.requires_grad) print(model_2.conv1.bias.requires_grad)
以上這篇pytorch 狀態(tài)字典:state_dict使用詳解就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實(shí)現(xiàn)的人工神經(jīng)網(wǎng)絡(luò)算法示例【基于反向傳播算法】
這篇文章主要介紹了Python實(shí)現(xiàn)的人工神經(jīng)網(wǎng)絡(luò)算法,結(jié)合實(shí)例形式分析了Python基于反向傳播算法實(shí)現(xiàn)的人工神經(jīng)網(wǎng)絡(luò)相關(guān)操作技巧,需要的朋友可以參考下2017-11-11Python利用treap實(shí)現(xiàn)雙索引的方法
所遍歷的元素一定是遞增(小堆)或是遞減(大堆)關(guān)系,但是我們無法得知左子樹與右子樹兩部分節(jié)點(diǎn)的排序關(guān)系。本文就來講講算法和數(shù)據(jù)結(jié)構(gòu)共同滿足一組特性,感興趣的小伙伴請參考下面文章的內(nèi)容2021-09-09python使用redis模塊來跟redis實(shí)現(xiàn)交互
這篇文章主要介紹了python使用redis模塊來跟redis實(shí)現(xiàn)交互,文章圍繞主題展開詳細(xì)的內(nèi)容介紹,具有一定的參考價(jià)值,需要的小伙伴可以參考一下2022-06-06python判斷、獲取一張圖片主色調(diào)的2個(gè)實(shí)例
一幅圖片,想通過程序判斷獲得其主要色調(diào),應(yīng)該怎么樣處理?本文通過python實(shí)現(xiàn)判斷、獲取一張圖片的主色調(diào)方法,需要的朋友可以參考下2014-04-04詳解Python?Flask?API?示例演示(附cookies和session)
這篇文章主要為大家介紹了Python?Flask?API?示例演示(附cookies和session)詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-03-03利用Python將時(shí)間或時(shí)間間隔轉(zhuǎn)為ISO 8601格式方法示例
國際標(biāo)準(zhǔn)化組織的國際標(biāo)準(zhǔn)ISO8601是日期和時(shí)間的表示方法,全稱為《數(shù)據(jù)存儲和交換形式·信息交換·日期和時(shí)間的表示方法》,下面這篇文章主要給大家介紹了關(guān)于利用Python將時(shí)間或時(shí)間間隔轉(zhuǎn)為ISO 8601格式的相關(guān)資料,需要的朋友可以參考下。2017-09-09python計(jì)算階乘和的方法(1!+2!+3!+...+n!)
今天小編就為大家分享一篇python計(jì)算階乘和的方法(1!+2!+3!+...+n!),具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-02-02Python使用內(nèi)置函數(shù)setattr設(shè)置對象的屬性值
這篇文章主要介紹了Python使用內(nèi)置函數(shù)setattr設(shè)置對象的屬性值,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-10-10