Pytorch反向傳播中的細(xì)節(jié)-計(jì)算梯度時(shí)的默認(rèn)累加操作
Pytorch反向傳播計(jì)算梯度默認(rèn)累加
今天學(xué)習(xí)pytorch實(shí)現(xiàn)簡(jiǎn)單的線性回歸,發(fā)現(xiàn)了pytorch的反向傳播時(shí)計(jì)算梯度采用的累加機(jī)制, 于是百度來(lái)一下,好多博客都說(shuō)了累加機(jī)制,但是好多都沒(méi)有說(shuō)明這個(gè)累加機(jī)制到底會(huì)有啥影響, 所以我趁著自己練習(xí)的一個(gè)例子正好直觀的看一下以及如何解決:
pytorch實(shí)現(xiàn)線性回歸
先附上試驗(yàn)代碼來(lái)感受一下:
torch.manual_seed(6) lr = 0.01 # 學(xué)習(xí)率 result = [] # 創(chuàng)建訓(xùn)練數(shù)據(jù) x = torch.rand(20, 1) * 10 y = 2 * x + (5 + torch.randn(20, 1)) # 構(gòu)建線性回歸函數(shù) w = torch.randn((1), requires_grad=True) b = torch.zeros((1), requires_grad=True) # 這里是迭代過(guò)程,為了看pytorch的反向傳播計(jì)算梯度的細(xì)節(jié),我先迭代兩次 for iteration in range(2): # 前向傳播 wx = torch.mul(w, x) y_pred = torch.add(wx, b) # 計(jì)算 MSE loss loss = (0.5 * (y - y_pred) ** 2).mean() # 反向傳播 loss.backward() # 這里看一下反向傳播計(jì)算的梯度 print("w.grad:", w.grad) print("b.grad:", b.grad) # 更新參數(shù) b.data.sub_(lr * b.grad) w.data.sub_(lr * w.grad)
上面的代碼比較簡(jiǎn)單,迭代了兩次, 看一下計(jì)算的梯度結(jié)果:
w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-122.9075])
b.grad: tensor([-20.9364])
然后我稍微加兩行代碼, 就是在反向傳播上面,我手動(dòng)添加梯度清零操作的代碼,再感受一下結(jié)果:
torch.manual_seed(6) lr = 0.01 result = [] # 創(chuàng)建訓(xùn)練數(shù)據(jù) x = torch.rand(20, 1) * 10 #print(x) y = 2 * x + (5 + torch.randn(20, 1)) #print(y) # 構(gòu)建線性回歸函數(shù) w = torch.randn((1), requires_grad=True) #print(w) b = torch.zeros((1), requires_grad=True) #print(b) for iteration in range(2): # 前向傳播 wx = torch.mul(w, x) y_pred = torch.add(wx, b) # 計(jì)算 MSE loss loss = (0.5 * (y - y_pred) ** 2).mean() # 由于pytorch反向傳播中,梯度是累加的,所以如果不想先前的梯度影響當(dāng)前梯度的計(jì)算,需要手動(dòng)清0 if iteration > 0: w.grad.data.zero_() b.grad.data.zero_() # 反向傳播 loss.backward() # 看一下梯度 print("w.grad:", w.grad) print("b.grad:", b.grad) # 更新參數(shù) b.data.sub_(lr * b.grad) w.data.sub_(lr * w.grad)
w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-48.2813])
b.grad: tensor([-8.3831])
從上面可以發(fā)現(xiàn),pytorch在反向傳播的時(shí)候,確實(shí)是默認(rèn)累加上了上一次求的梯度, 如果不想讓上一次的梯度影響自己本次梯度計(jì)算的話,需要手動(dòng)的清零。
但是, 如果不進(jìn)行手動(dòng)清零的話,會(huì)有什么后果呢? 我在這次線性回歸試驗(yàn)中,遇到的后果就是loss值反復(fù)的震蕩不收斂。下面感受一下:
torch.manual_seed(6) lr = 0.01 result = [] # 創(chuàng)建訓(xùn)練數(shù)據(jù) x = torch.rand(20, 1) * 10 #print(x) y = 2 * x + (5 + torch.randn(20, 1)) #print(y) # 構(gòu)建線性回歸函數(shù) w = torch.randn((1), requires_grad=True) #print(w) b = torch.zeros((1), requires_grad=True) #print(b) for iteration in range(1000): # 前向傳播 wx = torch.mul(w, x) y_pred = torch.add(wx, b) # 計(jì)算 MSE loss loss = (0.5 * (y - y_pred) ** 2).mean() # print("iteration {}: loss {}".format(iteration, loss)) result.append(loss) # 由于pytorch反向傳播中,梯度是累加的,所以如果不想先前的梯度影響當(dāng)前梯度的計(jì)算,需要手動(dòng)清0 #if iteration > 0: # w.grad.data.zero_() # b.grad.data.zero_() # 反向傳播 loss.backward() # 更新參數(shù) b.data.sub_(lr * b.grad) w.data.sub_(lr * w.grad) if loss.data.numpy() < 1: break plt.plot(result)
上面的代碼中,我沒(méi)有進(jìn)行手動(dòng)清零,迭代1000次, 把每一次的loss放到來(lái)result中, 然后畫出圖像,感受一下結(jié)果:
接下來(lái),我把手動(dòng)清零的注釋打開(kāi),進(jìn)行每次迭代之后的手動(dòng)清零操作,得到的結(jié)果:
可以看到,這個(gè)才是理想中的反向傳播求導(dǎo),然后更新參數(shù)后得到的loss值的變化。
總結(jié)
這次主要是記錄一下,pytorch在進(jìn)行反向傳播計(jì)算梯度的時(shí)候的累加機(jī)制到底是什么樣子? 至于為什么采用這種機(jī)制,我也搜了一下,大部分給出的結(jié)果是這樣子的:
但是如果不想累加的話,可以采用手動(dòng)清零的方式,只需要在每次迭代時(shí)加上即可
w.grad.data.zero_() b.grad.data.zero_()
另外, 在搜索資料的時(shí)候,在一篇博客上看到兩個(gè)不錯(cuò)的線性回歸時(shí)pytorch的計(jì)算圖在這里借用一下:
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python pandas 對(duì)series和dataframe的重置索引reindex方法
今天小編就為大家分享一篇python pandas 對(duì)series和dataframe的重置索引reindex方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-06-06python掃描proxy并獲取可用代理ip的實(shí)例
下面小編就為大家?guī)?lái)一篇python掃描proxy并獲取可用代理ip的實(shí)例。小編覺(jué)得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2017-08-08手把手教你怎么用Python實(shí)現(xiàn)zip文件密碼的破解
之前在家里的老電腦中,發(fā)現(xiàn)一個(gè)加密zip壓縮包,由于時(shí)隔太久忘記密碼了,依稀記得密碼是6位字母加數(shù)字,網(wǎng)上下載了很多破解密碼的軟件都沒(méi)有效果,于是想到自己用Python寫一個(gè)暴力破解密碼的腳本,需要的朋友可以參考下2021-05-05詳解使用python的logging模塊在stdout輸出的兩種方法
這篇文章主要介紹了詳解使用python的logging模塊在stdout輸出的相關(guān)資料,需要的朋友可以參考下2017-05-05Python循環(huán)實(shí)現(xiàn)n的全排列功能
這篇文章主要介紹了Python循環(huán)實(shí)現(xiàn)n的全排列功能,本文給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-09-09Python實(shí)現(xiàn)softmax反向傳播的示例代碼
這篇文章主要為大家詳細(xì)介紹了Python實(shí)現(xiàn)softmax反向傳播的相關(guān)資料,文中的示例代碼講解詳細(xì),具有一定的參考價(jià)值,感興趣的可以了解一下2023-04-04