pytorch_detach 切斷網(wǎng)絡(luò)反傳方式
detach
官方文檔中,對這個方法是這么介紹的。
detach = _add_docstr(_C._TensorBase.detach, r""" Returns a new Tensor, detached from the current graph. The result will never require gradient. .. note:: Returned Tensor uses the same data tensor as the original one. In-place modifications on either of them will be seen, and may trigger errors in correctness checks. """)
返回一個新的從當前圖中分離的 Variable。
返回的 Variable 永遠不會需要梯度
如果 被 detach 的Variable volatile=True, 那么 detach 出來的 volatile 也為 True
還有一個注意事項,即:返回的 Variable 和 被 detach 的Variable 指向同一個 tensor
import torch from torch.nn import init t1 = torch.tensor([1., 2.],requires_grad=True) t2 = torch.tensor([2., 3.],requires_grad=True) v3 = t1 + t2 v3_detached = v3.detach() v3_detached.data.add_(t1) # 修改了 v3_detached Variable中 tensor 的值 print(v3, v3_detached) # v3 中tensor 的值也會改變 print(v3.requires_grad,v3_detached.requires_grad) ''' tensor([4., 7.], grad_fn=<AddBackward0>) tensor([4., 7.]) True False '''
在pytorch中通過拷貝需要切斷位置前的tensor實現(xiàn)這個功能。tensor中拷貝的函數(shù)有兩個,一個是clone(),另外一個是copy_(),clone()相當于完全復制了之前的tensor,他的梯度也會復制,而且在反向傳播時,克隆的樣本和結(jié)果是等價的,可以簡單的理解為clone只是給了同一個tensor不同的代號,和‘='等價。所以如果想要生成一個新的分開的tensor,請使用copy_()。
不過對于這樣的操作,pytorch中有專門的函數(shù)——detach()。
用戶自己創(chuàng)建的節(jié)點是leaf_node(如圖中的abc三個節(jié)點),不依賴于其他變量,對于leaf_node不能進行in_place操作.根節(jié)點是計算圖的最終目標(如圖y),通過鏈式法則可以計算出所有節(jié)點相對于根節(jié)點的梯度值.這一過程通過調(diào)用root.backward()就可以實現(xiàn).
因此,detach所做的就是,重新聲明一個變量,指向原變量的存放位置,但是requires_grad為false.更深入一點的理解是,計算圖從detach過的變量這里就斷了, 它變成了一個leaf_node.即使之后重新將它的requires_node置為true,它也不會具有梯度.
pytorch 梯度
(0.4之后),tensor和variable合并,tensor具有g(shù)rad、grad_fn等屬性;
默認創(chuàng)建的tensor,grad默認為False, 如果當前tensor_grad為None,則不會向前傳播,如果有其它支路具有g(shù)rad,則只傳播其它支路的grad
# 默認創(chuàng)建requires_grad = False的Tensor x = torch.ones(1) # create a tensor with requires_grad=False (default) print(x.requires_grad) # out: False # 創(chuàng)建另一個Tensor,同樣requires_grad = False y = torch.ones(1) # another tensor with requires_grad=False # both inputs have requires_grad=False. so does the output z = x + y # 因為兩個Tensor x,y,requires_grad=False.都無法實現(xiàn)自動微分, # 所以操作(operation)z=x+y后的z也是無法自動微分,requires_grad=False print(z.requires_grad) # out: False # then autograd won't track this computation. let's verify! # 因而無法autograd,程序報錯 # z.backward() # out:程序報錯:RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn # now create a tensor with requires_grad=True w = torch.ones(1, requires_grad=True) print(w.requires_grad) # out: True # add to the previous result that has require_grad=False # 因為total的操作中輸入Tensor w的requires_grad=True,因而操作可以進行反向傳播和自動求導。 total = w + z # the total sum now requires grad! total.requires_grad # out: True # autograd can compute the gradients as well total.backward() print(w.grad) #out: tensor([ 1.]) # and no computation is wasted to compute gradients for x, y and z, which don't require grad # 由于z,x,y的requires_grad=False,所以并沒有計算三者的梯度 z.grad == x.grad == y.grad == None # True
nn.Paramter
import torch.nn.functional as F # With square kernels and equal stride filters = torch.randn(8,4,3,3) weiths = torch.nn.Parameter(torch.randn(8,4,3,3)) inputs = torch.randn(1,4,5,5) out = F.conv2d(inputs, weiths, stride=2,padding=1) print(out.shape) con2d = torch.nn.Conv2d(4,8,3,stride=2,padding=1) out_2 = con2d(inputs) print(out_2.shape)
補充:Pytorch-detach()用法
目的:
神經(jīng)網(wǎng)絡(luò)的訓練有時候可能希望保持一部分的網(wǎng)絡(luò)參數(shù)不變,只對其中一部分的參數(shù)進行調(diào)整。
或者訓練部分分支網(wǎng)絡(luò),并不讓其梯度對主網(wǎng)絡(luò)的梯度造成影響.這時候我們就需要使用detach()函數(shù)來切斷一些分支的反向傳播.
1 tensor.detach()
返回一個新的tensor,從當前計算圖中分離下來。但是仍指向原變量的存放位置,不同之處只是requirse_grad為false.得到的這個tensir永遠不需要計算器梯度,不具有g(shù)rad.
即使之后重新將它的requires_grad置為true,它也不會具有梯度grad.這樣我們就會繼續(xù)使用這個新的tensor進行計算,后面當我們進行反向傳播時,到該調(diào)用detach()的tensor就會停止,不能再繼續(xù)向前進行傳播.
注意:
使用detach返回的tensor和原始的tensor共同一個內(nèi)存,即一個修改另一個也會跟著改變。
比如正常的例子是:
import torch a = torch.tensor([1, 2, 3.], requires_grad=True) print(a) print(a.grad) out = a.sigmoid() out.sum().backward() print(a.grad)
輸出
tensor([1., 2., 3.], requires_grad=True)
None
tensor([0.1966, 0.1050, 0.0452])
1.1 當使用detach()分離tensor但是沒有更改這個tensor時,并不會影響backward():
import torch a = torch.tensor([1, 2, 3.], requires_grad=True) print(a.grad) out = a.sigmoid() print(out) #添加detach(),c的requires_grad為False c = out.detach() print(c) #這時候沒有對c進行更改,所以并不會影響backward() out.sum().backward() print(a.grad) '''返回: None tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>) tensor([0.7311, 0.8808, 0.9526]) tensor([0.1966, 0.1050, 0.0452]) '''
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。如有錯誤或未考慮完全的地方,望不吝賜教。
- pytorch 禁止/允許計算局部梯度的操作
- 如何利用Pytorch計算三角函數(shù)
- 聊聊PyTorch中eval和no_grad的關(guān)系
- Pytorch實現(xiàn)圖像識別之數(shù)字識別(附詳細注釋)
- Pytorch實現(xiàn)全連接層的操作
- pytorch 優(yōu)化器(optim)不同參數(shù)組,不同學習率設(shè)置的操作
- PyTorch 如何將CIFAR100數(shù)據(jù)按類標歸類保存
- PyTorch的Debug指南
- Python深度學習之使用Pytorch搭建ShuffleNetv2
- win10系統(tǒng)配置GPU版本Pytorch的詳細教程
- 淺談pytorch中的nn.Sequential(*net[3: 5])是啥意思
- pytorch visdom安裝開啟及使用方法
- PyTorch CUDA環(huán)境配置及安裝的步驟(圖文教程)
- pytorch中的nn.ZeroPad2d()零填充函數(shù)實例詳解
- 使用pytorch實現(xiàn)線性回歸
- pytorch實現(xiàn)線性回歸以及多元回歸
- PyTorch學習之軟件準備與基本操作總結(jié)
相關(guān)文章
PyQt5 實現(xiàn)字體大小自適應(yīng)分辨率的方法
今天小編就為大家分享一篇PyQt5 實現(xiàn)字體大小自適應(yīng)分辨率的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-06-06Python中往列表中插入字典時,數(shù)據(jù)重復問題
這篇文章主要介紹了Python中往列表中插入字典時,數(shù)據(jù)重復問題,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2024-02-02Python的經(jīng)緯度與xy坐標系相互轉(zhuǎn)換方式
這篇文章主要介紹了Python的經(jīng)緯度與xy坐標系相互轉(zhuǎn)換方式,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2024-02-02