pytorch模型部署 pth轉(zhuǎn)onnx的方法
Pytorch轉(zhuǎn)ONNX的意義
一般來(lái)說(shuō)轉(zhuǎn)ONNX只是一個(gè)手段,在之后得到ONNX模型后還需要再將它做轉(zhuǎn)換,比如轉(zhuǎn)換到TensorRT上完成部署,或者有的人多加一步,從ONNX先轉(zhuǎn)換到caffe,再?gòu)腸affe到tensorRT。Pytorch自帶的torch.onnx.export轉(zhuǎn)換得到的ONNX,ONNXRuntime需要的ONNX,TensorRT需要的ONNX都是不同的。
將pytorch訓(xùn)練保存的pth文件轉(zhuǎn)為onnx文件,為后續(xù)模型部署做準(zhǔn)備。
一、分類模型
import torch
import os
import timm
import argparse
from utils_net import Resnet
parser = argparse.ArgumentParser()
parser.add_argument("--pth_path", default='classify_model.pth')
parser.add_argument("--save_onnx_path", default='classify_model.onnx')
parser.add_argument("--input_width", default=416)
parser.add_argument("--input_height", default=416)
parser.add_argument("--input_channel", default=1)
parser.add_argument("--num_classes", default=6)
args = parser.parse_args()
def pth_to_onnx(pth_path, onnx_path, in_hig, in_wid, in_chal, num_cls):
if not onnx_path.endswith('.onnx'):
print('Warning! The onnx model name is not correct,\
please give a name that ends with \'.onnx\'!')
return 0
model = Resnet(num_classes=num_cls)
model.load_state_dict(torch.load(pth_path))
model.eval()
print(f'{pth_path} model loaded')
input_names = ['input']
output_names = ['output']
im = torch.rand(1, in_chal, in_hig, in_wid)
torch.onnx.export(model, im, onnx_path,
verbose=False,
input_names=input_names,
output_names=output_names)
print("Exporting .pth model to onnx model has been successful!")
print(f"Onnx model save as {onnx_path}")
if __name__ == '__main__':
pth_to_onnx(pth_path=args.pth_path,
onnx_path=args.save_onnx_path,
in_hig=args.input_height,
in_wid=args.input_width,
in_chal=args.input_channel,
num_cls=args.num_classes)運(yùn)行結(jié)果:
classify_model.pth model loaded
Exporting .pth model to onnx model has been successful!
Onnx model save as classify_model.onnxProcess finished with exit code 0
二、分割模型
import torch
import os
import argparse
from utils_net import seg_net
parser = argparse.ArgumentParser()
parser.add_argument("--pth_path", default='segment_model.pth')
parser.add_argument("--save_onnx_path", default='segment_model.onnx')
parser.add_argument("--input_width", default=416)
parser.add_argument("--input_height", default=416)
parser.add_argument("--input_channel", default=1)
parser.add_argument("--num_classes", default=4)
args = parser.parse_args()
def pth_to_onnx(pth_path, onnx_path, in_hig, in_wid, in_channel, num_cls):
if not onnx_path.endswith('.onnx'):
print('Warning! The onnx model name is not correct,\
please give a name that ends with \'.onnx\'!')
return 0
model = seg_net(in_channel=in_channel, num_cls=num_cls)
model.load_state_dict(torch.load(pth_path))
model.eval()
print(f'{pth_path} model loaded')
input_names = ['input']
output_names = ['output']
im = torch.rand(1, in_channel, in_hig, in_wid)
torch.onnx.export(model, im, onnx_path,
verbose=False,
input_names=input_names,
output_names=output_names,
opset_version=11)
print("Exporting .pth model to onnx model has been successful!")
print(f"Onnx model save as {onnx_path}")
if __name__ == '__main__':
pth_to_onnx(pth_path=args.pth_path,
onnx_path=args.save_onnx_path,
in_hig=args.input_height,
in_wid=args.input_width,
in_channel=args.input_channel,
num_cls=args.num_classes)運(yùn)行結(jié)果:
segment_model.pth model loaded
Exporting .pth model to onnx model has been successful!
Onnx model save as segment_model.onnxProcess finished with exit code 0
三、目標(biāo)檢測(cè)模型
在這里插入代碼片
import torch
import onnx
import argparse
from utils_net import YoloBody
parser = argparse.ArgumentParser()
parser.add_argument("--pth_path", default='yolo.pth')
parser.add_argument("--save_onnx_path", default='yolo.onnx')
parser.add_argument("--input_width", default=416)
parser.add_argument("--input_height", default=416)
parser.add_argument("--num_classes", default=2)
parser.add_argument("--anchors_mask", default=[[6, 7, 8], [3, 4, 5], [0, 1, 2]])
args = parser.parse_args()
def pth_to_onnx(pth_path: str, save_onnx_path: str, num_cls: int,
in_hig: int, in_wid: int, anchor_mask: list,
opset_version: int = 12, simplify: bool = False):
"""
:param pth_path: pth文件文件
:param save_onnx_path: 準(zhǔn)備保存的onnx路徑
:param num_cls: 檢測(cè)目標(biāo)類別數(shù)
:param in_hig: 網(wǎng)絡(luò)輸入高度
:param in_wid: 網(wǎng)絡(luò)輸入寬度
:param anchor_mask: anchor寬高索引
:param opset_version: onnx算子集版本
:param simplify: 是否對(duì)模型進(jìn)行簡(jiǎn)化
:return:保存onnx到指定路徑
"""
# Build model, load weights
net = YoloBody(anchors_mask=anchor_mask,
num_classes=num_cls)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# net.load_state_dict(torch.load(pth_path, map_location=device))
net.load_state_dict(torch.load(pth_path))
# print(next(net.parameters()).device)
net = net.eval()
print(f'{pth_path} model loaded')
im = torch.zeros(1, 3, in_hig, in_wid).to('cpu')
input_layer_names = ['images']
output_layer_names = ['output']
# Export the model
print(f'Starting export with onnx {onnx.__version__}.')
torch.onnx.export(net,
im,
f=save_onnx_path,
verbose=False,
opset_version=opset_version,
training=torch.onnx.TrainingMode.EVAL,
do_constant_folding=True,
input_names=input_layer_names,
output_names=output_layer_names,
dynamic_axes=None)
# Checks
model_onnx = onnx.load(save_onnx_path) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# Simplify onnx
if simplify:
import onnxsim
print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.')
model_onnx, check = onnxsim.simplify(
model_onnx,
dynamic_input_shape=False,
input_shapes=None)
assert check, 'assert check failed'
onnx.save(model_onnx, save_onnx_path)
print('Onnx model save as {}'.format(save_onnx_path))
if __name__ == '__main__':
pth_to_onnx(pth_path=args.pth_path,
save_onnx_path=args.save_onnx_path,
num_cls=args.num_classes,
in_hig=args.input_height,
in_wid=args.input_width,
anchor_mask=args.anchors_mask)運(yùn)行結(jié)果:
yolo.pth model loaded
Starting export with onnx 1.11.0.
Onnx model save as yolo.onnxProcess finished with exit code 0
參考鏈接:
1.yolo
2.模型部署翻車記:pytorch轉(zhuǎn)onnx踩坑實(shí)錄
到此這篇關(guān)于pytorch模型部署 pth轉(zhuǎn)onnx的文章就介紹到這了,更多相關(guān)pytorch模型部署內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python實(shí)現(xiàn)的文本簡(jiǎn)單可逆加密算法示例
這篇文章主要介紹了Python實(shí)現(xiàn)的文本簡(jiǎn)單可逆加密算法,結(jié)合完整實(shí)例形式分析了Python自定義加密與解密算法具體實(shí)現(xiàn)與使用技巧,需要的朋友可以參考下2017-05-05
Python OpenCV實(shí)現(xiàn)攝像頭人臉識(shí)別功能
這篇文章主要介紹了Python OpenCV實(shí)現(xiàn)攝像頭人臉識(shí)別,使用Python 3和OpenCV進(jìn)行攝像頭人臉識(shí)別的基本步驟,本文結(jié)合實(shí)例代碼給大家介紹的非常詳細(xì),需要的朋友可以參考下2023-07-07
Python實(shí)戰(zhàn)整活之聊天機(jī)器人
這篇文章主要介紹了Python實(shí)戰(zhàn)整活之聊天機(jī)器人,文中有非常詳細(xì)的代碼示例,對(duì)正在學(xué)習(xí)python的小伙伴們有非常好的幫助,需要的朋友可以參考下2021-04-04
使用python實(shí)現(xiàn)離散時(shí)間傅里葉變換的方法
這篇文章主要介紹了使用python實(shí)現(xiàn)離散時(shí)間傅里葉變換的方法,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-09-09
Python解決MySQL數(shù)據(jù)處理從SQL批量刪除報(bào)錯(cuò)
這篇文章主要為大家介紹了Python解決MySQL數(shù)據(jù)處理從SQL批量刪除報(bào)錯(cuò),有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-12-12
Python實(shí)現(xiàn)輕松找出兩個(gè)列表不同之處
在日常編程中,需要比較兩個(gè)列表并找出它們之間差異是一種常見(jiàn)需求,在本文中,我們將深入探討Python中查找兩個(gè)列表差異值的方法,需要的小伙伴可以參考下2023-12-12

