Pytorch 中retain_graph的用法詳解
用法分析
在查看SRGAN源碼時(shí)有如下?lián)p失函數(shù),其中設(shè)置了retain_graph=True,其作用是什么?
############################
# (1) Update D network: maximize D(x)-1-D(G(z))
###########################
real_img = Variable(target)
if torch.cuda.is_available():
real_img = real_img.cuda()
z = Variable(data)
if torch.cuda.is_available():
z = z.cuda()
fake_img = netG(z)
netD.zero_grad()
real_out = netD(real_img).mean()
fake_out = netD(fake_img).mean()
d_loss = 1 - real_out + fake_out
d_loss.backward(retain_graph=True) #####
optimizerD.step()
############################
# (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
###########################
netG.zero_grad()
g_loss = generator_criterion(fake_out, fake_img, real_img)
g_loss.backward()
optimizerG.step()
fake_img = netG(z)
fake_out = netD(fake_img).mean()
g_loss = generator_criterion(fake_out, fake_img, real_img)
running_results['g_loss'] += g_loss.data[0] * batch_size
d_loss = 1 - real_out + fake_out
running_results['d_loss'] += d_loss.data[0] * batch_size
running_results['d_score'] += real_out.data[0] * batch_size
running_results['g_score'] += fake_out.data[0] * batch_size
在更新D網(wǎng)絡(luò)時(shí)的loss反向傳播過程中使用了retain_graph=True,目的為是為保留該過程中計(jì)算的梯度,后續(xù)G網(wǎng)絡(luò)更新時(shí)使用;
其實(shí)retain_graph這個(gè)參數(shù)在平常中我們是用不到的,但是在特殊的情況下我們會(huì)用到它,
如下代碼:
import torch y=x**2 z=y*4 output1=z.mean() output2=z.sum() output1.backward() output2.backward()
輸出如下錯(cuò)誤信息:
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-19-8ad6b0658906> in <module>() ----> 1 output1.backward() 2 output2.backward() D:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph) 91 products. Defaults to ``False``. 92 """ ---> 93 torch.autograd.backward(self, gradient, retain_graph, create_graph) 94 95 def register_hook(self, hook): D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables) 88 Variable._execution_engine.run_backward( 89 tensors, grad_tensors, retain_graph, create_graph, ---> 90 allow_unreachable=True) # allow_unreachable flag 91 92 RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
修改成如下正確:
import torch y=x**2 z=y*4 output1=z.mean() output2=z.sum() output1.backward(retain_graph=True) output2.backward()
# 假如你有兩個(gè)Loss,先執(zhí)行第一個(gè)的backward,再執(zhí)行第二個(gè)backward loss1.backward(retain_graph=True) loss2.backward() # 執(zhí)行完這個(gè)后,所有中間變量都會(huì)被釋放,以便下一次的循環(huán) optimizer.step() # 更新參數(shù)
Variable 類源代碼
class Variable(_C._VariableBase):
"""
Attributes:
data: 任意類型的封裝好的張量。
grad: 保存與data類型和位置相匹配的梯度,此屬性難以分配并且不能重新分配。
requires_grad: 標(biāo)記變量是否已經(jīng)由一個(gè)需要調(diào)用到此變量的子圖創(chuàng)建的bool值。只能在葉子變量上進(jìn)行修改。
volatile: 標(biāo)記變量是否能在推理模式下應(yīng)用(如不保存歷史記錄)的bool值。只能在葉變量上更改。
is_leaf: 標(biāo)記變量是否是圖葉子(如由用戶創(chuàng)建的變量)的bool值.
grad_fn: Gradient function graph trace.
Parameters:
data (any tensor class): 要包裝的張量.
requires_grad (bool): bool型的標(biāo)記值. **Keyword only.**
volatile (bool): bool型的標(biāo)記值. **Keyword only.**
"""
def backward(self, gradient=None, retain_graph=None, create_graph=None, retain_variables=None):
"""計(jì)算關(guān)于當(dāng)前圖葉子變量的梯度,圖使用鏈?zhǔn)椒▌t導(dǎo)致分化
如果Variable是一個(gè)標(biāo)量(例如它包含一個(gè)單元素?cái)?shù)據(jù)),你無需對(duì)backward()指定任何參數(shù)
如果變量不是標(biāo)量(包含多個(gè)元素?cái)?shù)據(jù)的矢量)且需要梯度,函數(shù)需要額外的梯度;
需要指定一個(gè)和tensor的形狀匹配的grad_output參數(shù)(y在指定方向投影對(duì)x的導(dǎo)數(shù));
可以是一個(gè)類型和位置相匹配且包含與自身相關(guān)的不同函數(shù)梯度的張量。
函數(shù)在葉子上累積梯度,調(diào)用前需要對(duì)該葉子進(jìn)行清零。
Arguments:
grad_variables (Tensor, Variable or None):
變量的梯度,如果是一個(gè)張量,除非“create_graph”是True,否則會(huì)自動(dòng)轉(zhuǎn)換成volatile型的變量。
可以為標(biāo)量變量或不需要grad的值指定None值。如果None值可接受,則此參數(shù)可選。
retain_graph (bool, optional): 如果為False,用來計(jì)算梯度的圖將被釋放。
在幾乎所有情況下,將此選項(xiàng)設(shè)置為True不是必需的,通??梢砸愿行У姆绞浇鉀Q。
默認(rèn)值為create_graph的值。
create_graph (bool, optional): 為True時(shí),會(huì)構(gòu)造一個(gè)導(dǎo)數(shù)的圖,用來計(jì)算出更高階導(dǎo)數(shù)結(jié)果。
默認(rèn)為False,除非``gradient``是一個(gè)volatile變量。
"""
torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
def register_hook(self, hook):
"""Registers a backward hook.
每當(dāng)與variable相關(guān)的梯度被計(jì)算時(shí)調(diào)用hook,hook的申明:hook(grad)->Variable or None
不能對(duì)hook的參數(shù)進(jìn)行修改,但可以選擇性地返回一個(gè)新的梯度以用在`grad`的相應(yīng)位置。
函數(shù)返回一個(gè)handle,其``handle.remove()``方法用于將hook從模塊中移除。
Example:
>>> v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
>>> v.backward(torch.Tensor([1, 1, 1]))
>>> v.grad.data
2
2
2
[torch.FloatTensor of size 3]
>>> h.remove() # removes the hook
"""
if self.volatile:
raise RuntimeError("cannot register a hook on a volatile variable")
if not self.requires_grad:
raise RuntimeError("cannot register a hook on a variable that "
"doesn't require gradient")
if self._backward_hooks is None:
self._backward_hooks = OrderedDict()
if self.grad_fn is not None:
self.grad_fn._register_hook_dict(self)
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id] = hook
return handle
def reinforce(self, reward):
"""Registers a reward obtained as a result of a stochastic process.
區(qū)分隨機(jī)節(jié)點(diǎn)需要為他們提供reward值。如果圖表中包含任何的隨機(jī)操作,都應(yīng)該在其輸出上調(diào)用此函數(shù),否則會(huì)出現(xiàn)錯(cuò)誤。
Parameters:
reward(Tensor): 帶有每個(gè)元素獎(jiǎng)賞的張量,必須與Variable數(shù)據(jù)的設(shè)備位置和形狀相匹配。
"""
if not isinstance(self.grad_fn, StochasticFunction):
raise RuntimeError("reinforce() can be only called on outputs "
"of stochastic functions")
self.grad_fn._reinforce(reward)
def detach(self):
"""返回一個(gè)從當(dāng)前圖分離出來的心變量。
結(jié)果不需要梯度,如果輸入是volatile,則輸出也是volatile。
.. 注意::
返回變量使用與原始變量相同的數(shù)據(jù)張量,并且可以看到其中任何一個(gè)的就地修改,并且可能會(huì)觸發(fā)正確性檢查中的錯(cuò)誤。
"""
result = NoGrad()(self) # this is needed, because it merges version counters
result._grad_fn = None
return result
def detach_(self):
"""從創(chuàng)建它的圖中分離出變量并作為該圖的一個(gè)葉子"""
self._grad_fn = None
self.requires_grad = False
def retain_grad(self):
"""Enables .grad attribute for non-leaf Variables."""
if self.grad_fn is None: # no-op for leaves
return
if not self.requires_grad:
raise RuntimeError("can't retain_grad on Variable that has requires_grad=False")
if hasattr(self, 'retains_grad'):
return
weak_self = weakref.ref(self)
def retain_grad_hook(grad):
var = weak_self()
if var is None:
return
if var._grad is None:
var._grad = grad.clone()
else:
var._grad = var._grad + grad
self.register_hook(retain_grad_hook)
self.retains_grad = True
以上這篇Pytorch 中retain_graph的用法詳解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python數(shù)據(jù)分析之獲取雙色球歷史信息的方法示例
這篇文章主要介紹了Python數(shù)據(jù)分析之獲取雙色球歷史信息的方法,涉及Python網(wǎng)頁抓取、正則匹配、文件讀寫及數(shù)值運(yùn)算等相關(guān)操作技巧,需要的朋友可以參考下2018-02-02
好的Python培訓(xùn)機(jī)構(gòu)應(yīng)該具備哪些條件
python是現(xiàn)在開發(fā)的熱潮,大家應(yīng)該如何學(xué)習(xí)呢?許多人選擇自學(xué),還有人會(huì)選擇去培訓(xùn)結(jié)構(gòu)學(xué)習(xí),那么好的培訓(xùn)機(jī)構(gòu)的標(biāo)準(zhǔn)是什么樣的呢?下面跟隨腳本之家小編一起通過本文學(xué)習(xí)吧2018-05-05
PyTorch使用GPU加速計(jì)算的實(shí)現(xiàn)
PyTorch利用NVIDIA CUDA庫提供的底層接口來實(shí)現(xiàn)GPU加速計(jì)算,本文就來介紹一下PyTorch使用GPU加速計(jì)算的實(shí)現(xiàn),具有一定的參考價(jià)值,感興趣的可以了解一下2024-02-02
Python XML轉(zhuǎn)Json之XML2Dict的使用方法
今天小編就為大家分享一篇Python XML轉(zhuǎn)Json之XML2Dict的使用方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-01-01

