使用Pytorch導(dǎo)出自定義ONNX算子的示例代碼
在實(shí)際部署模型時(shí)有時(shí)可能會(huì)遇到想用的算子無(wú)法導(dǎo)出onnx,但實(shí)際部署的框架是支持該算子的。此時(shí)可以通過(guò)自定義onnx算子的方式導(dǎo)出onnx模型(注:自定義onnx算子導(dǎo)出onnx模型后是無(wú)法使用onnxruntime推理的)。下面給出個(gè)具體應(yīng)用中的示例:需要導(dǎo)出pytorch的affine_grid算子,但在pytorch的2.0.1版本中又無(wú)法正常導(dǎo)出該算子,故可通過(guò)如下自定義算子代碼導(dǎo)出。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypes
class CustomAffineGrid(Function):
@staticmethod
def forward(ctx, theta: torch.Tensor, size: torch.Tensor):
grid = F.affine_grid(theta=theta, size=size.cpu().tolist())
return grid
@staticmethod
def symbolic(g: torch.Graph, theta: torch.Tensor, size: torch.Tensor):
return g.op("AffineGrid", theta, size)
class MyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor, theta: torch.Tensor, size: torch.Tensor):
grid = CustomAffineGrid.apply(theta, size)
x = F.grid_sample(x, grid=grid, mode="bilinear", padding_mode="zeros")
return x
def main():
with torch.inference_mode():
custum_model = MyModel()
x = torch.randn(1, 3, 224, 224)
theta = torch.randn(1, 2, 3)
size = torch.as_tensor([1, 3, 512, 512])
torch.onnx.export(model=custum_model,
args=(x, theta, size),
f="custom.onnx",
input_names=["input0_x", "input1_theta", "input2_size"],
output_names=["output"],
dynamic_axes={"input0_x": {2: "h0", 3: "w0"},
"output": {2: "h1", 3: "w1"}},
opset_version=16,
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)
if __name__ == '__main__':
main()在上面代碼中,通過(guò)繼承torch.autograd.Function父類的方式實(shí)現(xiàn)導(dǎo)出自定義算子,繼承該父類后需要用戶自己實(shí)現(xiàn)forward以及symbolic兩個(gè)靜態(tài)方法,其中forward方法是在pytorch正常推理時(shí)調(diào)用的函數(shù),而symbolic方法是在導(dǎo)出onnx時(shí)調(diào)用的函數(shù)。對(duì)于forward方法需要按照正常的pytorch語(yǔ)法來(lái)實(shí)現(xiàn),其中第一個(gè)參數(shù)必須是ctx但對(duì)于當(dāng)前導(dǎo)出onnx場(chǎng)景可以不用管它,后面的參數(shù)是實(shí)際自己傳入的參數(shù)。對(duì)于symbolic方法的第一個(gè)必須是g,后面的參數(shù)任為實(shí)際自己傳入的參數(shù),然后通過(guò)g.op方法指定具體導(dǎo)出自定義算子的名稱,以及輸入的參數(shù)(注:上面示例中傳入的都是Tensor所以可以直接傳入,對(duì)與非Tensor的參數(shù)可見(jiàn)下面一個(gè)示例)。最后在使用時(shí)直接調(diào)用自己實(shí)現(xiàn)類的apply方法即可。使用netron打開(kāi)自己導(dǎo)出的onnx文件,可以看到如下所示網(wǎng)絡(luò)結(jié)構(gòu)。

有時(shí)按照使用的推理框架導(dǎo)出自定義算子時(shí)還需要設(shè)置一些參數(shù)(非Tensor)那么可以參考如下示例,例如要導(dǎo)出int型的參數(shù)k那么可以通過(guò)傳入k_i來(lái)指定,要導(dǎo)出float型的參數(shù)scale那么可以通過(guò)傳入scale_f來(lái)指定,要導(dǎo)出string型的參數(shù)clockwise那么可以通過(guò)傳入clockwise_s來(lái)指定:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypes
class CustomRot90AndScale(Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
x = torch.rot90(x, k=1, dims=(3, 2)) # clockwise 90
x *= 1.2
return x
@staticmethod
def symbolic(g: torch.Graph, x: torch.Tensor):
return g.op("Rot90AndScale", x, k_i=1, scale_f=1.2, clockwise_s="yes")
class MyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor):
return CustomRot90AndScale.apply(x)
def main():
with torch.inference_mode():
custum_model = MyModel()
x = torch.randn(1, 3, 224, 224)
torch.onnx.export(model=custum_model,
args=(x,),
f="custom_rot90.onnx",
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {2: "h0", 3: "w0"},
"output": {2: "w0", 3: "h0"}},
opset_version=16,
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)
if __name__ == '__main__':
main()使用netron打開(kāi)自己導(dǎo)出的onnx文件,可以看到如下所示信息。

到此這篇關(guān)于使用Pytorch導(dǎo)出自定義ONNX算子的文章就介紹到這了,更多相關(guān)使用Pytorch導(dǎo)出自定義ONNX算子內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
對(duì)PyQt5中樹(shù)結(jié)構(gòu)的實(shí)現(xiàn)方法詳解
今天小編就為大家分享一篇對(duì)PyQt5中樹(shù)結(jié)構(gòu)的實(shí)現(xiàn)方法詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-06-06
Django的URLconf中使用缺省視圖參數(shù)的方法
這篇文章主要介紹了Django的URLconf中使用缺省視圖參數(shù)的方法,Django是最著名的Python的web開(kāi)發(fā)框架,需要的朋友可以參考下2015-07-07
Python如何統(tǒng)計(jì)函數(shù)調(diào)用的耗時(shí)
這篇文章主要為大家詳細(xì)介紹了如何使用Python實(shí)現(xiàn)統(tǒng)計(jì)函數(shù)調(diào)用的耗時(shí),文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2024-04-04
Python 調(diào)用 C++ 傳遞numpy 數(shù)據(jù)詳情
這篇文章主要介紹了Python 調(diào)用 C++ 傳遞numpy 數(shù)據(jù)詳情,文章主要分為兩部分,c++代碼和python代碼,代碼分享詳細(xì),需要的小伙伴可以參考一下,希望對(duì)你有所幫助2022-03-03
PyCharm中Matplotlib繪圖不能顯示UI效果的問(wèn)題解決
這篇文章主要介紹了PyCharm中Matplotlib繪圖不能顯示UI效果的問(wèn)題解決,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-03-03
解讀殘差網(wǎng)絡(luò)(Residual Network),殘差連接(skip-connect)
這篇文章主要介紹了殘差網(wǎng)絡(luò)(Residual Network),殘差連接(skip-connect),具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-08-08
Python實(shí)現(xiàn)的求解最小公倍數(shù)算法示例
這篇文章主要介紹了Python實(shí)現(xiàn)的求解最小公倍數(shù)算法,涉及Python數(shù)值運(yùn)算、判斷等相關(guān)操作技巧,需要的朋友可以參考下2018-05-05

