解決pytorch?model代碼內(nèi)tensor?device不一致的問題
pytorch model代碼內(nèi)tensor device不一致的問題
在編寫一段處理兩個tensor的代碼如下,需要在forward函數(shù)內(nèi)編寫函數(shù)創(chuàng)建一個新的tensor進(jìn)行索引的掩碼計算
# todo(liang)空間交換 def compute_sim_and_swap(t1, t2, threshold=0.7): n, c, h, w = t1.shape sim = torch.nn.functional.cosine_similarity(t1, t2, dim=1) # n, h, w sim = sim.unsqueeze(0) # c, n, h, w expand_tensor = sim.clone() # 使用拼接構(gòu)建相同的維度 for _ in range(c-1): # c, n, h, w sim = torch.cat([sim, expand_tensor], dim=0) sim = sim.permute(1, 0, 2, 3) # n, c, h, w # 創(chuàng)建邏輯掩碼,小于 threshold 的將掩碼變?yōu)?True 用于交換 mask = sim < threshold indices = torch.rand(mask.shape) < 0.5 t1[mask&indices], t2[mask&indices] = t2[mask&indices], t1[mask&indices] return t1, t2
這段代碼報了這個錯誤
File "xxx/network.py", line 347, in compute_sim_and_swap
t1[mask&indices], t2[mask&indices] = t2[mask&indices], t1[mask&indices]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
統(tǒng)一下進(jìn)行掩碼計算的張量的設(shè)備即可
device = mask.Device indices = indices.to(device)
PyTorch 多GPU使用torch.nn.DataParallel訓(xùn)練參數(shù)不一致問題
在多GPU訓(xùn)練時,遇到了下述的錯誤:
1. Expected tensor for argument 1 'input' to have the same device as tensor for argument 2 'weight'; but device 0 does not equal 1
2. RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!
造成這個錯誤的可能性有挺多,總起來是模型、輸入、模型內(nèi)參數(shù)不在一個GPU上。本人是在調(diào)試RandLA-Net pytorch源碼,希望使用雙GPU訓(xùn)練,經(jīng)過嘗試解決這個問題,此處做一個記錄,希望給后來人一個提醒。經(jīng)過調(diào)試,發(fā)現(xiàn)報錯的地方主要是在數(shù)據(jù)拼接的時候,即一個數(shù)據(jù)在GPU0上,一個數(shù)據(jù)在GPU1上,這就會出現(xiàn)錯誤,相關(guān)代碼如下:
return torch.cat(( self.mlp(concat), features.expand(B, -1, N, K) ), dim=-3)
上述代碼中,必須保證self.mlp(concat)與features.expand(B, -1, N, K)在同一個GPU中。在多GPU運算時,features(此時是輸入變量)有可能放在任何一個GPU中,因此此處在拼接前,獲取一下features的GPU,然后將concat放入相應(yīng)的GPU中,再進(jìn)行數(shù)據(jù)拼接就可以了,代碼如下:
device = features.device concat = concat.to(device) return torch.cat(( self.mlp(concat), features.expand(B, -1, N, K) ), dim=-3)
該源碼中默認(rèn)狀態(tài)下device是一個固定的值,在多GPU訓(xùn)練狀態(tài)下就會報錯,代碼中還有幾處數(shù)據(jù)融合,大家可以依據(jù)上述思路做修改。此外該源碼中由于把device的值寫死了,訓(xùn)練好的模型也必須在相應(yīng)的GPU中做推理,如在cuda0中訓(xùn)練的模型如果在cuda1中推理就會報錯,各位可以依據(jù)此思路對源碼做相應(yīng)的修改。如果修改有困難,可以私信我,我可以把相關(guān)修改后的源碼分享。
到此這篇關(guān)于pytorch model代碼內(nèi)tensor device不一致的問題的文章就介紹到這了,更多相關(guān)pytorch tensor device不一致內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python實現(xiàn)不同數(shù)據(jù)庫間數(shù)據(jù)同步功能
這篇文章主要介紹了python實現(xiàn)不同數(shù)據(jù)庫間數(shù)據(jù)同步功能,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2021-02-02Python?web實戰(zhàn)教程之Django文件上傳和處理詳解
Django和Flask都是Python的Web框架,用于開發(fā)Web應(yīng)用程序,這篇文章主要給大家介紹了關(guān)于Python?web實戰(zhàn)教程之Django文件上傳和處理的相關(guān)資料,文中通過代碼介紹的非常詳細(xì),需要的朋友可以參考下2023-12-12python3.6.4安裝opencv3.4.2的實現(xiàn)
這篇文章主要介紹了python3.6.4安裝opencv3.4.2的實現(xiàn)方式,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2022-10-10詳解Python編程中基本的數(shù)學(xué)計算使用
這篇文章主要介紹了Python編程中基本的數(shù)學(xué)計算使用,其中重點講了除法運算及相關(guān)division模塊的使用,需要的朋友可以參考下2016-02-02