在pytorch中實(shí)現(xiàn)只讓指定變量向后傳播梯度
pytorch中如何只讓指定變量向后傳播梯度?
(或者說(shuō)如何讓指定變量不參與后向傳播?)
有以下公式,假如要讓L對(duì)xvar求導(dǎo):
(1)中,L對(duì)xvar的求導(dǎo)將同時(shí)計(jì)算out1部分和out2部分;
(2)中,L對(duì)xvar的求導(dǎo)只計(jì)算out2部分,因?yàn)閛ut1的requires_grad=False;
(3)中,L對(duì)xvar的求導(dǎo)只計(jì)算out1部分,因?yàn)閛ut2的requires_grad=False;
驗(yàn)證如下:
#!/usr/bin/env python2 # -*- coding: utf-8 -*- """ Created on Wed May 23 10:02:04 2018 @author: hy """ import torch from torch.autograd import Variable print("Pytorch version: {}".format(torch.__version__)) x=torch.Tensor([1]) xvar=Variable(x,requires_grad=True) y1=torch.Tensor([2]) y2=torch.Tensor([7]) y1var=Variable(y1) y2var=Variable(y2) #(1) print("For (1)") print("xvar requres_grad: {}".format(xvar.requires_grad)) print("y1var requres_grad: {}".format(y1var.requires_grad)) print("y2var requres_grad: {}".format(y2var.requires_grad)) out1 = xvar*y1var print("out1 requres_grad: {}".format(out1.requires_grad)) out2 = xvar*y2var print("out2 requres_grad: {}".format(out2.requires_grad)) L=torch.pow(out1-out2,2) L.backward() print("xvar.grad: {}".format(xvar.grad)) xvar.grad.data.zero_() #(2) print("For (2)") print("xvar requres_grad: {}".format(xvar.requires_grad)) print("y1var requres_grad: {}".format(y1var.requires_grad)) print("y2var requres_grad: {}".format(y2var.requires_grad)) out1 = xvar*y1var print("out1 requres_grad: {}".format(out1.requires_grad)) out2 = xvar*y2var print("out2 requres_grad: {}".format(out2.requires_grad)) out1 = out1.detach() print("after out1.detach(), out1 requres_grad: {}".format(out1.requires_grad)) L=torch.pow(out1-out2,2) L.backward() print("xvar.grad: {}".format(xvar.grad)) xvar.grad.data.zero_() #(3) print("For (3)") print("xvar requres_grad: {}".format(xvar.requires_grad)) print("y1var requres_grad: {}".format(y1var.requires_grad)) print("y2var requres_grad: {}".format(y2var.requires_grad)) out1 = xvar*y1var print("out1 requres_grad: {}".format(out1.requires_grad)) out2 = xvar*y2var print("out2 requres_grad: {}".format(out2.requires_grad)) #out1 = out1.detach() out2 = out2.detach() print("after out2.detach(), out2 requres_grad: {}".format(out1.requires_grad)) L=torch.pow(out1-out2,2) L.backward() print("xvar.grad: {}".format(xvar.grad)) xvar.grad.data.zero_()
pytorch中,將變量的requires_grad設(shè)為False,即可讓變量不參與梯度的后向傳播;
但是不能直接將out1.requires_grad=False;
其實(shí),Variable類型提供了detach()方法,所返回變量的requires_grad為False。
注意:如果out1和out2的requires_grad都為False的話,那么xvar.grad就出錯(cuò)了,因?yàn)樘荻葲](méi)有傳到xvar
補(bǔ)充:
volatile=True表示這個(gè)變量不計(jì)算梯度, 參考:Volatile is recommended for purely inference mode, when you're sure you won't be even calling .backward(). It's more efficient than any other autograd setting - it will use the absolute minimal amount of memory to evaluate the model. volatile also determines that requires_grad is False.
以上這篇在pytorch中實(shí)現(xiàn)只讓指定變量向后傳播梯度就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python os.listdir按文件存取時(shí)間順序列出目錄的實(shí)例
今天小編就為大家分享一篇python os.listdir按文件存取時(shí)間順序列出目錄的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-10-10用Python實(shí)現(xiàn)zip密碼破解實(shí)例
大家好,本篇文章主要講的是用Python實(shí)現(xiàn)zip密碼破解實(shí)例,感興趣的同學(xué)趕快來(lái)看一看吧,對(duì)你有幫助的話記得收藏一下2022-01-01python判斷計(jì)算機(jī)是否有網(wǎng)絡(luò)連接的實(shí)例
今天小編就為大家分享一篇python判斷計(jì)算機(jī)是否有網(wǎng)絡(luò)連接的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-12-12Python實(shí)現(xiàn)隨機(jī)生成算術(shù)題的示例代碼
這篇文章主要為大家詳細(xì)介紹了如何利用Python實(shí)現(xiàn)隨機(jī)生成算術(shù)題的功能,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2023-04-04Python寫(xiě)一個(gè)簡(jiǎn)單上課點(diǎn)名系統(tǒng)(附源碼)
這篇文章主要介紹了Python寫(xiě)一個(gè)簡(jiǎn)單上課點(diǎn)名系統(tǒng),文章圍繞Python得性概念資料寫(xiě)一個(gè)簡(jiǎn)的得上課點(diǎn)名系統(tǒng),并附上詳細(xì)的代碼即過(guò)程總結(jié),需要的朋友可以參考一下,希望對(duì)你有所幫助2021-11-11Python編程實(shí)現(xiàn)輸入某年某月某日計(jì)算出這一天是該年第幾天的方法
這篇文章主要介紹了Python編程實(shí)現(xiàn)輸入某年某月某日計(jì)算出這一天是該年第幾天的方法,涉及Python針對(duì)日期時(shí)間的轉(zhuǎn)換與運(yùn)算相關(guān)操作技巧,需要的朋友可以參考下2017-04-04python寫(xiě)入csv時(shí)writerow()和writerows()函數(shù)簡(jiǎn)單示例
這篇文章主要給大家介紹了關(guān)于python寫(xiě)入csv時(shí)writerow()和writerows()函數(shù)的相關(guān)資料,writerows和writerow是Python中csv模塊中的兩個(gè)函數(shù),用于將數(shù)據(jù)寫(xiě)入CSV文件,需要的朋友可以參考下2023-07-07