PyTorch中torch.no_grad()用法舉例詳解
前言
torch.no_grad() 是 PyTorch 中的一個上下文管理器,用于在上下文中臨時禁用自動梯度計算。它在模型評估或推理階段非常有用,因為在這些階段,我們通常不需要計算梯度。禁用梯度計算可以減少內(nèi)存消耗,并加快計算速度。
基本概念
在 PyTorch 中,每次對 requires_grad=True 的張量進(jìn)行操作時,PyTorch 會構(gòu)建一個計算圖(computation graph),用于計算反向傳播的梯度。這對訓(xùn)練模型是必要的,但在評估或推理時不需要。因此,我們可以使用 torch.no_grad() 來臨時禁用這些計算圖的構(gòu)建和梯度計算。
用法
torch.no_grad() 的使用非常簡單。只需要將不需要梯度計算的代碼塊放在 with torch.no_grad(): 下即可。
示例代碼
以下是一個使用 torch.no_grad() 的示例:
import torch # 創(chuàng)建一個張量,并設(shè)置 requires_grad=True 以便記錄梯度 x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) # 在 torch.no_grad() 上下文中禁用梯度計算 with torch.no_grad(): y = x + 2 print(y) # 此時,x 的 requires_grad 屬性仍然為 True,但 y 的 requires_grad 屬性為 False print("x 的 requires_grad:", x.requires_grad) print("y 的 requires_grad:", y.requires_grad)
詳細(xì)解釋
創(chuàng)建張量并設(shè)置 requires_grad=True:
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
創(chuàng)建一個包含三個元素的張量 x。
設(shè)置 requires_grad=True,告訴 PyTorch 需要為該張量記錄梯度。
禁用梯度計算:
with torch.no_grad(): y = x + 2 print(y)
進(jìn)入 torch.no_grad() 上下文,臨時禁用梯度計算。
在上下文中,對 x 進(jìn)行加法操作,得到新的張量 y。
打印 y,此時 y 的 requires_grad 屬性為 False。
查看 requires_grad 屬性:
print("x 的 requires_grad:", x.requires_grad) print("y 的 requires_grad:", y.requires_grad)
打印 x 的 requires_grad 屬性,仍然為 True。
打印 y 的 requires_grad 屬性,已被禁用為 False。
使用場景
模型評估
在評估模型性能時,不需要計算梯度。使用 torch.no_grad() 可以提高評估速度和減少內(nèi)存消耗。
model.eval() # 切換到評估模式 with torch.no_grad(): for data in validation_loader: outputs = model(data) # 計算評估指標(biāo)
模型推理
在部署和推理階段,只需要前向傳播,不需要反向傳播,因此可以使用 torch.no_grad()。
with torch.no_grad(): outputs = model(inputs) predicted = torch.argmax(outputs, dim=1)
初始化權(quán)重或其他不需要梯度的操作
在某些初始化或操作中,不需要梯度計算。
with torch.no_grad(): model.weight.fill_(1.0) # 直接修改權(quán)重
小結(jié)
torch.no_grad() 是一個用于禁用梯度計算的上下文管理器,適用于模型評估、推理等不需要梯度計算的場景。使用 torch.no_grad() 可以顯著減少內(nèi)存使用和加速計算。通過理解和合理使用 torch.no_grad(),可以使得模型評估和推理更加高效和穩(wěn)定。
額外注意事項
訓(xùn)練模式與評估模式:
在使用 torch.no_grad() 時,通常還會將模型設(shè)置為評估模式(model.eval()),以確保某些層(如 dropout 和 batch normalization)在推理時的行為與訓(xùn)練時不同。
嵌套使用:
torch.no_grad() 可以嵌套使用,內(nèi)層的 torch.no_grad() 仍然會禁用梯度計算。
with torch.no_grad(): with torch.no_grad(): y = x + 2 print(y)
恢復(fù)梯度計算:
在 torch.no_grad() 上下文管理器退出后,梯度計算會自動恢復(fù),不需要額外操作。
with torch.no_grad(): y = x + 2 print(y) # 這里梯度計算恢復(fù) z = x * 2 print(z.requires_grad) # True
通過合理使用 torch.no_grad(),可以在不需要梯度計算的場景中提升性能并節(jié)省資源。
總結(jié)
到此這篇關(guān)于PyTorch中torch.no_grad()用法舉例詳解的文章就介紹到這了,更多相關(guān)PyTorch torch.no_grad()詳解內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
PyTorch?之?強(qiáng)大的?hub?模塊和搭建神經(jīng)網(wǎng)絡(luò)進(jìn)行氣溫預(yù)測
hub 模塊是調(diào)用別人訓(xùn)練好的網(wǎng)絡(luò)架構(gòu)以及訓(xùn)練好的權(quán)重參數(shù),使得自己的一行代碼就可以解決問題,方便大家進(jìn)行調(diào)用,這篇文章主要介紹了PyTorch?之?強(qiáng)大的?hub?模塊和搭建神經(jīng)網(wǎng)絡(luò)進(jìn)行氣溫預(yù)測,需要的朋友可以參考下2023-03-03Python圖像處理之目標(biāo)物體輪廓提取的實現(xiàn)方法
目標(biāo)物體的輪廓實質(zhì)是指一系列像素點構(gòu)成,這些點構(gòu)成了一個有序的點集,這篇文章主要給大家介紹了關(guān)于Python圖像處理之目標(biāo)物體輪廓提取的實現(xiàn)方法,需要的朋友可以參考下2021-08-08Python編寫Windows Service服務(wù)程序
這篇文章主要為大家詳細(xì)介紹了Python編寫Windows Service服務(wù)程序,具有一定的參考價值,感興趣的小伙伴們可以參考一下2018-01-01