使用Pytorch導(dǎo)出自定義ONNX算子的示例代碼
在實際部署模型時有時可能會遇到想用的算子無法導(dǎo)出onnx,但實際部署的框架是支持該算子的。此時可以通過自定義onnx算子的方式導(dǎo)出onnx模型(注:自定義onnx算子導(dǎo)出onnx模型后是無法使用onnxruntime推理的)。下面給出個具體應(yīng)用中的示例:需要導(dǎo)出pytorch的affine_grid
算子,但在pytorch的2.0.1
版本中又無法正常導(dǎo)出該算子,故可通過如下自定義算子代碼導(dǎo)出。
import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Function from torch.onnx import OperatorExportTypes class CustomAffineGrid(Function): @staticmethod def forward(ctx, theta: torch.Tensor, size: torch.Tensor): grid = F.affine_grid(theta=theta, size=size.cpu().tolist()) return grid @staticmethod def symbolic(g: torch.Graph, theta: torch.Tensor, size: torch.Tensor): return g.op("AffineGrid", theta, size) class MyModel(nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x: torch.Tensor, theta: torch.Tensor, size: torch.Tensor): grid = CustomAffineGrid.apply(theta, size) x = F.grid_sample(x, grid=grid, mode="bilinear", padding_mode="zeros") return x def main(): with torch.inference_mode(): custum_model = MyModel() x = torch.randn(1, 3, 224, 224) theta = torch.randn(1, 2, 3) size = torch.as_tensor([1, 3, 512, 512]) torch.onnx.export(model=custum_model, args=(x, theta, size), f="custom.onnx", input_names=["input0_x", "input1_theta", "input2_size"], output_names=["output"], dynamic_axes={"input0_x": {2: "h0", 3: "w0"}, "output": {2: "h1", 3: "w1"}}, opset_version=16, operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH) if __name__ == '__main__': main()
在上面代碼中,通過繼承torch.autograd.Function
父類的方式實現(xiàn)導(dǎo)出自定義算子,繼承該父類后需要用戶自己實現(xiàn)forward
以及symbolic
兩個靜態(tài)方法,其中forward
方法是在pytorch正常推理時調(diào)用的函數(shù),而symbolic
方法是在導(dǎo)出onnx時調(diào)用的函數(shù)。對于forward
方法需要按照正常的pytorch語法來實現(xiàn),其中第一個參數(shù)必須是ctx
但對于當(dāng)前導(dǎo)出onnx場景可以不用管它,后面的參數(shù)是實際自己傳入的參數(shù)。對于symbolic
方法的第一個必須是g
,后面的參數(shù)任為實際自己傳入的參數(shù),然后通過g.op
方法指定具體導(dǎo)出自定義算子的名稱,以及輸入的參數(shù)(注:上面示例中傳入的都是Tensor
所以可以直接傳入,對與非Tensor
的參數(shù)可見下面一個示例)。最后在使用時直接調(diào)用自己實現(xiàn)類的apply
方法即可。使用netron
打開自己導(dǎo)出的onnx文件,可以看到如下所示網(wǎng)絡(luò)結(jié)構(gòu)。
有時按照使用的推理框架導(dǎo)出自定義算子時還需要設(shè)置一些參數(shù)(非Tensor
)那么可以參考如下示例,例如要導(dǎo)出int
型的參數(shù)k
那么可以通過傳入k_i
來指定,要導(dǎo)出float
型的參數(shù)scale
那么可以通過傳入scale_f
來指定,要導(dǎo)出string
型的參數(shù)clockwise
那么可以通過傳入clockwise_s
來指定:
import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Function from torch.onnx import OperatorExportTypes class CustomRot90AndScale(Function): @staticmethod def forward(ctx, x: torch.Tensor): x = torch.rot90(x, k=1, dims=(3, 2)) # clockwise 90 x *= 1.2 return x @staticmethod def symbolic(g: torch.Graph, x: torch.Tensor): return g.op("Rot90AndScale", x, k_i=1, scale_f=1.2, clockwise_s="yes") class MyModel(nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x: torch.Tensor): return CustomRot90AndScale.apply(x) def main(): with torch.inference_mode(): custum_model = MyModel() x = torch.randn(1, 3, 224, 224) torch.onnx.export(model=custum_model, args=(x,), f="custom_rot90.onnx", input_names=["input"], output_names=["output"], dynamic_axes={"input": {2: "h0", 3: "w0"}, "output": {2: "w0", 3: "h0"}}, opset_version=16, operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH) if __name__ == '__main__': main()
使用netron
打開自己導(dǎo)出的onnx文件,可以看到如下所示信息。
到此這篇關(guān)于使用Pytorch導(dǎo)出自定義ONNX算子的文章就介紹到這了,更多相關(guān)使用Pytorch導(dǎo)出自定義ONNX算子內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
對PyQt5中樹結(jié)構(gòu)的實現(xiàn)方法詳解
今天小編就為大家分享一篇對PyQt5中樹結(jié)構(gòu)的實現(xiàn)方法詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-06-06Django的URLconf中使用缺省視圖參數(shù)的方法
這篇文章主要介紹了Django的URLconf中使用缺省視圖參數(shù)的方法,Django是最著名的Python的web開發(fā)框架,需要的朋友可以參考下2015-07-07Python如何統(tǒng)計函數(shù)調(diào)用的耗時
這篇文章主要為大家詳細(xì)介紹了如何使用Python實現(xiàn)統(tǒng)計函數(shù)調(diào)用的耗時,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2024-04-04Python 調(diào)用 C++ 傳遞numpy 數(shù)據(jù)詳情
這篇文章主要介紹了Python 調(diào)用 C++ 傳遞numpy 數(shù)據(jù)詳情,文章主要分為兩部分,c++代碼和python代碼,代碼分享詳細(xì),需要的小伙伴可以參考一下,希望對你有所幫助2022-03-03PyCharm中Matplotlib繪圖不能顯示UI效果的問題解決
這篇文章主要介紹了PyCharm中Matplotlib繪圖不能顯示UI效果的問題解決,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-03-03解讀殘差網(wǎng)絡(luò)(Residual Network),殘差連接(skip-connect)
這篇文章主要介紹了殘差網(wǎng)絡(luò)(Residual Network),殘差連接(skip-connect),具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2023-08-08Python實現(xiàn)的求解最小公倍數(shù)算法示例
這篇文章主要介紹了Python實現(xiàn)的求解最小公倍數(shù)算法,涉及Python數(shù)值運(yùn)算、判斷等相關(guān)操作技巧,需要的朋友可以參考下2018-05-05