PyTorch模型的保存與加載方法實(shí)例
模型的保存與加載
首先,需要導(dǎo)入兩個(gè)包
import torch import torchvision.models as models
保存和加載模型參數(shù)
PyTorch
模型將學(xué)習(xí)到的參數(shù)存儲(chǔ)在一個(gè)內(nèi)部狀態(tài)字典中,叫做state_dict
。這可以通過torch.save
方法來實(shí)現(xiàn)。
我們導(dǎo)入預(yù)訓(xùn)練好的VGG16
模型,并將其保存。我們將state_dict
字典保存在model_weights.pth
文件中。
model = models.vgg16(pretrained=True) torch.save(model.state_dict(), 'model_weights.pth')
想要加載模型參數(shù),我們需要?jiǎng)?chuàng)建一個(gè)和原模型一樣的實(shí)例,然后通過load_state_dict()
方法來加載模型參數(shù)
- 創(chuàng)建一個(gè)
VGG16
模型實(shí)例(未經(jīng)過預(yù)訓(xùn)練的) - 加載本地參數(shù)
model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights model.load_state_dict(torch.load('model_weights.pth')) model.eval()
注意:在進(jìn)行測試前,如果模型中有dropout
層和batch normalization
層的話,一定要使用model.eval()
將模型轉(zhuǎn)到測試模式。
- 在
train
模式下,dropout
網(wǎng)絡(luò)層會(huì)按照設(shè)定的參數(shù)p
設(shè)置保留激活單元的概率(保留概率=p
);batchnorm
層會(huì)繼續(xù)計(jì)算數(shù)據(jù)的mean
和var
等參數(shù)并更新。 - 在
val
模式下,dropout
層會(huì)讓所有的激活單元都通過,而batchnorm
層會(huì)停止計(jì)算和更新mean
和var
,直接使用在訓(xùn)練階段已經(jīng)學(xué)出的mean
和var
值
當(dāng)然,相同的,在模型進(jìn)行訓(xùn)練之前,要使用model.train()
來將模型轉(zhuǎn)為訓(xùn)練模式
保存和加載模型參數(shù)與結(jié)構(gòu)
當(dāng)加載模型權(quán)重時(shí),我們需要首先實(shí)例化模型類,因?yàn)轭惗x了網(wǎng)絡(luò)的結(jié)構(gòu)。我們可能希望將這個(gè)類的結(jié)構(gòu)與模型保存在一起。這樣的話,我們可以將model
而不是model.state_dict()
作為參數(shù)。
torch.save(model, 'model.pth')
這樣,我們加載模型的時(shí)候就不用再新建一個(gè)實(shí)例了。加載方式如下所示
model = torch.load('model.pth')
這種方式在網(wǎng)絡(luò)比較大的時(shí)候可能比較慢,因?yàn)橄噍^于上面的方式多存儲(chǔ)了網(wǎng)絡(luò)的結(jié)構(gòu)
總結(jié)
到此這篇關(guān)于PyTorch模型的保存與加載方法的文章就介紹到這了,更多相關(guān)PyTorch模型保存加載內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
合并Excel工作薄中成績表的VBA代碼,非常適合教育一線的朋友
每次學(xué)生考試,評分完畢之后,把每個(gè)科的成績收集起來,就得到了一個(gè)有若干工作表,每個(gè)表有學(xué)生學(xué)號、分?jǐn)?shù)等列的Excel工作薄。2009-04-04python+pygame實(shí)現(xiàn)坦克大戰(zhàn)小游戲的示例代碼(可以自定義子彈速度)
這篇文章主要介紹了python+pygame實(shí)現(xiàn)坦克大戰(zhàn)小游戲---可以自定義子彈速度,本文通過實(shí)例代碼給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-08-08python輸入一個(gè)水仙花數(shù)(三位數(shù)) 輸出百位十位個(gè)位實(shí)例
這篇文章主要介紹了python輸入一個(gè)水仙花數(shù)(三位數(shù)) 輸出百位十位個(gè)位實(shí)例,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-05-05Python wxPython庫Core組件BoxSizer用法示例
這篇文章主要介紹了Python wxPython庫Core組件BoxSizer用法,結(jié)合實(shí)例形式分析了wxPython BoxSizer布局管理相關(guān)使用方法及操作注意事項(xiàng),需要的朋友可以參考下2018-09-09python中str內(nèi)置函數(shù)用法總結(jié)
在本篇文章里小編給大家整理了一篇關(guān)于python中str內(nèi)置函數(shù)用法總結(jié)內(nèi)容,有需要的朋友們可以學(xué)習(xí)下。2020-12-12