欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

使用Pytorch導(dǎo)出自定義ONNX算子的示例代碼

 更新時間:2024年03月08日 12:00:30   作者:太陽花的小綠豆  
這篇文章主要介紹了使用Pytorch導(dǎo)出自定義ONNX算子的示例代碼,下面給出個具體應(yīng)用中的示例:需要導(dǎo)出pytorch的affine_grid算子,但在pytorch的2.0.1版本中又無法正常導(dǎo)出該算子,故可通過如下自定義算子代碼導(dǎo)出,需要的朋友可以參考下

在實際部署模型時有時可能會遇到想用的算子無法導(dǎo)出onnx,但實際部署的框架是支持該算子的。此時可以通過自定義onnx算子的方式導(dǎo)出onnx模型(注:自定義onnx算子導(dǎo)出onnx模型后是無法使用onnxruntime推理的)。下面給出個具體應(yīng)用中的示例:需要導(dǎo)出pytorch的affine_grid算子,但在pytorch的2.0.1版本中又無法正常導(dǎo)出該算子,故可通過如下自定義算子代碼導(dǎo)出。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypes
class CustomAffineGrid(Function):
    @staticmethod
    def forward(ctx, theta: torch.Tensor, size: torch.Tensor):
        grid = F.affine_grid(theta=theta, size=size.cpu().tolist())
        return grid
    @staticmethod
    def symbolic(g: torch.Graph, theta: torch.Tensor, size: torch.Tensor):
        return g.op("AffineGrid", theta, size)
class MyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
    def forward(self, x: torch.Tensor, theta: torch.Tensor, size: torch.Tensor):
        grid = CustomAffineGrid.apply(theta, size)
        x = F.grid_sample(x, grid=grid, mode="bilinear", padding_mode="zeros")
        return x
def main():
    with torch.inference_mode():
        custum_model = MyModel()
        x = torch.randn(1, 3, 224, 224)
        theta = torch.randn(1, 2, 3)
        size = torch.as_tensor([1, 3, 512, 512])
        torch.onnx.export(model=custum_model,
                          args=(x, theta, size),
                          f="custom.onnx",
                          input_names=["input0_x", "input1_theta", "input2_size"],
                          output_names=["output"],
                          dynamic_axes={"input0_x": {2: "h0", 3: "w0"},
                                        "output": {2: "h1", 3: "w1"}},
                          opset_version=16,
                          operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)
if __name__ == '__main__':
    main()

在上面代碼中,通過繼承torch.autograd.Function父類的方式實現(xiàn)導(dǎo)出自定義算子,繼承該父類后需要用戶自己實現(xiàn)forward以及symbolic兩個靜態(tài)方法,其中forward方法是在pytorch正常推理時調(diào)用的函數(shù),而symbolic方法是在導(dǎo)出onnx時調(diào)用的函數(shù)。對于forward方法需要按照正常的pytorch語法來實現(xiàn),其中第一個參數(shù)必須是ctx但對于當(dāng)前導(dǎo)出onnx場景可以不用管它,后面的參數(shù)是實際自己傳入的參數(shù)。對于symbolic方法的第一個必須是g,后面的參數(shù)任為實際自己傳入的參數(shù),然后通過g.op方法指定具體導(dǎo)出自定義算子的名稱,以及輸入的參數(shù)(注:上面示例中傳入的都是Tensor所以可以直接傳入,對與非Tensor的參數(shù)可見下面一個示例)。最后在使用時直接調(diào)用自己實現(xiàn)類的apply方法即可。使用netron打開自己導(dǎo)出的onnx文件,可以看到如下所示網(wǎng)絡(luò)結(jié)構(gòu)。

有時按照使用的推理框架導(dǎo)出自定義算子時還需要設(shè)置一些參數(shù)(非Tensor)那么可以參考如下示例,例如要導(dǎo)出int型的參數(shù)k那么可以通過傳入k_i來指定,要導(dǎo)出float型的參數(shù)scale那么可以通過傳入scale_f來指定,要導(dǎo)出string型的參數(shù)clockwise那么可以通過傳入clockwise_s來指定:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypes
class CustomRot90AndScale(Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        x = torch.rot90(x, k=1, dims=(3, 2))  # clockwise 90
        x *= 1.2
        return x
    @staticmethod
    def symbolic(g: torch.Graph, x: torch.Tensor):
        return g.op("Rot90AndScale", x, k_i=1, scale_f=1.2, clockwise_s="yes")
class MyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
    def forward(self, x: torch.Tensor):
        return CustomRot90AndScale.apply(x)
def main():
    with torch.inference_mode():
        custum_model = MyModel()
        x = torch.randn(1, 3, 224, 224)
        torch.onnx.export(model=custum_model,
                          args=(x,),
                          f="custom_rot90.onnx",
                          input_names=["input"],
                          output_names=["output"],
                          dynamic_axes={"input": {2: "h0", 3: "w0"},
                                        "output": {2: "w0", 3: "h0"}},
                          opset_version=16,
                          operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)
if __name__ == '__main__':
    main()

使用netron打開自己導(dǎo)出的onnx文件,可以看到如下所示信息。

到此這篇關(guān)于使用Pytorch導(dǎo)出自定義ONNX算子的文章就介紹到這了,更多相關(guān)使用Pytorch導(dǎo)出自定義ONNX算子內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • python方差檢驗的含義及用法

    python方差檢驗的含義及用法

    在本篇內(nèi)容里小編給大家整理的是一篇關(guān)于python方差檢驗的含義及用法,有需要的朋友們可以跟著學(xué)習(xí)參考下。
    2021-07-07
  • 使用python裝飾器驗證配置文件示例

    使用python裝飾器驗證配置文件示例

    項目中用到了一個WriteData的函數(shù)保存用戶填寫的配置,為了實現(xiàn)驗證用戶輸入的需求,在不影響接口的使用的前提下,采用了python的裝飾器實現(xiàn),代碼片段演示了如何驗證WriteData函數(shù)的輸入?yún)?shù)
    2014-02-02
  • 對PyQt5中樹結(jié)構(gòu)的實現(xiàn)方法詳解

    對PyQt5中樹結(jié)構(gòu)的實現(xiàn)方法詳解

    今天小編就為大家分享一篇對PyQt5中樹結(jié)構(gòu)的實現(xiàn)方法詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-06-06
  • Django的URLconf中使用缺省視圖參數(shù)的方法

    Django的URLconf中使用缺省視圖參數(shù)的方法

    這篇文章主要介紹了Django的URLconf中使用缺省視圖參數(shù)的方法,Django是最著名的Python的web開發(fā)框架,需要的朋友可以參考下
    2015-07-07
  • Python如何統(tǒng)計函數(shù)調(diào)用的耗時

    Python如何統(tǒng)計函數(shù)調(diào)用的耗時

    這篇文章主要為大家詳細(xì)介紹了如何使用Python實現(xiàn)統(tǒng)計函數(shù)調(diào)用的耗時,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下
    2024-04-04
  • Python 調(diào)用 C++ 傳遞numpy 數(shù)據(jù)詳情

    Python 調(diào)用 C++ 傳遞numpy 數(shù)據(jù)詳情

    這篇文章主要介紹了Python 調(diào)用 C++ 傳遞numpy 數(shù)據(jù)詳情,文章主要分為兩部分,c++代碼和python代碼,代碼分享詳細(xì),需要的小伙伴可以參考一下,希望對你有所幫助
    2022-03-03
  • PyCharm中Matplotlib繪圖不能顯示UI效果的問題解決

    PyCharm中Matplotlib繪圖不能顯示UI效果的問題解決

    這篇文章主要介紹了PyCharm中Matplotlib繪圖不能顯示UI效果的問題解決,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-03-03
  • 解讀殘差網(wǎng)絡(luò)(Residual Network),殘差連接(skip-connect)

    解讀殘差網(wǎng)絡(luò)(Residual Network),殘差連接(skip-connect)

    這篇文章主要介紹了殘差網(wǎng)絡(luò)(Residual Network),殘差連接(skip-connect),具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教
    2023-08-08
  • nonebot插件之chatgpt使用詳解

    nonebot插件之chatgpt使用詳解

    這篇文章主要為大家介紹了nonebot插件之chatgpt使用詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2023-03-03
  • Python實現(xiàn)的求解最小公倍數(shù)算法示例

    Python實現(xiàn)的求解最小公倍數(shù)算法示例

    這篇文章主要介紹了Python實現(xiàn)的求解最小公倍數(shù)算法,涉及Python數(shù)值運(yùn)算、判斷等相關(guān)操作技巧,需要的朋友可以參考下
    2018-05-05

最新評論