pytorch GPU和CPU模型相互加載方式
1 pytorch保存模型的兩種方式
1.1 直接保存模型并讀取
# 創(chuàng)建你的模型實(shí)例對(duì)象: model model = net() ## 保存模型 torch.save(model, 'model_name.pth') ## 讀取模型 model = torch.load('model_name.pth')
1.2 只保存模型中的參數(shù)并讀取
## 保存模型 torch.save({'model': model.state_dict()}, 'model_name.pth') ## 讀取模型 model = net() state_dict = torch.load('model_name.pth') model.load_state_dict(state_dict['model'])
- 第一種方法可以直接保存模型,加載模型的時(shí)候直接把讀取的模型給一個(gè)參數(shù)就行。
- 第二種方法則只是保存參數(shù),在讀取模型參數(shù)前要先定義一個(gè)模型(模型必須與原模型相同的構(gòu)造),然后對(duì)這個(gè)模型導(dǎo)入?yún)?shù)。雖然麻煩,但是可以同時(shí)保存多個(gè)模型的參數(shù),而第一種方法則不能,而且第一種方法有時(shí)不能保證模型的相同性(你讀取的模型并不是你想要的)。
如何保存模型決定了如何讀取模型,一般來(lái)選擇第二種來(lái)保存和讀取。
2 GPU / CPU模型相互加載
2.1 單個(gè)CPU和單個(gè)GPU模型加載
pytorch 允許把在GPU上訓(xùn)練的模型加載到CPU上,也允許把在CPU上訓(xùn)練的模型加載到GPU上。
加載模型參數(shù)的時(shí)候,在GPU和CPU訓(xùn)練的模型是不一樣的,這兩種模型是不能混為一談的,下面分情況進(jìn)行操作說(shuō)明。
情況一:CPU -> CPU, GPU -> GPU
- GPU訓(xùn)練的模型,在GPU上使用;
- CPU訓(xùn)練的模型,在CPU上使用,
這種情況下我們都只用直接用下面的語(yǔ)句即可:
torch.load('model_dict.pth')
情況二:GPU -> CPG/GPU
GPU訓(xùn)練的模型,不知道放在CPU還是GPU運(yùn)行,兩種情況都要考慮
import torch from torchvision import models # 加載預(yù)訓(xùn)練的GPU模型權(quán)重文件 weights_path = 'model_gpu.pth' # 定義一個(gè)與原模型結(jié)構(gòu)相同的新模型 model = models.resnet50() # 檢查是否有可用的CUDA設(shè)備 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 將權(quán)重映射到相應(yīng)的設(shè)備內(nèi)存并加載到模型中 weights = torch.load(weights_path, map_location=device) model.load_state_dict(weights) # 設(shè)置為評(píng)估模式 model.eval() print("Model is successfully loaded and can be used on a", device.type, "!")
情況三:CPU -> CPG/GPU
模型是在CPU上訓(xùn)練的,但不確定要在CPU還是GPU上運(yùn)行時(shí),兩種情況都要考慮
import torch from torchvision import models # 加載預(yù)訓(xùn)練的CPU模型權(quán)重文件 weights_path = 'model_cpu.pth' # 定義一個(gè)與原模型結(jié)構(gòu)相同的新模型 model = models.resnet50() # 檢查是否有可用的CUDA設(shè)備 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 將權(quán)重映射到相應(yīng)的設(shè)備內(nèi)存并加載到模型中 if device.type == 'cuda': model.to(device) weights = torch.load(weights_path, map_location=device) else: weights = torch.load(weights_path, map_location='cpu') model.load_state_dict(weights) # 設(shè)置為評(píng)估模式 model.eval() print("Model is successfully loaded and can be used on a", device.type, "!")
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
詳解用python -m http.server搭一個(gè)簡(jiǎn)易的本地局域網(wǎng)
這篇文章主要介紹了詳解用python -m http.server搭一個(gè)簡(jiǎn)易的本地局域網(wǎng),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-09-09用selenium解決滑塊驗(yàn)證碼的實(shí)現(xiàn)步驟
驗(yàn)證碼作為一種自然人的機(jī)器人的判別工具,被廣泛的用于各種防止程序做自動(dòng)化的場(chǎng)景中,下面這篇文章主要給大家介紹了關(guān)于用selenium解決滑塊驗(yàn)證碼的實(shí)現(xiàn)步驟,需要的朋友可以參考下2023-02-02Python利用卡方Chi特征檢驗(yàn)實(shí)現(xiàn)提取關(guān)鍵文本特征
卡方檢驗(yàn)最基本的思想就是通過觀察實(shí)際值與理論值的偏差來(lái)確定理論的正確與否。本文將利用卡方Chi特征檢驗(yàn)實(shí)現(xiàn)提取關(guān)鍵文本特征功能,感興趣的可以了解一下2022-12-12Pygame實(shí)現(xiàn)簡(jiǎn)易版趣味小游戲之反彈球
這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)簡(jiǎn)易版趣味反彈球游戲,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2022-03-03python linecache讀取行更新的實(shí)現(xiàn)
本文主要介紹了python linecache讀取行更新的實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2023-03-03