pytorch模型轉(zhuǎn)換為onnx可視化(使用netron)
pytorch模型轉(zhuǎn)換為onnx,并使用netron可視化
netron 是一個(gè)非常好用的網(wǎng)絡(luò)結(jié)構(gòu)可視化工具。
但是netron對pytorch模型的支持還不成熟。自己試的效果是生成的模型圖沒有連線。
目前支持的框架 根據(jù)netron的github
目前netron支持:
ONNX (.onnx, .pb, .pbtxt) Keras (.h5, .keras) Core ML (.mlmodel) Caffe (.caffemodel, .prototxt) Caffe2 (predict_net.pb, predict_net.pbtxt) Darknet (.cfg) MXNet (.model, -symbol.json) ncnn (.param) TensorFlow Lite (.tflite) PaddlePaddle (.zip, model) TensorFlow.js CNTK (.model, .cntk)
并且實(shí)驗(yàn)性支持:
TorchScript (.pt, .pth) PyTorch (.pt, .pth) Torch (.t7) Arm NN (.armnn) BigDL (.bigdl, .model) Chainer (.npz, .h5) Deeplearning4j (.zip) MediaPipe (.pbtxt) ML.NET (.zip), MNN (.mnn) OpenVINO (.xml) scikit-learn (.pkl) TensorFlow (.pb, .meta, .pbtxt, .ckpt, .index)
Netron supports ONNX, TensorFlow Lite, Caffe, Keras, Darknet, PaddlePaddle, ncnn, MNN, Core ML, RKNN, MXNet, MindSpore Lite, TNN, Barracuda, Tengine, CNTK, TensorFlow.js, Caffe2 and UFF.
Netron has experimental support for PyTorch, TensorFlow, TorchScript, OpenVINO, Torch, Vitis AI, kmodel, Arm NN, BigDL, Chainer, Deeplearning4j, MediaPipe, ML.NET and scikit-learn.
這里就有一個(gè)把 .pth 模型轉(zhuǎn)化為 .onnx 模型。
Pytorch模型轉(zhuǎn)onnx
model = resnet18(pretrained=True) # print(model) # old_net_path = "resnet18.pth" new_net_path = "./resnet18.onnx" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 導(dǎo)入模型 net = model.to(device) # net.load_state_dict(torch.load(old_net_path, map_location=device)) net.eval() input = torch.randn(1, 3, 224, 224).to(device) # BCHW 其中Batch必須為1,因?yàn)闇y試時(shí)一般為1,尺寸HW必須和訓(xùn)練時(shí)的尺寸一致 torch.onnx.export(net, input, new_net_path, verbose=False)
torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None)
參數(shù):
model(torch.nn.Module)-要被導(dǎo)出的模型 args(參數(shù)的集合)-模型的輸入,例如,這種model(*args)方式是對模型的有效調(diào)用。任何非Variable參數(shù)都將硬編碼到導(dǎo)出的模型中;任何Variable參數(shù)都將成為導(dǎo)出的模型的輸入,并按照他們在args中出現(xiàn)的順序輸入。如果args是一個(gè)Variable,這等價(jià)于用包含這個(gè)Variable的1-ary元組調(diào)用它。(注意:現(xiàn)在不支持向模型傳遞關(guān)鍵字參數(shù)。) f-一個(gè)類文件的對象(必須實(shí)現(xiàn)文件描述符的返回)或一個(gè)包含文件名字符串。一個(gè)二進(jìn)制Protobuf將會寫入這個(gè)文件中。 export_params(bool,default True)-如果指定,所有參數(shù)都會被導(dǎo)出。如果你只想導(dǎo)出一個(gè)未訓(xùn)練的模型,就將此參數(shù)設(shè)置為False。在這種情況下,導(dǎo)出的模型將首先把所有parameters作為參arguments,順序由model.state_dict().values()指定。 verbose(bool,default False)-如果指定,將會輸出被導(dǎo)出的軌跡的調(diào)試描述。 training(bool,default False)-導(dǎo)出訓(xùn)練模型下的模型。目前,ONNX只面向推斷模型的導(dǎo)出,所以一般不需要將該項(xiàng)設(shè)置為True。 input_names(list of strings, default empty list)-按順序分配名稱到圖中的輸入節(jié)點(diǎn)。 output_names(list of strings, default empty list)-按順序分配名稱到圖中的輸出節(jié)點(diǎn)。
文件中保存模型結(jié)構(gòu)和權(quán)重參數(shù)
import torch torch_model = torch.load("save.pt") # pytorch模型加載 batch_size = 1 #批處理大小 input_shape = (3,244,244) #輸入數(shù)據(jù) # set the model to inference mode torch_model.eval() x = torch.randn(batch_size,*input_shape) # 生成張量 export_onnx_file = "test.onnx" # 目的ONNX文件名 torch.onnx.export(torch_model, x, export_onnx_file, opset_version=10, do_constant_folding=True, # 是否執(zhí)行常量折疊優(yōu)化 input_names=["input"], # 輸入名 output_names=["output"], # 輸出名 dynamic_axes={"input":{0:"batch_size"}, # 批處理變量 "output":{0:"batch_size"}})
dynamic_axes字段用于批處理.若不想支持批處理或固定批處理大小,移除dynamic_axes字段即可.
文件中只保留模型權(quán)重
import torch torch_model = selfmodel() # 由研究員提供python.py文件 batch_size = 1 # 批處理大小 input_shape = (3, 244, 244) # 輸入數(shù)據(jù) # set the model to inference mode torch_model.eval() x = torch.randn(batch_size,*input_shape) # 生成張量 export_onnx_file = "test.onnx" # 目的ONNX文件名 torch.onnx.export(torch_model, x, export_onnx_file, opset_version=10, do_constant_folding=True, # 是否執(zhí)行常量折疊優(yōu)化 input_names=["input"], # 輸入名 output_names=["output"], # 輸出名 dynamic_axes={"input":{0:"batch_size"}, # 批處理變量 "output":{0:"batch_size"}})
到此這篇關(guān)于pytorch模型轉(zhuǎn)換為onnx可視化(使用netron)的文章就介紹到這了,更多相關(guān)pytorch模型轉(zhuǎn)onnx可視化內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python從csv文件中讀取數(shù)據(jù)及提取數(shù)據(jù)的方法
這篇文章主要介紹了Python從csv文件中讀取數(shù)據(jù)并提取數(shù)據(jù)的方法,文中通過多種方法給大家講解獲取指定列的數(shù)據(jù),并存入一個(gè)數(shù)組中,每種方法通過實(shí)例代碼給大家介紹的非常詳細(xì),需要的朋友參考下吧2021-11-11Windows下Python使用Pandas模塊操作Excel文件的教程
Pandas是一個(gè)強(qiáng)大的Python數(shù)據(jù)分析模塊,這里我們先使用ANACONDA來幫助獲取Pandas所以來的一些環(huán)境,然后來初步學(xué)習(xí)Windows下Python使用Pandas模塊操作Excel文件的教程2016-05-05Python+flask編寫一個(gè)簡單實(shí)用的自動(dòng)排班系統(tǒng)
這篇文章主要為大家詳細(xì)介紹了如何基于Python+flask編寫一個(gè)簡單實(shí)用的自動(dòng)排班系統(tǒng),文中的示例代碼講解詳細(xì),有需要的小伙伴可以了解下2025-03-03sublime python3 輸入換行不結(jié)束的方法
下面小編就為大家分享一篇sublime python3 輸入換行不結(jié)束的方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-04-04