欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

python pytorch模型轉(zhuǎn)onnx模型的全過程(多輸入+動態(tài)維度)

 更新時間:2024年03月21日 11:51:25   作者:暗號9  
這篇文章主要介紹了python pytorch模型轉(zhuǎn)onnx模型的全過程(多輸入+動態(tài)維度),本文給大家記錄記錄了pt文件轉(zhuǎn)onnx全過程,簡單的修改即可應(yīng)用,結(jié)合實例代碼給大家介紹的非常詳細(xì),感興趣的朋友一起看看吧

(多輸入+動態(tài)維度)整理的自定義神經(jīng)網(wǎng)絡(luò)pt轉(zhuǎn)onnx過程的python代碼,記錄了pt文件轉(zhuǎn)onnx全過程,簡單的修改即可應(yīng)用。

pt文件轉(zhuǎn)onnx步驟 

1、編寫預(yù)處理代碼

預(yù)處理代碼 與torch模型的預(yù)處理代碼一樣

def preprocess(img):
	img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1)
	img = np.expand_dims(img, 0)
	sh_im = img.shape
	if sh_im[2]%2==1:
    	img = np.concatenate((img, img[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
	if sh_im[3]%2==1:
    	img = np.concatenate((img, img[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
	img = normalize(img)
	img = torch.Tensor(img)
	return img

2、用onnxruntime導(dǎo)出onnx

def export_onnx(net, model_path, img, nsigma, onnx_outPath):
	nsigma /= 255.
	if torch.cuda.is_available():
    	state_dict = torch.load(model_path)
    	model = net.cuda()
    	dtype = torch.cuda.FloatTensor
    else:
    	state_dict = torch.load(model_path, map_location='cpu')
    	state_dict = remove_dataparallel_wrapper(state_dict)
    	model = net
    	dtype = torch.FloatTensor
	img = Variable(img.type(dtype))
	nsigma = Variable(torch.FloatTensor([nsigma]).type(dtype))
	# 我這里預(yù)訓(xùn)練權(quán)重中參數(shù)名字與網(wǎng)絡(luò)名字不同
	# 相同的話可直接load_state_dict(state_dict)
	new_state_dict = {}
	for k, v in state_dict.items():
    	new_state_dict[k[7:]] = v
	model.load_state_dict(new_state_dict)
	# 設(shè)置onnx的輸入輸出列表,多輸入多輸出就設(shè)置多個
	input_list = ['input', 'nsigma']
	output_list = ['output']
	# onnx模型導(dǎo)出
	# dynamic_axes為動態(tài)維度,如果自己的輸入輸出是維度變化的建議設(shè)置,否則只能輸入固定維度的tensor
	torch.onnx.export(model, (img, nsigma), onnx_outPath, verbose=True, opset_version=11, export_params=True,
		 				input_names=input_list, output_names=output_list,
		 				dynamic_axes={'input_img': {0: 'batch', 1: 'channel', 2: 'height', 3: 'width'},
		 				'output': {0: 'batch', 1: 'channel', 2: 'height', 3: 'width'}}) 

導(dǎo)出結(jié)果

3、對導(dǎo)出的模型進(jìn)行檢查

此處為檢查onnx模型節(jié)點,后面如果onnx算子不支持轉(zhuǎn)engine時,方便定位節(jié)點,找到不支持的算子進(jìn)行修改

def check_onnx(onnx_model_path):
	model = onnx.load(onnx_model_path)
	onnx.checker.check_model((model))
	print(onnx.helper.printable_graph(model.graph))

下面貼出輸出結(jié)果


netron可視化

4、推理onnx模型,查看輸出是否一致

    def run_onnx(onnx_model_path, test_img, nsigma):
		nsigma /= 255.
		with torch.no_grad:
    	# 這里默認(rèn)是cuda推理torch.cuda.FloatTensor
    	img = Variable(test_img.type(torch.cuda.FloatTensor))
    	nsigma = Variable(torch.FloatTensor([nsigma]).type(torch.cuda.FloatTensor))
    	# 設(shè)置GPU推理
    	device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    	providers = ['CUDAExecutionProvider'] if device != "cpu" else ['CPUExecutionProvider']
    	# 通過創(chuàng)建onnxruntime session來運行onnx模型
    	ort_session = ort.InferenceSession(onnx_model_path, providers=providers)
    	output = ort_session.run(output_names=['output'],
                             	input_feed={'input_img': np.array(img.cpu(), dtype=np.float32),
                                'nsigma':  np.array(nsigma.cpu(), dtype=np.float32)})
		return output

5、對onnx模型的輸出進(jìn)行處理,顯示cv圖像

def postprocess(img, img_noise_estime):
    out = torch.clamp(img-img_noise_estime, 0., 1.)
    outimg = variable_to_cv2_image(out)
    cv2.imshow(outimg)

6、編輯主函數(shù)進(jìn)行測試

def main():
    ##############################
    #
    #        onnx模型導(dǎo)出
    #
    ##############################
    # pt權(quán)重路徑:自己的路徑 + mypt.pt
    model_path = "D:/python/ffdnet-pytorch/models/net_rgb.pth"
    # export onnx模型時輸入進(jìn)去數(shù)據(jù),用于onnx記錄網(wǎng)絡(luò)的計算過程
    export_feed_path = "D:/python/ffdnet-pytorch/noisy.png"
    # onnx模型導(dǎo)出的路徑
    onnx_outpath = "D:/python/ffdnet-pytorch/models/myonnx.onnx"
    # 實例化自己的網(wǎng)絡(luò)模型并設(shè)置輸入?yún)?shù)
    net = FFDNet(num_input_channels=3)
    nsigma = 25
    # onnx 導(dǎo)出
    img = cv2.imread(export_feed_path)
    input = preprocess(img)
    export_onnx(net, model_path, input, nsigma, onnx_outpath)
    print("export success!")
    ##############################
    #
    #        檢查onnx模型
    #
    ##############################
    check_onnx(onnx_outpath)
    # netron可視化網(wǎng)絡(luò),可視化用節(jié)點記錄的網(wǎng)絡(luò)推理流程
    netron.start(onnx_outpath)
    ##############################
    #
    #        運行onnx模型
    #
    ##############################
    # 此處過程是數(shù)據(jù)預(yù)處理 ---> 調(diào)用run_onnx函數(shù) ---> 對模型輸出后處理
    # 具體代碼就不再重復(fù)了

#完整代碼

import time
import netron
import cv2
import torch
import onnx
import numpy as np
from torch.autograd import Variable
import onnxruntime as ort
from models import FFDNet
from utils import remove_dataparallel_wrapper, normalize, variable_to_cv2_image
# 此處為預(yù)處理代碼 與torch模型的預(yù)處理代碼一樣
def preprocess(img):
    img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1)
    img = np.expand_dims(img, 0)
    sh_im = img.shape
    if sh_im[2]%2==1:
        img = np.concatenate((img, img[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
    if sh_im[3]%2==1:
        img = np.concatenate((img, img[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
    img = normalize(img)
    img = torch.Tensor(img)
    return img
# 此處為onnx模型導(dǎo)出的代碼,包括torch模型的pt權(quán)重加載,onnx模型的導(dǎo)出
def export_onnx(net, model_path, img, nsigma, onnx_outPath):
    nsigma /= 255.
    if torch.cuda.is_available():
        state_dict = torch.load(model_path)
        model = net.cuda()
        dtype = torch.cuda.FloatTensor
    else:
        state_dict = torch.load(model_path, map_location='cpu')
        state_dict = remove_dataparallel_wrapper(state_dict)
        model = net
        dtype = torch.FloatTensor
    img = Variable(img.type(dtype))
    nsigma = Variable(torch.FloatTensor([nsigma]).type(dtype))
    # 我這里預(yù)訓(xùn)練權(quán)重中參數(shù)名字與網(wǎng)絡(luò)名字不同
    # 相同的話可直接load_state_dict(state_dict)
    new_state_dict = {}
    for k, v in state_dict.items():
        new_state_dict[k[7:]] = v
    model.load_state_dict(new_state_dict)
    # 設(shè)置onnx的輸入輸出列表,多輸入多輸出就設(shè)置多個
    input_list = ['input', 'nsigma']
    output_list = ['output']
    # onnx模型導(dǎo)出
    # dynamic_axes為動態(tài)維度,如果自己的輸入輸出是維度變化的建議設(shè)置,否則只能輸入固定維度的tensor
    torch.onnx.export(model, (img, nsigma), onnx_outPath, verbose=True, opset_version=11, export_params=True,
                      input_names=input_list, output_names=output_list,
                      dynamic_axes={'input_img': {0: 'batch', 1: 'channel', 2: 'height', 3: 'width'},
                                    'output': {0: 'batch', 1: 'channel', 2: 'height', 3: 'width'}})
# 此處為檢查onnx模型節(jié)點,后面如果onnx算子不支持轉(zhuǎn)engine時,方便定位節(jié)點,找到不支持的算子進(jìn)行修改
def check_onnx(onnx_model_path):
    model = onnx.load(onnx_model_path)
    onnx.checker.check_model((model))
    print(onnx.helper.printable_graph(model.graph))
# 此處為推理onnx模型的代碼,檢查輸出是否跟torch模型相同
def run_onnx(onnx_model_path, test_img, nsigma):
    nsigma /= 255.
    with torch.no_grad:
        # 這里默認(rèn)是cuda推理torch.cuda.FloatTensor
        img = Variable(test_img.type(torch.cuda.FloatTensor))
        nsigma = Variable(torch.FloatTensor([nsigma]).type(torch.cuda.FloatTensor))
        # 設(shè)置GPU推理
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        providers = ['CUDAExecutionProvider'] if device != "cpu" else ['CPUExecutionProvider']
        # 通過創(chuàng)建onnxruntime session來運行onnx模型
        ort_session = ort.InferenceSession(onnx_model_path, providers=providers)
        output = ort_session.run(output_names=['output'],
                                 input_feed={'input_img': np.array(img.cpu(), dtype=np.float32),
                                             'nsigma':  np.array(nsigma.cpu(), dtype=np.float32)})
    return output
# 此處是后處理代碼,將onnx模型的輸出處理成可顯示cv圖像
# 與torch模型的后處理一樣
def postprocess(img, img_noise_estime):
    out = torch.clamp(img-img_noise_estime, 0., 1.)
    outimg = variable_to_cv2_image(out)
    cv2.imshow(outimg)
def main():
    ##############################
    #
    #        onnx模型導(dǎo)出
    #
    ##############################
    # pt權(quán)重路徑:自己的路徑 + mypt.pt
    model_path = "D:/python/ffdnet-pytorch/models/net_rgb.pth"
    # export onnx模型時輸入進(jìn)去數(shù)據(jù),用于onnx記錄網(wǎng)絡(luò)的計算過程
    export_feed_path = "D:/python/ffdnet-pytorch/noisy.png"
    # onnx模型導(dǎo)出的路徑
    onnx_outpath = "D:/python/ffdnet-pytorch/models/myonnx.onnx"
    # 實例化自己的網(wǎng)絡(luò)模型并設(shè)置輸入?yún)?shù)
    net = FFDNet(num_input_channels=3)
    nsigma = 25
    # onnx 導(dǎo)出
    img = cv2.imread(export_feed_path)
    input = preprocess(img)
    export_onnx(net, model_path, input, nsigma, onnx_outpath)
    print("export success!")
    ##############################
    #
    #        檢查onnx模型
    #
    ##############################
    onnx_outpath = "D:/python/ffdnet-pytorch/models/myonnx.onnx"
    check_onnx(onnx_outpath)
    # netron可視化網(wǎng)絡(luò),可視化用節(jié)點記錄的網(wǎng)絡(luò)推理流程
    netron.start(onnx_outpath)
    ##############################
    #
    #        運行onnx模型
    #
    ##############################
    # 此處過程是數(shù)據(jù)預(yù)處理 ---> 調(diào)用run_onnx函數(shù) ---> 對模型輸出后處理
    # 具體代碼就不再重復(fù)了
if __name__ == '__main__':
    main()

到此這篇關(guān)于python pytorch模型轉(zhuǎn)onnx模型(多輸入+動態(tài)維度)的文章就介紹到這了,更多相關(guān)python pytorch模型轉(zhuǎn)onnx模型內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • python:按行讀入,排序然后輸出的方法

    python:按行讀入,排序然后輸出的方法

    今天小編就為大家分享一篇python:按行讀入,排序然后輸出的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-07-07
  • Python列表刪除重復(fù)元素與圖像相似度判斷及刪除實例代碼

    Python列表刪除重復(fù)元素與圖像相似度判斷及刪除實例代碼

    這篇文章主要給大家介紹了關(guān)于Python列表刪除重復(fù)元素與圖像相似度判斷及刪除的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2021-05-05
  • 詳解django2中關(guān)于時間處理策略

    詳解django2中關(guān)于時間處理策略

    這篇文章主要介紹了詳解django2中關(guān)于時間處理策略,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2019-03-03
  • 如何使用Python程序完成描述性統(tǒng)計分析需求

    如何使用Python程序完成描述性統(tǒng)計分析需求

    這篇文章主要介紹了如何使用Python程序完成描述性統(tǒng)計分析需求,運用制表和分類,圖形以及計算概括性數(shù)據(jù)來描述數(shù)據(jù)特征的各項活動,需要的朋友可以參考下
    2023-03-03
  • Python更改pip鏡像源的方法示例

    Python更改pip鏡像源的方法示例

    這篇文章主要介紹了Python更改pip鏡像源的方法示例,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-12-12
  • Python中實現(xiàn)從目錄中過濾出指定文件類型的文件

    Python中實現(xiàn)從目錄中過濾出指定文件類型的文件

    這篇文章主要介紹了Python中實現(xiàn)從目錄中過濾出指定文件類型的文件,本文是一篇學(xué)筆記,實例相對簡單,需要的朋友可以參考下
    2015-02-02
  • 關(guān)于np.meshgrid函數(shù)中的indexing參數(shù)問題

    關(guān)于np.meshgrid函數(shù)中的indexing參數(shù)問題

    Meshgrid函數(shù)在二維與三維空間中用于生成坐標(biāo)網(wǎng)格,便于進(jìn)行圖像處理和空間數(shù)據(jù)分析,二維情況下,默認(rèn)使用笛卡爾坐標(biāo)系,而三維meshgrid則涉及不同的坐標(biāo)軸取法,在三維情況下,可能會出現(xiàn)坐標(biāo)軸排列序混亂
    2024-09-09
  • Python 日期與時間轉(zhuǎn)換的方法

    Python 日期與時間轉(zhuǎn)換的方法

    這篇文章主要介紹了Python 日期與時間轉(zhuǎn)換的方法,文中講解非常細(xì)致,代碼幫助大家更好的理解和學(xué)習(xí),感興趣的朋友可以了解下
    2020-08-08
  • Python3爬蟲中識別圖形驗證碼的實例講解

    Python3爬蟲中識別圖形驗證碼的實例講解

    在本篇內(nèi)容里小編給大家分享的是關(guān)于Python3爬蟲中識別圖形驗證碼的實例講解內(nèi)容,需要的朋友們可以學(xué)習(xí)參考下。
    2020-07-07
  • Python3.7安裝keras和TensorFlow的教程圖解

    Python3.7安裝keras和TensorFlow的教程圖解

    這篇文章主要介紹了Python3.7安裝keras和TensorFlow經(jīng)驗,本文圖文并茂給大家介紹的非常詳細(xì),具有一定的參考借鑒價值,需要的朋友可以參考下
    2019-10-10

最新評論