pytorch模型部署 pth轉(zhuǎn)onnx的方法
Pytorch轉(zhuǎn)ONNX的意義
一般來說轉(zhuǎn)ONNX只是一個手段,在之后得到ONNX模型后還需要再將它做轉(zhuǎn)換,比如轉(zhuǎn)換到TensorRT上完成部署,或者有的人多加一步,從ONNX先轉(zhuǎn)換到caffe,再從caffe到tensorRT。Pytorch自帶的torch.onnx.export轉(zhuǎn)換得到的ONNX,ONNXRuntime需要的ONNX,TensorRT需要的ONNX都是不同的。
將pytorch訓(xùn)練保存的pth文件轉(zhuǎn)為onnx文件,為后續(xù)模型部署做準(zhǔn)備。
一、分類模型
import torch import os import timm import argparse from utils_net import Resnet parser = argparse.ArgumentParser() parser.add_argument("--pth_path", default='classify_model.pth') parser.add_argument("--save_onnx_path", default='classify_model.onnx') parser.add_argument("--input_width", default=416) parser.add_argument("--input_height", default=416) parser.add_argument("--input_channel", default=1) parser.add_argument("--num_classes", default=6) args = parser.parse_args() def pth_to_onnx(pth_path, onnx_path, in_hig, in_wid, in_chal, num_cls): if not onnx_path.endswith('.onnx'): print('Warning! The onnx model name is not correct,\ please give a name that ends with \'.onnx\'!') return 0 model = Resnet(num_classes=num_cls) model.load_state_dict(torch.load(pth_path)) model.eval() print(f'{pth_path} model loaded') input_names = ['input'] output_names = ['output'] im = torch.rand(1, in_chal, in_hig, in_wid) torch.onnx.export(model, im, onnx_path, verbose=False, input_names=input_names, output_names=output_names) print("Exporting .pth model to onnx model has been successful!") print(f"Onnx model save as {onnx_path}") if __name__ == '__main__': pth_to_onnx(pth_path=args.pth_path, onnx_path=args.save_onnx_path, in_hig=args.input_height, in_wid=args.input_width, in_chal=args.input_channel, num_cls=args.num_classes)
運行結(jié)果:
classify_model.pth model loaded
Exporting .pth model to onnx model has been successful!
Onnx model save as classify_model.onnxProcess finished with exit code 0
二、分割模型
import torch import os import argparse from utils_net import seg_net parser = argparse.ArgumentParser() parser.add_argument("--pth_path", default='segment_model.pth') parser.add_argument("--save_onnx_path", default='segment_model.onnx') parser.add_argument("--input_width", default=416) parser.add_argument("--input_height", default=416) parser.add_argument("--input_channel", default=1) parser.add_argument("--num_classes", default=4) args = parser.parse_args() def pth_to_onnx(pth_path, onnx_path, in_hig, in_wid, in_channel, num_cls): if not onnx_path.endswith('.onnx'): print('Warning! The onnx model name is not correct,\ please give a name that ends with \'.onnx\'!') return 0 model = seg_net(in_channel=in_channel, num_cls=num_cls) model.load_state_dict(torch.load(pth_path)) model.eval() print(f'{pth_path} model loaded') input_names = ['input'] output_names = ['output'] im = torch.rand(1, in_channel, in_hig, in_wid) torch.onnx.export(model, im, onnx_path, verbose=False, input_names=input_names, output_names=output_names, opset_version=11) print("Exporting .pth model to onnx model has been successful!") print(f"Onnx model save as {onnx_path}") if __name__ == '__main__': pth_to_onnx(pth_path=args.pth_path, onnx_path=args.save_onnx_path, in_hig=args.input_height, in_wid=args.input_width, in_channel=args.input_channel, num_cls=args.num_classes)
運行結(jié)果:
segment_model.pth model loaded
Exporting .pth model to onnx model has been successful!
Onnx model save as segment_model.onnxProcess finished with exit code 0
三、目標(biāo)檢測模型
在這里插入代碼片 import torch import onnx import argparse from utils_net import YoloBody parser = argparse.ArgumentParser() parser.add_argument("--pth_path", default='yolo.pth') parser.add_argument("--save_onnx_path", default='yolo.onnx') parser.add_argument("--input_width", default=416) parser.add_argument("--input_height", default=416) parser.add_argument("--num_classes", default=2) parser.add_argument("--anchors_mask", default=[[6, 7, 8], [3, 4, 5], [0, 1, 2]]) args = parser.parse_args() def pth_to_onnx(pth_path: str, save_onnx_path: str, num_cls: int, in_hig: int, in_wid: int, anchor_mask: list, opset_version: int = 12, simplify: bool = False): """ :param pth_path: pth文件文件 :param save_onnx_path: 準(zhǔn)備保存的onnx路徑 :param num_cls: 檢測目標(biāo)類別數(shù) :param in_hig: 網(wǎng)絡(luò)輸入高度 :param in_wid: 網(wǎng)絡(luò)輸入寬度 :param anchor_mask: anchor寬高索引 :param opset_version: onnx算子集版本 :param simplify: 是否對模型進行簡化 :return:保存onnx到指定路徑 """ # Build model, load weights net = YoloBody(anchors_mask=anchor_mask, num_classes=num_cls) # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # net.load_state_dict(torch.load(pth_path, map_location=device)) net.load_state_dict(torch.load(pth_path)) # print(next(net.parameters()).device) net = net.eval() print(f'{pth_path} model loaded') im = torch.zeros(1, 3, in_hig, in_wid).to('cpu') input_layer_names = ['images'] output_layer_names = ['output'] # Export the model print(f'Starting export with onnx {onnx.__version__}.') torch.onnx.export(net, im, f=save_onnx_path, verbose=False, opset_version=opset_version, training=torch.onnx.TrainingMode.EVAL, do_constant_folding=True, input_names=input_layer_names, output_names=output_layer_names, dynamic_axes=None) # Checks model_onnx = onnx.load(save_onnx_path) # load onnx model onnx.checker.check_model(model_onnx) # check onnx model # Simplify onnx if simplify: import onnxsim print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.') model_onnx, check = onnxsim.simplify( model_onnx, dynamic_input_shape=False, input_shapes=None) assert check, 'assert check failed' onnx.save(model_onnx, save_onnx_path) print('Onnx model save as {}'.format(save_onnx_path)) if __name__ == '__main__': pth_to_onnx(pth_path=args.pth_path, save_onnx_path=args.save_onnx_path, num_cls=args.num_classes, in_hig=args.input_height, in_wid=args.input_width, anchor_mask=args.anchors_mask)
運行結(jié)果:
yolo.pth model loaded
Starting export with onnx 1.11.0.
Onnx model save as yolo.onnxProcess finished with exit code 0
參考鏈接:
1.yolo
2.模型部署翻車記:pytorch轉(zhuǎn)onnx踩坑實錄
到此這篇關(guān)于pytorch模型部署 pth轉(zhuǎn)onnx的文章就介紹到這了,更多相關(guān)pytorch模型部署內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python OpenCV實現(xiàn)攝像頭人臉識別功能
這篇文章主要介紹了Python OpenCV實現(xiàn)攝像頭人臉識別,使用Python 3和OpenCV進行攝像頭人臉識別的基本步驟,本文結(jié)合實例代碼給大家介紹的非常詳細,需要的朋友可以參考下2023-07-07Python解決MySQL數(shù)據(jù)處理從SQL批量刪除報錯
這篇文章主要為大家介紹了Python解決MySQL數(shù)據(jù)處理從SQL批量刪除報錯,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪2023-12-12