python pytorch模型轉(zhuǎn)onnx模型的全過程(多輸入+動態(tài)維度)
(多輸入+動態(tài)維度)整理的自定義神經(jīng)網(wǎng)絡pt轉(zhuǎn)onnx過程的python代碼,記錄了pt文件轉(zhuǎn)onnx全過程,簡單的修改即可應用。
pt文件轉(zhuǎn)onnx步驟
1、編寫預處理代碼
預處理代碼 與torch模型的預處理代碼一樣
def preprocess(img): img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1) img = np.expand_dims(img, 0) sh_im = img.shape if sh_im[2]%2==1: img = np.concatenate((img, img[:, :, -1, :][:, :, np.newaxis, :]), axis=2) if sh_im[3]%2==1: img = np.concatenate((img, img[:, :, :, -1][:, :, :, np.newaxis]), axis=3) img = normalize(img) img = torch.Tensor(img) return img
2、用onnxruntime導出onnx
def export_onnx(net, model_path, img, nsigma, onnx_outPath): nsigma /= 255. if torch.cuda.is_available(): state_dict = torch.load(model_path) model = net.cuda() dtype = torch.cuda.FloatTensor else: state_dict = torch.load(model_path, map_location='cpu') state_dict = remove_dataparallel_wrapper(state_dict) model = net dtype = torch.FloatTensor img = Variable(img.type(dtype)) nsigma = Variable(torch.FloatTensor([nsigma]).type(dtype)) # 我這里預訓練權(quán)重中參數(shù)名字與網(wǎng)絡名字不同 # 相同的話可直接load_state_dict(state_dict) new_state_dict = {} for k, v in state_dict.items(): new_state_dict[k[7:]] = v model.load_state_dict(new_state_dict) # 設置onnx的輸入輸出列表,多輸入多輸出就設置多個 input_list = ['input', 'nsigma'] output_list = ['output'] # onnx模型導出 # dynamic_axes為動態(tài)維度,如果自己的輸入輸出是維度變化的建議設置,否則只能輸入固定維度的tensor torch.onnx.export(model, (img, nsigma), onnx_outPath, verbose=True, opset_version=11, export_params=True, input_names=input_list, output_names=output_list, dynamic_axes={'input_img': {0: 'batch', 1: 'channel', 2: 'height', 3: 'width'}, 'output': {0: 'batch', 1: 'channel', 2: 'height', 3: 'width'}})
導出結(jié)果
3、對導出的模型進行檢查
此處為檢查onnx模型節(jié)點,后面如果onnx算子不支持轉(zhuǎn)engine時,方便定位節(jié)點,找到不支持的算子進行修改
def check_onnx(onnx_model_path): model = onnx.load(onnx_model_path) onnx.checker.check_model((model)) print(onnx.helper.printable_graph(model.graph))
下面貼出輸出結(jié)果
netron可視化
4、推理onnx模型,查看輸出是否一致
def run_onnx(onnx_model_path, test_img, nsigma): nsigma /= 255. with torch.no_grad: # 這里默認是cuda推理torch.cuda.FloatTensor img = Variable(test_img.type(torch.cuda.FloatTensor)) nsigma = Variable(torch.FloatTensor([nsigma]).type(torch.cuda.FloatTensor)) # 設置GPU推理 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") providers = ['CUDAExecutionProvider'] if device != "cpu" else ['CPUExecutionProvider'] # 通過創(chuàng)建onnxruntime session來運行onnx模型 ort_session = ort.InferenceSession(onnx_model_path, providers=providers) output = ort_session.run(output_names=['output'], input_feed={'input_img': np.array(img.cpu(), dtype=np.float32), 'nsigma': np.array(nsigma.cpu(), dtype=np.float32)}) return output
5、對onnx模型的輸出進行處理,顯示cv圖像
def postprocess(img, img_noise_estime): out = torch.clamp(img-img_noise_estime, 0., 1.) outimg = variable_to_cv2_image(out) cv2.imshow(outimg)
6、編輯主函數(shù)進行測試
def main(): ############################## # # onnx模型導出 # ############################## # pt權(quán)重路徑:自己的路徑 + mypt.pt model_path = "D:/python/ffdnet-pytorch/models/net_rgb.pth" # export onnx模型時輸入進去數(shù)據(jù),用于onnx記錄網(wǎng)絡的計算過程 export_feed_path = "D:/python/ffdnet-pytorch/noisy.png" # onnx模型導出的路徑 onnx_outpath = "D:/python/ffdnet-pytorch/models/myonnx.onnx" # 實例化自己的網(wǎng)絡模型并設置輸入?yún)?shù) net = FFDNet(num_input_channels=3) nsigma = 25 # onnx 導出 img = cv2.imread(export_feed_path) input = preprocess(img) export_onnx(net, model_path, input, nsigma, onnx_outpath) print("export success!") ############################## # # 檢查onnx模型 # ############################## check_onnx(onnx_outpath) # netron可視化網(wǎng)絡,可視化用節(jié)點記錄的網(wǎng)絡推理流程 netron.start(onnx_outpath) ############################## # # 運行onnx模型 # ############################## # 此處過程是數(shù)據(jù)預處理 ---> 調(diào)用run_onnx函數(shù) ---> 對模型輸出后處理 # 具體代碼就不再重復了
#完整代碼
import time import netron import cv2 import torch import onnx import numpy as np from torch.autograd import Variable import onnxruntime as ort from models import FFDNet from utils import remove_dataparallel_wrapper, normalize, variable_to_cv2_image # 此處為預處理代碼 與torch模型的預處理代碼一樣 def preprocess(img): img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1) img = np.expand_dims(img, 0) sh_im = img.shape if sh_im[2]%2==1: img = np.concatenate((img, img[:, :, -1, :][:, :, np.newaxis, :]), axis=2) if sh_im[3]%2==1: img = np.concatenate((img, img[:, :, :, -1][:, :, :, np.newaxis]), axis=3) img = normalize(img) img = torch.Tensor(img) return img # 此處為onnx模型導出的代碼,包括torch模型的pt權(quán)重加載,onnx模型的導出 def export_onnx(net, model_path, img, nsigma, onnx_outPath): nsigma /= 255. if torch.cuda.is_available(): state_dict = torch.load(model_path) model = net.cuda() dtype = torch.cuda.FloatTensor else: state_dict = torch.load(model_path, map_location='cpu') state_dict = remove_dataparallel_wrapper(state_dict) model = net dtype = torch.FloatTensor img = Variable(img.type(dtype)) nsigma = Variable(torch.FloatTensor([nsigma]).type(dtype)) # 我這里預訓練權(quán)重中參數(shù)名字與網(wǎng)絡名字不同 # 相同的話可直接load_state_dict(state_dict) new_state_dict = {} for k, v in state_dict.items(): new_state_dict[k[7:]] = v model.load_state_dict(new_state_dict) # 設置onnx的輸入輸出列表,多輸入多輸出就設置多個 input_list = ['input', 'nsigma'] output_list = ['output'] # onnx模型導出 # dynamic_axes為動態(tài)維度,如果自己的輸入輸出是維度變化的建議設置,否則只能輸入固定維度的tensor torch.onnx.export(model, (img, nsigma), onnx_outPath, verbose=True, opset_version=11, export_params=True, input_names=input_list, output_names=output_list, dynamic_axes={'input_img': {0: 'batch', 1: 'channel', 2: 'height', 3: 'width'}, 'output': {0: 'batch', 1: 'channel', 2: 'height', 3: 'width'}}) # 此處為檢查onnx模型節(jié)點,后面如果onnx算子不支持轉(zhuǎn)engine時,方便定位節(jié)點,找到不支持的算子進行修改 def check_onnx(onnx_model_path): model = onnx.load(onnx_model_path) onnx.checker.check_model((model)) print(onnx.helper.printable_graph(model.graph)) # 此處為推理onnx模型的代碼,檢查輸出是否跟torch模型相同 def run_onnx(onnx_model_path, test_img, nsigma): nsigma /= 255. with torch.no_grad: # 這里默認是cuda推理torch.cuda.FloatTensor img = Variable(test_img.type(torch.cuda.FloatTensor)) nsigma = Variable(torch.FloatTensor([nsigma]).type(torch.cuda.FloatTensor)) # 設置GPU推理 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") providers = ['CUDAExecutionProvider'] if device != "cpu" else ['CPUExecutionProvider'] # 通過創(chuàng)建onnxruntime session來運行onnx模型 ort_session = ort.InferenceSession(onnx_model_path, providers=providers) output = ort_session.run(output_names=['output'], input_feed={'input_img': np.array(img.cpu(), dtype=np.float32), 'nsigma': np.array(nsigma.cpu(), dtype=np.float32)}) return output # 此處是后處理代碼,將onnx模型的輸出處理成可顯示cv圖像 # 與torch模型的后處理一樣 def postprocess(img, img_noise_estime): out = torch.clamp(img-img_noise_estime, 0., 1.) outimg = variable_to_cv2_image(out) cv2.imshow(outimg) def main(): ############################## # # onnx模型導出 # ############################## # pt權(quán)重路徑:自己的路徑 + mypt.pt model_path = "D:/python/ffdnet-pytorch/models/net_rgb.pth" # export onnx模型時輸入進去數(shù)據(jù),用于onnx記錄網(wǎng)絡的計算過程 export_feed_path = "D:/python/ffdnet-pytorch/noisy.png" # onnx模型導出的路徑 onnx_outpath = "D:/python/ffdnet-pytorch/models/myonnx.onnx" # 實例化自己的網(wǎng)絡模型并設置輸入?yún)?shù) net = FFDNet(num_input_channels=3) nsigma = 25 # onnx 導出 img = cv2.imread(export_feed_path) input = preprocess(img) export_onnx(net, model_path, input, nsigma, onnx_outpath) print("export success!") ############################## # # 檢查onnx模型 # ############################## onnx_outpath = "D:/python/ffdnet-pytorch/models/myonnx.onnx" check_onnx(onnx_outpath) # netron可視化網(wǎng)絡,可視化用節(jié)點記錄的網(wǎng)絡推理流程 netron.start(onnx_outpath) ############################## # # 運行onnx模型 # ############################## # 此處過程是數(shù)據(jù)預處理 ---> 調(diào)用run_onnx函數(shù) ---> 對模型輸出后處理 # 具體代碼就不再重復了 if __name__ == '__main__': main()
到此這篇關于python pytorch模型轉(zhuǎn)onnx模型(多輸入+動態(tài)維度)的文章就介紹到這了,更多相關python pytorch模型轉(zhuǎn)onnx模型內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
Python中實現(xiàn)從目錄中過濾出指定文件類型的文件
這篇文章主要介紹了Python中實現(xiàn)從目錄中過濾出指定文件類型的文件,本文是一篇學筆記,實例相對簡單,需要的朋友可以參考下2015-02-02關于np.meshgrid函數(shù)中的indexing參數(shù)問題
Meshgrid函數(shù)在二維與三維空間中用于生成坐標網(wǎng)格,便于進行圖像處理和空間數(shù)據(jù)分析,二維情況下,默認使用笛卡爾坐標系,而三維meshgrid則涉及不同的坐標軸取法,在三維情況下,可能會出現(xiàn)坐標軸排列序混亂2024-09-09Python3.7安裝keras和TensorFlow的教程圖解
這篇文章主要介紹了Python3.7安裝keras和TensorFlow經(jīng)驗,本文圖文并茂給大家介紹的非常詳細,具有一定的參考借鑒價值,需要的朋友可以參考下2019-10-10