pytorch模型保存方式
pytorch模型保存
保存模型主要分為兩類:
- 保存整個(gè)模型
- 只保存模型參數(shù)
1.保存加載整個(gè)模型(不推薦)
保存整個(gè)網(wǎng)絡(luò)模型,網(wǎng)絡(luò)結(jié)構(gòu)+權(quán)重參數(shù)
torch.save(model,'net.pth')
加載整個(gè)網(wǎng)絡(luò)模型(可能比較耗時(shí))
model=torch.load('net.pth')
2.只保存加載模型參數(shù)(推薦)
保存模型的權(quán)重參數(shù)(速度快,占內(nèi)存少)
torch.save(model.state_dict(),'net_params.pth')
load 模型參數(shù)
因?yàn)槲覀冎槐4媪?模型的參數(shù),所以需要先定義一個(gè)網(wǎng)絡(luò)對(duì)象,然后再加載模型參數(shù)。
model=myNet()
#將模型參數(shù)加載到新模型中,torch.load返回的是一個(gè)OrderedDict,說(shuō)明.state_dict()只是把所有模型的參數(shù)都已OrderedDict的形式存下來(lái)。
state_dict=torch.load('net_params.pth') model.load_state_dict(state_dict)
Note:保存模型進(jìn)行推理測(cè)試時(shí),只需保存訓(xùn)練好的模型的權(quán)重參數(shù),即推薦第二種方法。
load_state_dict的參數(shù)strict=False new_model.load_state_dict(state_dict,strict=False)
如果哪一天我們需要重新寫這個(gè)網(wǎng)絡(luò)的,比如使用new_model,如果直接load會(huì)出現(xiàn)unexpected key.
但是加上strict=False可以很容易地加載預(yù)訓(xùn)練的參數(shù)(注意檢查key是否匹配),直接忽略不匹配的key,對(duì)于匹配的key則進(jìn)行正常的賦值。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python3實(shí)現(xiàn)跳一跳點(diǎn)擊跳躍
這篇文章主要為大家詳細(xì)介紹了python3實(shí)現(xiàn)跳一跳點(diǎn)擊跳躍,玩跳一跳小游戲的思路,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-01-01python?dataframe獲得指定行列實(shí)戰(zhàn)代碼
對(duì)于一個(gè)DataFrame,常常需要篩選出某列為指定值的行,下面這篇文章主要給大家介紹了關(guān)于python?dataframe獲得指定行列的相關(guān)資料,文中通過(guò)代碼介紹的非常詳細(xì),需要的朋友可以參考下2023-12-12對(duì)python判斷ip是否可達(dá)的實(shí)例詳解
今天小編就為大家分享一篇對(duì)python判斷ip是否可達(dá)的實(shí)例詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-01-01利用Python實(shí)現(xiàn)自動(dòng)化監(jiān)控文件夾完成服務(wù)部署
本篇文章將為大家詳細(xì)介紹如何利用Python語(yǔ)言實(shí)現(xiàn)監(jiān)控文件夾,以此輔助完成服務(wù)的部署動(dòng)作,文中的示例代碼講解詳細(xì),感興趣的可以嘗試一下2022-07-07django頁(yè)面跳轉(zhuǎn)問(wèn)題及注意事項(xiàng)
這篇文章主要介紹了django頁(yè)面跳轉(zhuǎn)問(wèn)題及注意事項(xiàng),本文給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-07-07