pytorch訓(xùn)練時(shí)的顯存占用遞增的問題解決
遇到的問題:
在pytorch訓(xùn)練過程中突然out of memory。
解決方法:
1. 測(cè)試的時(shí)候爆顯存有可能是忘記設(shè)置no_grad
加入 with torch.no_grad()
model.eval() with torch.no_grad(): ? ? ? ? for idx, (data, target) in enumerate(data_loader): ? ? ? ? ? ? if args.gpu != -1: ? ? ? ? ? ? ? ? data, target = data.to(args.device), target.to(args.device) ? ? ? ? ? ? log_probs = net_g(data) ? ? ? ? ? ? probs.append(log_probs) ? ? ? ? ? ?? ? ? ? ? ? ? # sum up batch loss ? ? ? ? ? ? test_loss += F.cross_entropy(log_probs, target, reduction='sum').item() ? ? ? ? ? ? # get the index of the max log-probability ? ? ? ? ? ? y_pred = log_probs.data.max(1, keepdim=True)[1] ? ? ? ? ? ? correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
2. loss.item()
寫成loss_train = loss_train + loss.item(),不能直接寫loss_train = loss_train + loss
3. 在代碼中添加以下兩行:
torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True
4. del操作后再加上torch.cuda.empty_cache()
單獨(dú)使用del、torch.cuda.empty_cache()效果都不明顯,因?yàn)閑mpty_cache()不會(huì)釋放還被占用的內(nèi)存。
所以這里使用了del讓對(duì)應(yīng)數(shù)據(jù)成為“沒標(biāo)簽”的垃圾,之后這些垃圾所占的空間就會(huì)被empty_cache()回收。
"""添加了最后兩行,img和segm是圖像和標(biāo)簽輸入,很明顯通過.cuda()已經(jīng)是被存在在顯存里了; ? ?outputs是模型的輸出,模型在顯存里當(dāng)然其輸出也在顯存里;loss是通過在顯存里的segm和 ? ?outputs算出來的,其也在顯存里。這4個(gè)對(duì)象都是一次性的,使用后應(yīng)及時(shí)把其從顯存中清除 ? ?(當(dāng)然如果你顯存夠大也可以忽略)。""" ? def train(model, data_loader, batch_size, optimizer): ? ? model.train() ? ? total_loss = 0 ? ? accumulated_steps = 32 // batch_size ? ? optimizer.zero_grad() ? ? for idx, (img, segm) in enumerate(tqdm(data_loader)): ? ? ? ? img = img.cuda() ? ? ? ? segm = segm.cuda() ? ? ? ? outputs = model(img) ? ? ? ? loss = criterion(outputs, segm) ? ? ? ? (loss/accumulated_steps).backward() ? ? ? ? if (idx + 1 ) % accumulated_steps == 0: ? ? ? ? ? ? optimizer.step()? ? ? ? ? ? ? optimizer.zero_grad() ? ? ? ? total_loss += loss.item() ? ? ? ?? ? ? ? ? # delete caches ? ? ? ? del img, segm, outputs, loss ? ? ? ? torch.cuda.empty_cache()
補(bǔ)充:Pytorch顯存不斷增長(zhǎng)問題的解決思路
思路很簡(jiǎn)單,就是在代碼的運(yùn)行階段輸出顯存占用量,觀察在哪一塊存在顯存劇烈增加或者顯存異常變化的情況。
但是在這個(gè)過程中要分級(jí)確認(rèn)問題點(diǎn),也即如果存在三個(gè)文件main.py、train.py、model.py。
在此種思路下,應(yīng)該先在main.py中確定問題點(diǎn),然后,從main.py中進(jìn)入到train.py中,再次輸出顯存占用量,確定問題點(diǎn)在哪。
隨后,再從train.py中的問題點(diǎn),進(jìn)入到model.py中,再次確認(rèn)。
如果還有更深層次的調(diào)用,可以繼續(xù)追溯下去。
例如:
main.py
def train(model,epochs,data): for e in range(epochs): print("1:{}".format(torch.cuda.memory_allocated(0))) train_epoch(model,data) print("2:{}".format(torch.cuda.memory_allocated(0))) eval(model,data) print("3:{}".format(torch.cuda.memory_allocated(0)))
若1與2之間顯存增加極為劇烈,說明問題出在train_epoch中,進(jìn)一步進(jìn)入到train.py中。
train.py
def train_epoch(model,data): model.train() optim=torch.optimizer() for batch_data in data: print("1:{}".format(torch.cuda.memory_allocated(0))) output=model(batch_data) print("2:{}".format(torch.cuda.memory_allocated(0))) loss=loss(output,data.target) print("3:{}".format(torch.cuda.memory_allocated(0))) optim.zero_grad() print("4:{}".format(torch.cuda.memory_allocated(0))) loss.backward() print("5:{}".format(torch.cuda.memory_allocated(0))) utils.func(model) print("6:{}".format(torch.cuda.memory_allocated(0)))
如果在1,2之間,5,6之間同時(shí)出現(xiàn)顯存增加異常的情況。此時(shí)需要使用控制變量法,例如我們先讓5,6之間的代碼失效,然后運(yùn)行,觀察是否仍然存在顯存爆炸。如果沒有,說明問題就出在5,6之間下一級(jí)的代碼中。進(jìn)入到下一級(jí)代碼,進(jìn)行調(diào)試:
utils.py
def func(model): print("1:{}".format(torch.cuda.memory_allocated(0))) a=f1(model) print("2:{}".format(torch.cuda.memory_allocated(0))) b=f2(a) print("3:{}".format(torch.cuda.memory_allocated(0))) c=f3(b) print("4:{}".format(torch.cuda.memory_allocated(0))) d=f4(c) print("5:{}".format(torch.cuda.memory_allocated(0)))
此時(shí)我們?cè)僬故玖硪环N調(diào)試思路,先注釋第5行之后的代碼,觀察顯存是否存在先訓(xùn)爆炸,如果沒有,則注釋掉第7行之后的,直至確定哪一行的代碼出現(xiàn)導(dǎo)致了顯存爆炸。假設(shè)第9行起作用后,代碼出現(xiàn)顯存爆炸,說明問題出在第九行,顯存爆炸的問題鎖定。
參考鏈接:
http://www.zzvips.com/article/196059.html
https://blog.csdn.net/fish_like_apple/article/details/101448551
到此這篇關(guān)于pytorch訓(xùn)練時(shí)的顯存占用遞增的問題解決的文章就介紹到這了,更多相關(guān)pytorch 顯存占用遞增內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
詳解使用python crontab設(shè)置linux定時(shí)任務(wù)
本篇文章主要介紹了使用python crontab設(shè)置linux定時(shí)任務(wù),具有一定的參考價(jià)值,有需要的可以了解一下。2016-12-12numpy 中l(wèi)inspace函數(shù)的使用
本文主要介紹了numpy 中l(wèi)inspace函數(shù)的使用,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-03-03Pandas對(duì)數(shù)值進(jìn)行分箱操作的4種方法總結(jié)
分箱是一種常見的數(shù)據(jù)預(yù)處理技術(shù)有時(shí)也被稱為分桶或離散化,他可用于將連續(xù)數(shù)據(jù)的間隔分組到“箱”或“桶”中。本文將使用python?Pandas庫對(duì)數(shù)值進(jìn)行分箱的4種方法,感興趣的可以了解一下2022-05-05使用python實(shí)現(xiàn)數(shù)據(jù)篩查
一般數(shù)據(jù)篩查可以通過Python中的pandas庫來實(shí)現(xiàn),下面小編就來為大家介紹一下Python如何利用pandas實(shí)現(xiàn)數(shù)據(jù)篩查,感興趣的小伙伴可以一起學(xué)習(xí)一下2023-10-10結(jié)合Python網(wǎng)絡(luò)爬蟲做一個(gè)今日新聞小程序
本篇文章介紹了我在開發(fā)過程中遇到的一個(gè)問題,以及解決該問題的過程及思路,通讀本篇對(duì)大家的學(xué)習(xí)或工作具有一定的價(jià)值,需要的朋友可以參考下2021-09-09Python連接mysql數(shù)據(jù)庫及簡(jiǎn)單增刪改查操作示例代碼
這篇文章主要介紹了Python連接mysql數(shù)據(jù)庫及簡(jiǎn)單增刪改查操作示例代碼,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-08-08