pytorch模型保存方式
pytorch模型保存
保存模型主要分為兩類:
- 保存整個模型
- 只保存模型參數(shù)
1.保存加載整個模型(不推薦)
保存整個網絡模型,網絡結構+權重參數(shù)
torch.save(model,'net.pth')
加載整個網絡模型(可能比較耗時)
model=torch.load('net.pth')2.只保存加載模型參數(shù)(推薦)
保存模型的權重參數(shù)(速度快,占內存少)
torch.save(model.state_dict(),'net_params.pth')
load 模型參數(shù)
因為我們只保存了 模型的參數(shù),所以需要先定義一個網絡對象,然后再加載模型參數(shù)。
model=myNet()
#將模型參數(shù)加載到新模型中,torch.load返回的是一個OrderedDict,說明.state_dict()只是把所有模型的參數(shù)都已OrderedDict的形式存下來。
state_dict=torch.load('net_params.pth')
model.load_state_dict(state_dict)Note:保存模型進行推理測試時,只需保存訓練好的模型的權重參數(shù),即推薦第二種方法。
load_state_dict的參數(shù)strict=False new_model.load_state_dict(state_dict,strict=False)
如果哪一天我們需要重新寫這個網絡的,比如使用new_model,如果直接load會出現(xiàn)unexpected key.
但是加上strict=False可以很容易地加載預訓練的參數(shù)(注意檢查key是否匹配),直接忽略不匹配的key,對于匹配的key則進行正常的賦值。
總結
以上為個人經驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
python?dataframe獲得指定行列實戰(zhàn)代碼
對于一個DataFrame,常常需要篩選出某列為指定值的行,下面這篇文章主要給大家介紹了關于python?dataframe獲得指定行列的相關資料,文中通過代碼介紹的非常詳細,需要的朋友可以參考下2023-12-12
利用Python實現(xiàn)自動化監(jiān)控文件夾完成服務部署
本篇文章將為大家詳細介紹如何利用Python語言實現(xiàn)監(jiān)控文件夾,以此輔助完成服務的部署動作,文中的示例代碼講解詳細,感興趣的可以嘗試一下2022-07-07

