亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

怎樣保存模型權(quán)重和checkpoint

 更新時(shí)間:2022年12月17日 11:16:12   作者:取個(gè)名字真難吶  
這篇文章主要介紹了如何保存模型權(quán)重和checkpoint,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

概述

在pytorch中有兩種方式可以保存推理模型,第一種是只保存模型的參數(shù),比如parameters和buffers;另外一種是保存整個(gè)模型;

1.保存模型 - 權(quán)重參數(shù)

我們可以用torch.save()函數(shù)來保存model.state_dict();state_dict()里面包含模型的parameters&buffers;這種方法只保存模型中必要的訓(xùn)練參數(shù)。

你可以用pytorch中的pickle來保存模型;使用這種方法可以生成最直觀的語法,并涉及最少的代碼;這種方法的缺點(diǎn)是,序列化的數(shù)據(jù)被綁定到特定的類和保存模型時(shí)使用的確切的目錄結(jié)構(gòu)。

這樣做的原因是pickle并不保存模型類本身。相反,它保存包含類的文件的路徑,在加載期間使用;因此,當(dāng)在其他項(xiàng)目中使用或重構(gòu)后,您的代碼可能以各種方式中斷。

我們將探討如何保存和加載模型進(jìn)行推斷的兩種方法。

步驟:

(1)導(dǎo)入所有必要的庫(kù)來加載我們的數(shù)據(jù)

(2)定義和初始化神經(jīng)網(wǎng)絡(luò)

(3)初始化優(yōu)化器

(4)保存并通過state_dict加載模型

(5)保存并加載整個(gè)模型

1.1代碼

# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: Neural_Network_test
# @Create time: 2022/3/19 15:33

# 1.導(dǎo)入相關(guān)數(shù)據(jù)庫(kù)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F


# 2.定義神經(jīng)網(wǎng)絡(luò)模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# 3. 實(shí)例化神經(jīng)網(wǎng)絡(luò)
net = Net()

# 4. 實(shí)例化優(yōu)化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 5. 保存模型參數(shù)
# Specify a path
PATH = "state_dict_model.pt"

# 6. 保存模型的參數(shù)字典:parameters and buffers
torch.save(net.state_dict(), PATH)

# 7. 實(shí)例化新的模型
model = Net()

# 8. 給新的實(shí)例加載之前的模型參數(shù)
model.load_state_dict(torch.load(PATH))

# 9. 設(shè)置模型為評(píng)估模式
model.eval()

注意(1):

pytorch中常用的慣例是將model.state_dict()保存為"state_dict_model.pt",即文件的格式一般是.pt或者.pth格式文件;注意load_state_dict加載的是一個(gè)字典,而不是路徑。

注意(2):

模型參數(shù)在推理階段一定要設(shè)置model.eval();這樣可以讓dropout和batchnorm失效,如果沒設(shè)置推理模式,會(huì)得到不一樣的結(jié)果。

2.保存模型 - 整個(gè)模型

將模型所有的內(nèi)容都保存下來。 

# Specify a path
PATH = "entire_model.pt"

# Save
torch.save(net, PATH)

# Load
model = torch.load(PATH)
model.eval()

3.保存模型 - checkpoints

我們按照checkpoints模式來保存模型,本質(zhì)上就是按照字典的模式進(jìn)行分門別類的保存,我們可以通過鍵值進(jìn)行加載。

  • epoch:訓(xùn)練周期
  • model_state_dict:模型可訓(xùn)練參數(shù)
  • optimizer_state_dict:模型優(yōu)化器參數(shù)
  • loss:模型的損失函數(shù)
# Additional information
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4

torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

保存和加載通用的檢查點(diǎn)模型以進(jìn)行推斷或恢復(fù)訓(xùn)練,這有助于您從上一個(gè)地方繼續(xù)進(jìn)行。

當(dāng)保存一個(gè)常規(guī)檢查點(diǎn)時(shí),您必須保存模型的state_dict之外的更多信息。

保存優(yōu)化器的state_dict也很重要,因?yàn)樗彌_區(qū)和參數(shù),隨著模型的運(yùn)行而更新。

您可能希望保存的其他項(xiàng)目是您離開的時(shí)期,最新記錄的訓(xùn)練損失,外部torch.nn.嵌入層,以及更多,基于自己的算法

3.1代碼

# 1.導(dǎo)入相關(guān)數(shù)據(jù)庫(kù)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F

# 2. 定義神經(jīng)網(wǎng)絡(luò)
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 3. 實(shí)例化神經(jīng)網(wǎng)絡(luò)
net = Net()

# 4. 實(shí)例化優(yōu)化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# Additional information

# 5. 定義超參數(shù)
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4

# 6. 以checkpoints形式保存模型的相關(guān)數(shù)據(jù)
torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

# 7. 重新實(shí)例化一個(gè)模型
model = Net()

# 8. 實(shí)例化優(yōu)化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)


# 9. 加載以前的checkpoint
checkpoint = torch.load(PATH)

# 10. 通過鍵值來加載相關(guān)參數(shù)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

# 11.設(shè)置推理模式
model.eval()
# - or -
model.train()

4.保存雙模型

當(dāng)保存有多個(gè)神經(jīng)網(wǎng)絡(luò)模型組成的神經(jīng)網(wǎng)絡(luò)時(shí),比如GAN對(duì)抗模型,sequence-to-sequence序列到序列模型,或者一個(gè)組合模型,你必須為每一個(gè)模型保存狀態(tài)字典state_dict()和其對(duì)應(yīng)的優(yōu)化器參數(shù)optimizer.state_dict();您還可以保存任何其他項(xiàng)目,可能會(huì)幫助您恢復(fù)訓(xùn)練,只需將它們添加到字典;為了加載模型,第一步是初始化神經(jīng)網(wǎng)絡(luò)模型和優(yōu)化器,然后用torch.load()去加載checkpoint對(duì)應(yīng)的數(shù)據(jù),因?yàn)閏heckpoints是字典,所以我們可以通過鍵值進(jìn)行查詢導(dǎo)入;

4.1相關(guān)步驟

(1)導(dǎo)入所有相關(guān)的數(shù)據(jù)庫(kù)

(2)定義和實(shí)例化神經(jīng)網(wǎng)絡(luò)模型

(3)初始化優(yōu)化器

(4)保存多重模型

(5)加載多重模型

# 1.導(dǎo)入相關(guān)數(shù)據(jù)庫(kù)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F

# 2. 定義神經(jīng)網(wǎng)絡(luò)
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 3. 實(shí)例化神經(jīng)網(wǎng)絡(luò)A,B
netA = Net()
netB = Net()

# 4. 實(shí)例化優(yōu)化器A,B
optimizerA = optim.SGD(netA.parameters(), lr=0.001, momentum=0.9)
optimizerB = optim.SGD(netB.parameters(), lr=0.001, momentum=0.9)

# 5. 保存模型
# Specify a path to save to
PATH = "model.pt"

torch.save({
            'modelA_state_dict': netA.state_dict(),
            'modelB_state_dict': netB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            }, PATH)

# 6.重新實(shí)例化新的網(wǎng)絡(luò)模型A,B
modelA = Net()
modelB = Net()

# 7. 重新實(shí)例化新的網(wǎng)絡(luò)模型A,B
optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)

# 8. 將以前模型的參數(shù)重新加載到新的模型A,B中
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

# 9. 開啟預(yù)測(cè)模式
modelA.eval()
modelB.eval()
# - or -
# 10.開啟訓(xùn)練模式
modelA.train()
modelB.train()

5.機(jī)器學(xué)習(xí)流程圖

6.機(jī)器學(xué)習(xí)常用庫(kù)

總結(jié)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python解析json代碼實(shí)例解析

    Python解析json代碼實(shí)例解析

    這篇文章主要介紹了Python解析json代碼實(shí)例解析,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2019-11-11
  • 詳解python 爬取12306驗(yàn)證碼

    詳解python 爬取12306驗(yàn)證碼

    這篇文章主要介紹了python爬取12306驗(yàn)證碼,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2019-05-05
  • 詳解Python使用tensorflow入門指南

    詳解Python使用tensorflow入門指南

    本篇文章主要介紹了詳解Python使用tensorflow入門指南,小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧
    2018-02-02
  • python計(jì)算方程式根的方法

    python計(jì)算方程式根的方法

    這篇文章主要介紹了python計(jì)算方程式根的方法,涉及Python數(shù)學(xué)運(yùn)算的相關(guān)技巧,需要的朋友可以參考下
    2015-05-05
  • Python實(shí)現(xiàn)爬取某站視頻彈幕并繪制詞云圖

    Python實(shí)現(xiàn)爬取某站視頻彈幕并繪制詞云圖

    這篇文章主要介紹了利用Python爬取某站的視頻彈幕,并將其繪制成詞云圖,文中的示例代碼講解詳細(xì),對(duì)我學(xué)習(xí)Python爬蟲有一定的幫助,需要的朋友可以參考一下
    2021-12-12
  • Python中index()函數(shù)與find()函數(shù)的區(qū)別詳解

    Python中index()函數(shù)與find()函數(shù)的區(qū)別詳解

    這篇文章主要介紹了Python中index()函數(shù)與find()函數(shù)的區(qū)別詳解,Python index()方法檢測(cè)字符串中是否包含子字符串 str ,如果指定beg開始和end結(jié)束范圍,則檢查是否包含在指定范圍內(nèi),需要的朋友可以參考下
    2023-08-08
  • LyScript實(shí)現(xiàn)內(nèi)存交換與差異對(duì)比的方法詳解

    LyScript實(shí)現(xiàn)內(nèi)存交換與差異對(duì)比的方法詳解

    LyScript?針對(duì)內(nèi)存讀寫函數(shù)的封裝功能并不多,只提供了內(nèi)存讀取和內(nèi)存寫入函數(shù)的封裝,本篇文章將繼續(xù)對(duì)API進(jìn)行封裝,實(shí)現(xiàn)一些在軟件逆向分析中非常實(shí)用的功能,需要的可以參考一下
    2022-08-08
  • 解決已經(jīng)安裝requests,卻依然提示No module named requests問題

    解決已經(jīng)安裝requests,卻依然提示No module named requests問題

    今天小編就為大家分享一篇解決已經(jīng)安裝requests,卻依然提示No module named 'requests'問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2018-05-05
  • pytorch之Resize()函數(shù)具體使用詳解

    pytorch之Resize()函數(shù)具體使用詳解

    這篇文章主要介紹了pytorch之Resize()函數(shù)具體使用詳解,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-02-02
  • Opencv圖像添加椒鹽噪聲、高斯濾波去除噪聲原理以及手寫Python代碼實(shí)現(xiàn)方法

    Opencv圖像添加椒鹽噪聲、高斯濾波去除噪聲原理以及手寫Python代碼實(shí)現(xiàn)方法

    椒鹽噪聲的特征非常明顯,為圖像上有黑色和白色的點(diǎn),下面這篇文章主要給大家介紹了關(guān)于Opencv圖像添加椒鹽噪聲、高斯濾波去除噪聲原理以及手寫Python代碼實(shí)現(xiàn)的相關(guān)資料,文中通過實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下
    2022-09-09

最新評(píng)論