PyTorch快速搭建神經網絡及其保存提取方法詳解
有時候我們訓練了一個模型, 希望保存它下次直接使用,不需要下次再花時間去訓練 ,本節(jié)我們來講解一下PyTorch快速搭建神經網絡及其保存提取方法詳解
一、PyTorch快速搭建神經網絡方法
先看實驗代碼:
import torch
import torch.nn.functional as F
# 方法1,通過定義一個Net類來建立神經網絡
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)
def forward(self, x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
net1 = Net(2, 10, 2)
print('方法1:\n', net1)
# 方法2 通過torch.nn.Sequential快速建立神經網絡結構
net2 = torch.nn.Sequential(
torch.nn.Linear(2, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 2),
)
print('方法2:\n', net2)
# 經驗證,兩種方法構建的神經網絡功能相同,結構細節(jié)稍有不同
'''''
方法1:
Net (
(hidden): Linear (2 -> 10)
(predict): Linear (10 -> 2)
)
方法2:
Sequential (
(0): Linear (2 -> 10)
(1): ReLU ()
(2): Linear (10 -> 2)
)
'''
先前學習了通過定義一個Net類來構建神經網絡的方法,classNet中首先通過super函數(shù)繼承torch.nn.Module模塊的構造方法,再通過添加屬性的方式搭建神經網絡各層的結構信息,在forward方法中完善神經網絡各層之間的連接信息,然后再通過定義Net類對象的方式完成對神經網絡結構的構建。
構建神經網絡的另一個方法,也可以說是快速構建方法,就是通過torch.nn.Sequential,直接完成對神經網絡的建立。
兩種方法構建得到的神經網絡結構完全相同,都可以通過print函數(shù)來打印輸出網絡信息,不過打印結果會有些許不同。
二、PyTorch的神經網絡保存和提取
在學習和研究深度學習的時候,當我們通過一定時間的訓練,得到了一個比較好的模型的時候,我們當然希望將這個模型及模型參數(shù)保存下來,以備后用,所以神經網絡的保存和模型參數(shù)提取重載是很有必要的。
首先,我們需要在需要保存網路結構及其模型參數(shù)的神經網絡的定義、訓練部分之后通過torch.save()實現(xiàn)對網絡結構和模型參數(shù)的保存。有兩種保存方式:一是保存年整個神經網絡的的結構信息和模型參數(shù)信息,save的對象是網絡net;二是只保存神經網絡的訓練模型參數(shù),save的對象是net.state_dict(),保存結果都以.pkl文件形式存儲。
對應上面兩種保存方式,重載方式也有兩種。對應第一種完整網絡結構信息,重載的時候通過torch.load(‘.pkl')直接初始化新的神經網絡對象即可。對應第二種只保存模型參數(shù)信息,需要首先搭建相同的神經網絡結構,通過net.load_state_dict(torch.load('.pkl'))完成模型參數(shù)的重載。在網絡比較大的時候,第一種方法會花費較多的時間。
代碼實現(xiàn):
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
torch.manual_seed(1) # 設定隨機數(shù)種子
# 創(chuàng)建數(shù)據
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
# 將待保存的神經網絡定義在一個函數(shù)中
def save():
# 神經網絡結構
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1),
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_function = torch.nn.MSELoss()
# 訓練部分
for i in range(300):
prediction = net1(x)
loss = loss_function(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 繪圖部分
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
# 保存神經網絡
torch.save(net1, '7-net.pkl') # 保存整個神經網絡的結構和模型參數(shù)
torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神經網絡的模型參數(shù)
# 載入整個神經網絡的結構及其模型參數(shù)
def reload_net():
net2 = torch.load('7-net.pkl')
prediction = net2(x)
plt.subplot(132)
plt.title('net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
# 只載入神經網絡的模型參數(shù),神經網絡的結構需要與保存的神經網絡相同的結構
def reload_params():
# 首先搭建相同的神經網絡結構
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1),
)
# 載入神經網絡的模型參數(shù)
net3.load_state_dict(torch.load('7-net_params.pkl'))
prediction = net3(x)
plt.subplot(133)
plt.title('net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
# 運行測試
save()
reload_net()
reload_params()
實驗結果:

以上就是本文的全部內容,希望對大家的學習有所幫助,也希望大家多多支持腳本之家。
相關文章
Python + Requests + Unittest接口自動化測試實例分析
這篇文章主要介紹了Python + Requests + Unittest接口自動化測試,結合具體實例形式分析了Python使用Requests與Unittest模塊實現(xiàn)接口自動化測試相關操作技巧,需要的朋友可以參考下2019-12-12
解決jupyter notebook圖片顯示模糊和保存清晰圖片的操作
這篇文章主要介紹了解決jupyter notebook圖片顯示模糊和保存清晰圖片的操作方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2021-04-04
Linux上安裝Python的PIL和Pillow庫處理圖片的實例教程
這里我們來看一下在Linux上安裝Python的PIL和Pillow庫處理圖片的實例教程,包括一個使用Pillow庫實現(xiàn)批量轉換圖片的例子:2016-06-06

