PyTorch深度學習模型的保存和加載流程詳解
更新時間:2021年10月21日 09:32:00 作者:軟耳朵DONG
PyTorch是一個開源的Python機器學習庫,基于Torch,用于自然語言處理等應用程序。2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch,這篇文章主要介紹了PyTorch模型的保存和加載流程
一、模型參數的保存和加載
-
torch.save(module.state_dict(), path):使用module.state_dict()函數獲取各層已經訓練好的參數和緩沖區(qū),然后將參數和緩沖區(qū)保存到path所指定的文件存放路徑(常用文件格式為.pt、.pth或.pkl)。 torch.nn.Module.load_state_dict(state_dict):從state_dict中加載參數和緩沖區(qū)到Module及其子類中 。torch.nn.Module.state_dict()函數返回python中的一個OrderedDict類型字典對象,該對象將每一層與它的對應參數和緩沖區(qū)建立映射關系,字典的鍵值是參數或緩沖區(qū)的名稱。只有那些參數可以訓練的層才會被保存到OrderedDict中,例如:卷積層、線性層等。Python中的字典類以“鍵:值”方式存取數據,OrderedDict是它的一個子類,實現了對字典對象中元素的排序(OrderedDict根據放入元素的先后順序進行排序)。由于進行了排序,所以順序不同的兩個OrderedDict字典對象會被當做是兩個不同的對象。- 示例:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
# 初始化網絡
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 獲取state_dict
state_dict = net.state_dict()
# 字典的遍歷默認是遍歷key,所以param_tensor實際上是鍵值
for param_tensor in state_dict:
print(param_tensor,':\n',state_dict[param_tensor])
# 保存模型參數
torch.save(state_dict,"net_params.pth")
# 通過加載state_dict獲取模型參數
net.load_state_dict(state_dict)
輸出:

二、完整模型的保存和加載
-
torch.save(module, path):將訓練完的整個網絡模型module保存到path所指定的文件存放路徑(常用文件格式為.pt或.pth)。 torch.load(path):加載保存到path中的整個神經網絡模型。- 示例:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
# 初始化網絡
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 保存整個網絡
torch.save(net,"net.pth")
# 加載網絡
net = torch.load("net.pth")
到此這篇關于PyTorch深度學習模型的保存和加載流程詳解的文章就介紹到這了,更多相關PyTorch 模型的保存 內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
Python 格式化輸出_String Formatting_控制小數點位數的實例詳解
在本篇文章里小編給大家整理了關于Python 格式化輸出_String Formatting_控制小數點位數的實例內容,需要的朋友們參考下。2020-02-02
VPS CENTOS 上配置python,mysql,nginx,uwsgi,django的方法詳解
這篇文章主要介紹了VPS CENTOS 上配置python,mysql,nginx,uwsgi,django的方法,較為詳細的分析了VPS CENTOS 上配置python,mysql,nginx,uwsgi,django的具體步驟、相關命令與操作注意事項,需要的朋友可以參考下2019-07-07
國產化設備鯤鵬CentOS7上源碼安裝Python3.7的過程詳解
這篇文章主要介紹了國產化設備鯤鵬CentOS7上源碼安裝Python3.7,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2022-05-05

