pytorch中的自定義反向傳播,求導實例
pytorch中自定義backward()函數(shù)。在圖像處理過程中,我們有時候會使用自己定義的算法處理圖像,這些算法多是基于numpy或者scipy等包。
那么如何將自定義算法的梯度加入到pytorch的計算圖中,能使用Loss.backward()操作自動求導并優(yōu)化呢。下面的代碼展示了這個功能`
import torch import numpy as np from PIL import Image from torch.autograd import gradcheck class Bicubic(torch.autograd.Function): def basis_function(self, x, a=-1): x_abs = np.abs(x) if x_abs < 1 and x_abs >= 0: y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1 elif x_abs > 1 and x_abs < 2: y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a else: y = 0 return y def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'): # data_in = data_in.detach().numpy() self.grad = np.zeros(data_in.shape,dtype=np.float32) obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2]) data_tmp = data_in.copy() data_obj = np.zeros(shape=obj_shape, dtype=np.float32) data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode) print(data_tmp.shape) for axis0 in range(obj_shape[0]): f_0 = float(axis0) / scale - np.floor(axis0 / scale) int_0 = int(axis0 / scale) + 2 axis0_weight = np.array( [[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]]) for axis1 in range(obj_shape[1]): f_1 = float(axis1) / scale - np.floor(axis1 / scale) int_1 = int(axis1 / scale) + 2 axis1_weight = np.array( [[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]]) nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32) grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight) for i in range(4): for j in range(4): nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :] for ii in range(data_in.shape[2]): self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j] tmp = np.matmul(axis0_weight, nbr_pixel) data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0] # img = np.transpose(img[0, :, :, :], [1, 2, 0]) return data_obj def forward(self,input): print(type(input)) input_ = input.detach().numpy() output = self.bicubic_interpolate(input_) # return input.new(output) return torch.Tensor(output) def backward(self,grad_output): print(self.grad.shape,grad_output.shape) grad_output.detach().numpy() grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32) for i in range(self.grad.shape[0]): for j in range(self.grad.shape[1]): grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:] grad_input = grad_output_tmp*self.grad print(type(grad_input)) # return grad_output.new(grad_input) return torch.Tensor(grad_input) def bicubic(input): return Bicubic()(input) def main(): hr = Image.open('./baboon/baboon_hr.png').convert('L') hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2)) hr.requires_grad = True lr = bicubic(hr) print(lr.is_leaf) loss=torch.mean(lr) loss.backward() if __name__ =='__main__': main()
要想實現(xiàn)自動求導,必須同時實現(xiàn)forward(),backward()兩個函數(shù)。
1、從代碼中可以看出來,forward()函數(shù)是針對numpy數(shù)據(jù)操作,返回值再重新指定為torch.Tensor類型。因此就有這個問題出現(xiàn)了:forward輸入input被轉換為numpy類型,輸出轉換為tensor類型,那么輸出output的grad_fn參數(shù)是如何指定的呢。調(diào)試發(fā)現(xiàn),當main()中hr的requires_grad被指定為True,即hr被指定為需要求導的葉子節(jié)點。只要Bicubic類繼承自torch.autograd.Function,那么output也就是代碼中的lr的grad_fn就會被指定為<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic這個類。
2、backward()為求導的函數(shù),gard_output是鏈式求導法則的上一級的梯度,grad_input即為我們想要得到的梯度。只需要在輸入指定grad_output,在調(diào)用loss.backward()過程中的某一步會執(zhí)行到Bicubic的backwward()函數(shù)
以上這篇pytorch中的自定義反向傳播,求導實例就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
Flask框架運用WTForms實現(xiàn)用戶注冊的示例詳解
WTForms 是用于web開發(fā)的靈活的表單驗證和呈現(xiàn)庫,它可以與您選擇的任何web框架和模板引擎一起工作,并支持數(shù)據(jù)驗證、CSRF保護、國際化等。本文將運用WTForms實現(xiàn)用戶注冊功能,需要的可以參考一下2022-12-12使用python的pexpect模塊,實現(xiàn)遠程免密登錄的示例
今天小編就為大家分享一篇使用python的pexpect模塊,實現(xiàn)遠程免密登錄的示例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-02-02python sklearn中tsne算法降維結果不一致問題的解決方法
最近在做一個文本聚類的分析,在對文本數(shù)據(jù)embedding后,想著看下數(shù)據(jù)的分布,于是用sklearn的TSNE算法來降維embedding后的數(shù)據(jù)結果,當在多次執(zhí)行后,竟發(fā)現(xiàn)TSNE的結果竟然變了,而且每次都不一樣,所以本文就給大家講講如何解決sklearn中tsne算法降維結果不一致的問題2023-10-10Python Pycharm虛擬下百度飛漿PaddleX安裝報錯問題及處理方法(親測100%有效)
最近很多朋友給小編留言在安裝PaddleX的時候總是出現(xiàn)各種奇葩問題,不知道該怎么處理,今天小編通過本文給大家介紹下Python Pycharm虛擬下百度飛漿PaddleX安裝報錯問題及處理方法,真的有效,遇到同樣問題的朋友快來參考下吧2021-05-05Python構造函數(shù)與析構函數(shù)超詳細分析
在python之中定義一個類的時候會在類中創(chuàng)建一個名為__init__的函數(shù),這個函數(shù)就叫做構造函數(shù)。它的作用就是在實例化類的時候去自動的定義一些屬性和方法的值,而析構函數(shù)恰恰是一個和它相反的函數(shù),這篇文章主要介紹了Python構造函數(shù)與析構函數(shù)2022-11-11Python中循環(huán)后使用list.append()數(shù)據(jù)被覆蓋問題的解決
這篇文章主要給大家介紹了關于Python中循環(huán)后使用list.append()數(shù)據(jù)被覆蓋問題的解決方法,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2018-07-07Pytorch中的 torch.distributions庫詳解
這篇文章主要介紹了Pytorch中的 torch.distributions庫,本文通過實例代碼給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2023-02-02