Pytorch中實(shí)現(xiàn)CPU和GPU之間的切換的兩種方法
如何在pytorch中指定CPU和GPU進(jìn)行訓(xùn)練,以及cpu和gpu之間切換
由CPU切換到GPU,要修改的幾個地方:
網(wǎng)絡(luò)模型、損失函數(shù)、數(shù)據(jù)(輸入,標(biāo)注)
# 創(chuàng)建網(wǎng)絡(luò)模型 tudui = Tudui() if torch.cuda.is_available(): tudui = tudui.cuda() # 損失函數(shù) loss_fn = nn.CrossEntropyLoss() if torch.cuda.is_available(): loss_fn = loss_fn.cuda() # 數(shù)據(jù)輸入 包括訓(xùn)練和測試的代碼,二者都需要添加此代碼 if torch.cuda.is_available(): imgs = imgs.cuda() targets = targets.cuda()
方法一:.to(device)
1.不知道電腦GPU可不可用時:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' ) a.to(device)
第一行代碼的意思是判斷電腦GPU可不可用,如果可用的話device就采用cuda()即調(diào)用GPU,不可用的話就采用cpu()即調(diào)用CPU。
第二行代碼的意思就是把變量放到對應(yīng)的device上(當(dāng)然如果你用的是CPU的話就不用這一步了,因?yàn)樽兞磕J(rèn)是存在CPU上的,調(diào)用GPU的話要先把變量放到GPU上跑,跑完之后再調(diào)回CPU上)
2.指定GPU時
# 定義訓(xùn)練的設(shè)備 device = torch.device("cuda:0") # 網(wǎng)絡(luò)模型創(chuàng)建 tudui = Tudui() tudui = tudui.to(device) # 損失函數(shù) loss_fn = nn.CrossEntropyLoss() loss_fn = loss_fn.to(device) # 訓(xùn)練步驟開始 tudui.train() for data in train_dataloader: imgs, targets=data imgs = imgs.to(device) targets = targets.to(device) outputs = tudui(imgs) loss = loss_fn(outputs, targets) # 測試步驟開始 tudui.eval() total_test_loss = 0 total_accuracy = 0 with torch.no_grad(): for data in test_dataloader: imgs, targets=data imgs = imgs.to(device) targets = targets.to(device) outputs = tudui(imgs) loss = loss_fn(outputs, targets) total_test_loss = total_test_loss + loss.item() accuracy = (outputs.argmax(1)==targets).sum() total_accuracy = total_accuracy + accuracy
3.指定cpu時:
device = torch.device('cpu')
方法二:
1、需要修改的
# 三種常見的寫法 device = torch.device('cuda') device = torch.device('cuda: 0') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2、代碼
# 創(chuàng)建模型 tudui = Tudui() if torch.cuda.is_available(): tudui = tudui.cuda() # 損失函數(shù) loss_fn = nn.CrossEntropyLoss() if torch.cuda.is_available(): loss_fn = loss_fn.cuda() # 訓(xùn)練步驟開始 tudui.train() for data in train_dataloader: imgs, targets=data if torch.cuda.is_available(): imgs = imgs.cuda() targets = targets.cuda() outputs = tudui(imgs) loss = loss_fn(outputs, targets) # 測試步驟開始 tudui.eval() total_test_loss = 0 total_accuracy = 0 with torch.no_grad(): for data in test_dataloader: imgs, targets=data if torch.cuda.is_available(): imgs = imgs.cuda() targets = targets.cuda() outputs = tudui(imgs) loss = loss_fn(outputs, targets) total_test_loss = total_test_loss + loss.item() accuracy = (outputs.argmax(1)==targets).sum() total_accuracy = total_accuracy + accuracy
總結(jié):
推薦方法一,如果自己電腦是只有CPU,可以推薦使用云端服務(wù)器,比如PaddlePaddle,Google colab,這些服務(wù)器由每周免費(fèi)八個小時的使用時間,可供我們基本的需求。
到此這篇關(guān)于Pytorch中實(shí)現(xiàn)CPU和GPU之間的切換的兩種方法的文章就介紹到這了,更多相關(guān)Pytorch CPU和GPU切換內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Pycharm搭建Django項(xiàng)目詳細(xì)教程(看完這一篇就夠了)
這篇文章主要給大家介紹了關(guān)于Pycharm搭建Django項(xiàng)目的詳細(xì)教程,想要學(xué)習(xí)的小伙伴看完這一篇就夠了,pycharm是一種Python?IDE,帶有一整套可以幫助用戶在使用Python語言開發(fā)時提高其效率的工具,需要的朋友可以參考下2023-11-11python GUI庫圖形界面開發(fā)之PyQt5時間控件QTimer詳細(xì)使用方法與實(shí)例
這篇文章主要介紹了python GUI庫圖形界面開發(fā)之PyQt5時間控件QTimer詳細(xì)使用方法與實(shí)例,需要的朋友可以參考下2020-02-02Python3實(shí)現(xiàn)計(jì)算兩個數(shù)組的交集算法示例
這篇文章主要介紹了Python3實(shí)現(xiàn)計(jì)算兩個數(shù)組的交集算法,結(jié)合2個實(shí)例形式總結(jié)分析了Python3針對數(shù)組的遍歷、位運(yùn)算以及元素的添加、刪除等相關(guān)操作技巧,需要的朋友可以參考下2019-04-04Python中元組的基礎(chǔ)介紹及常用操作總結(jié)
元組是一種不可變序列。元組變量的賦值要在定義時就進(jìn)行,這就像C語言中的const變量或是C++的引用,定義時賦值之后就不允許有修改。元組存在的意義是:元組在映射中可以作為鍵使用,因?yàn)橐WC鍵的不變性。元組作為很多內(nèi)置函數(shù)和方法的返回值存在2021-09-09pytorch學(xué)習(xí)教程之自定義數(shù)據(jù)集
這篇文章主要給大家介紹了關(guān)于pytorch學(xué)習(xí)教程之自定義數(shù)據(jù)集的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-11-11Python實(shí)現(xiàn)的根據(jù)IP地址計(jì)算子網(wǎng)掩碼位數(shù)功能示例
這篇文章主要介紹了Python實(shí)現(xiàn)的根據(jù)IP地址計(jì)算子網(wǎng)掩碼位數(shù)功能,涉及Python數(shù)值運(yùn)算相關(guān)操作技巧,需要的朋友可以參考下2018-05-05Python的Flask框架中@app.route的用法教程
這篇文章主要介紹了Python的Flask框架中@app.route的用法教程,包括相關(guān)的正則表達(dá)式講解,是Flask學(xué)習(xí)過程當(dāng)中的基礎(chǔ)知識,需要的朋友可以參考下2015-03-03在Heroku云平臺上部署Python的Django框架的教程
這篇文章主要介紹了在Heroku云平臺上部署Python的Django框架的教程,Heroku云平臺使用了Git版本控制系統(tǒng),所以本教程主要提供了配置所需要的Git腳本,需要的朋友可以參考下2015-04-04