亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

pytorch如何自定義forward和backward函數(shù)

 更新時(shí)間:2024年10月12日 16:08:13   作者:xx_xjm  
PyTorch自動(dòng)求導(dǎo)功能強(qiáng)大,但在特定情況下需要用戶自行定義backward函數(shù),通過實(shí)例解釋了保存變量、計(jì)算梯度、鏈?zhǔn)椒▌t等核心概念,并展示了如何通過自定義函數(shù)集成到網(wǎng)絡(luò)中以及如何正確返回梯度,此外,還討論了多輸出情況下的梯度傳遞

pytorch自定義forward和backward函數(shù)

pytorch會(huì)自動(dòng)求導(dǎo),但是當(dāng)遇到無法自動(dòng)求導(dǎo)的時(shí)候,需要自己認(rèn)為定義求導(dǎo)過程,這個(gè)時(shí)候就涉及到要定義自己的forward和backward函數(shù)。

舉例如下:

看到這里,大家應(yīng)該會(huì)有很多疑問

比如:

  • 1:ctx.save_for_backward和ctx.saved_tensors的含義
  • 2:backward中各個(gè)計(jì)算函數(shù)的意義,以及backward的輸入?yún)?shù)grad_out是什么,以及grad_out包含哪些數(shù)據(jù)。

針對(duì)以上問題,我們一個(gè)個(gè)解答

  • 第一個(gè)問題:百度吧,答案很多?。。?!
  • 第二個(gè)問題:拿上面這個(gè)例子來看,我們定義了一個(gè)類似于線性層的東西,但注意這不是線性層,因?yàn)槲覀兪侵苯影演斎牒蛍eight用*來做點(diǎn)對(duì)點(diǎn)的乘法的,所以這不是我們通常情況下的線性層。

但是這么看也費(fèi)勁,我們寫一個(gè)網(wǎng)絡(luò),把這個(gè)函數(shù)加到網(wǎng)絡(luò)中去,再完整的跑一遍看吧!

測(cè)試代碼

結(jié)果如下:

來進(jìn)行解答

首先,backward函數(shù)的返回值,就是對(duì)應(yīng)著forward里面的參數(shù)的梯度,也就是說,forward函數(shù)里面有幾個(gè)輸入?yún)?shù),那么backward函數(shù)的輸出就要有幾個(gè)!為什么是這樣?

我們首先要理解backward的輸入grad_out,為什么backward的參數(shù)就是一個(gè),因?yàn)檫@是根據(jù)鏈?zhǔn)椒▌t來的

比如,我們定義三個(gè)函數(shù)H(對(duì)應(yīng)上面網(wǎng)絡(luò)中l(wèi)inear1),F(自定義函數(shù)xjm_inter),D(對(duì)應(yīng)上面網(wǎng)絡(luò)中l(wèi)inear2),定義一個(gè)輸入x(對(duì)應(yīng)上面輸入a),定義一個(gè)輸出y(對(duì)應(yīng)上面輸出b):

y = D(F(H(X)))

現(xiàn)在,我們求y對(duì)x的偏導(dǎo),那么:

dy/dx = dy/dD * dD/dF * dF/dH * dH/dx

好吧看到這里你可能還是不懂,為什么backward的參數(shù)就是一個(gè)grad_out?。?/p>

我們韓式以上面則個(gè)函數(shù)為例子,但是,我們現(xiàn)在不求y對(duì)x的導(dǎo)數(shù),我們假設(shè)F函數(shù)有一個(gè)葉子節(jié)點(diǎn)(或者說requires_grad=True)的參數(shù)w1,現(xiàn)在我們要求y對(duì)w1的導(dǎo)數(shù):

所以

dy/dw1 = dy/dD *dD/dF * dF/dw1

那么此時(shí),F(xiàn)就是我們上面代碼中自定義的xjm_inter函數(shù),則 grad_out = dy/dD *dD/dF。

怎么理解呢,根據(jù)鏈?zhǔn)椒▌t,我們呢所定義的網(wǎng)絡(luò)中的每一層都是一個(gè)單獨(dú)的函數(shù),所以函數(shù)中的變量的最終求導(dǎo)其實(shí)只取決于該函數(shù)本身,鏈?zhǔn)椒▌t求導(dǎo)傳遞過來的其實(shí)永遠(yuǎn)都知識(shí)一個(gè)值,這就是為什么backward函數(shù)的輸出只有一個(gè)。

擴(kuò)展

當(dāng)forward的輸出有多個(gè)的時(shí)候,那么就有多個(gè)鏈?zhǔn)椒▌t,因?yàn)榭梢酝瑫r(shí)對(duì)x或者對(duì)w求導(dǎo),此時(shí)backward的輸入可以是一個(gè),也可以是對(duì)應(yīng)forward輸出的個(gè)數(shù),如果是一個(gè)則是一個(gè)元組,包含對(duì)應(yīng)的梯度?。?!

那么我們的backward要實(shí)現(xiàn)什么樣的功能呢?說到這里,大家應(yīng)該大概能明白了,就是實(shí)現(xiàn)當(dāng)前層那的梯度計(jì)算,并進(jìn)行返回,所以,這也是為什么backward的返回值要和forward的輸入值一一對(duì)應(yīng),否則會(huì)報(bào)錯(cuò)。

總結(jié)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

最新評(píng)論