PyTorch中clone()、detach()及相關擴展詳解
clone() 與 detach() 對比
Torch 為了提高速度,向量或是矩陣的賦值是指向同一內存的,這不同于 Matlab。如果需要保存舊的tensor即需要開辟新的存儲地址而不是引用,可以用 clone() 進行深拷貝,
首先我們來打印出來clone()操作后的數據類型定義變化:
(1). 簡單打印類型
import torch a = torch.tensor(1.0, requires_grad=True) b = a.clone() c = a.detach() a.data *= 3 b += 1 print(a) # tensor(3., requires_grad=True) print(b) print(c) ''' 輸出結果: tensor(3., requires_grad=True) tensor(2., grad_fn=<AddBackward0>) tensor(3.) # detach()后的值隨著a的變化出現(xiàn)變化 '''
grad_fn=<CloneBackward>,表示clone后的返回值是個中間變量,因此支持梯度的回溯。clone操作在一定程度上可以視為是一個identity-mapping函數。
detach()操作后的tensor與原始tensor共享數據內存,當原始tensor在計算圖中數值發(fā)生反向傳播等更新之后,detach()的tensor值也發(fā)生了改變。
注意: 在pytorch中我們不要直接使用id是否相等來判斷tensor是否共享內存,這只是充分條件,因為也許底層共享數據內存,但是仍然是新的tensor,比如detach(),如果我們直接打印id會出現(xiàn)以下情況。
import torch as t a = t.tensor([1.0,2.0], requires_grad=True) b = a.detach() #c[:] = a.detach() print(id(a)) print(id(b)) #140568935450520 140570337203616
顯然直接打印出來的id不等,我們可以通過簡單的賦值后觀察數據變化進行判斷。
(2). clone()的梯度回傳
detach()函數可以返回一個完全相同的tensor,與舊的tensor共享內存,脫離計算圖,不會牽扯梯度計算。
而clone充當中間變量,會將梯度傳給源張量進行疊加,但是本身不保存其grad,即值為None
import torch a = torch.tensor(1.0, requires_grad=True) a_ = a.clone() y = a**2 z = a ** 2+a_ * 3 y.backward() print(a.grad) # 2 z.backward() print(a_.grad) # None. 中間variable,無grad print(a.grad) ''' 輸出: tensor(2.) None tensor(7.) # 2*2+3=7 '''
使用torch.clone()獲得的新tensor和原來的數據不再共享內存,但仍保留在計算圖中,clone操作在不共享數據內存的同時支持梯度梯度傳遞與疊加,所以常用在神經網絡中某個單元需要重復使用的場景下。
通常如果原tensor的requires_grad=True,則:
- clone()操作后的tensor requires_grad=True
- detach()操作后的tensor requires_grad=False。
import torch torch.manual_seed(0) x= torch.tensor([1., 2.], requires_grad=True) clone_x = x.clone() detach_x = x.detach() clone_detach_x = x.clone().detach() f = torch.nn.Linear(2, 1) y = f(x) y.backward() print(x.grad) print(clone_x.requires_grad) print(clone_x.grad) print(detach_x.requires_grad) print(clone_detach_x.requires_grad) ''' 輸出結果如下: tensor([-0.0053, 0.3793]) True None False False '''
另一個比較特殊的是當源張量的 require_grad=False,clone后的張量 require_grad=True,此時不存在張量回傳現(xiàn)象,可以得到clone后的張量求導。
如下:
import torch a = torch.tensor(1.0) a_ = a.clone() a_.requires_grad_() #require_grad=True y = a_ ** 2 y.backward() print(a.grad) # None print(a_.grad) ''' 輸出: None tensor(2.) '''
了解了兩者的區(qū)別后我們常與其他函數進行搭配使用,實現(xiàn)數據拷貝后的其他需要。
比如我們經常使用view()函數對tensor進行reshape操作。返回的新Tensor與源Tensor可能有不同的size,但是是共享data的,即其中的一個發(fā)生變化,另外一個也會跟著改變。
需要注意的是view返回的Tensor與源Tensor是共享data的,但是依然是一個新的Tensor(因為Tensor除了包含data外還有一些其他屬性),兩者id(內存地址)并不一致。
x = torch.rand(2, 2) y = x.view(4) x += 1 print(x) print(y) # 也加了1
view() 僅僅是改變了對這個張量的觀察角度,內部數據并未改變。這時候想返回一個真正新的副本(即不共享data內存)該怎么辦呢?Pytorch還提供了一個reshape()可以改變形狀,但是此函數并不能保證返回的是其拷貝,所以不推薦使用。推薦先用clone創(chuàng)造一個副本然后再使用view。參考此處
x = torch.rand(2, 2) x_cp = x.clone().view(4) x += 1 print(id(x)) print(id(x_cp)) print(x) print(x_cp) ''' 140568935036464 140568935035816 tensor([[0.4963, 0.7682], [0.1320, 0.3074]]) tensor([[1.4963, 1.7682, 1.1320, 1.3074]]) '''
另外使用clone()會被記錄在計算圖中,即梯度回傳到副本時也會傳到源Tensor。在上一篇中有總結。
總結:
- torch.detach() — 新的tensor會脫離計算圖,不會牽扯梯度計算
- torch.clone() — 新的tensor充當中間變量,會保留在計算圖中,參與梯度計算(回傳疊加),但是一般不會保留自身梯度。
原地操作(in-place, such as resize_ / resize_as_ / set_ / transpose_) 在上面兩者中執(zhí)行都會引發(fā)錯誤或者警告。 - 共享數據內存是底層設計,并不能簡單的通過直接打印tensor的id地址進行判斷,需要在進行賦值或運算操作后打印比較數據的變化進行判斷。
- 復制操作可以根據實際需要進行結合使用。
引用官方文檔的話:如果你使用了in-place operation而沒有報錯的話,那么你可以確定你的梯度計算是正確的。另外盡量避免in-place的使用。
像y = x + y這樣的運算會新開內存,然后將y指向新內存。我們可以使用Python自帶的id函數進行驗證:如果兩個實例的ID相同,則它們所對應的內存地址相同。
到此這篇關于PyTorch中clone()、detach()及相關擴展詳解的文章就介紹到這了,更多相關PyTorch中clone()、detach()及相關擴展內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
Python實現(xiàn)讀取大量Excel文件并跨文件批量計算平均值
這篇文章主要為大家詳細介紹了如何利用Python語言,實現(xiàn)對多個不同Excel文件進行數據讀取與平均值計算的方法,感興趣的可以了解一下2023-02-02
對pandas的dataframe繪圖并保存的實現(xiàn)方法
下面小編就為大家?guī)硪黄獙andas的dataframe繪圖并保存的實現(xiàn)方法。小編覺得挺不錯的,現(xiàn)在就分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2017-08-08
Python?tkinter中l(wèi)abel控件動態(tài)改變值問題
這篇文章主要介紹了Python?tkinter中l(wèi)abel控件動態(tài)改變值問題,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2023-01-01

