Pytorch通過保存為ONNX模型轉(zhuǎn)TensorRT5的實(shí)現(xiàn)
1 Pytorch以O(shè)NNX方式保存模型
def saveONNX(model, filepath): ''' 保存ONNX模型 :param model: 神經(jīng)網(wǎng)絡(luò)模型 :param filepath: 文件保存路徑 ''' # 神經(jīng)網(wǎng)絡(luò)輸入數(shù)據(jù)類型 dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda') torch.onnx.export(model, dummy_input, filepath, verbose=True)
2 利用TensorRT5中ONNX解析器構(gòu)建Engine
def ONNX_build_engine(onnx_file_path): ''' 通過加載onnx文件,構(gòu)建engine :param onnx_file_path: onnx文件路徑 :return: engine ''' # 打印日志 G_LOGGER = trt.Logger(trt.Logger.WARNING) with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser: builder.max_batch_size = 100 builder.max_workspace_size = 1 << 20 print('Loading ONNX file from path {}...'.format(onnx_file_path)) with open(onnx_file_path, 'rb') as model: print('Beginning ONNX file parsing') parser.parse(model.read()) print('Completed parsing of ONNX file') print('Building an engine from file {}; this may take a while...'.format(onnx_file_path)) engine = builder.build_cuda_engine(network) print("Completed creating Engine") # 保存計(jì)劃文件 # with open(engine_file_path, "wb") as f: # f.write(engine.serialize()) return engine
3 構(gòu)建TensorRT運(yùn)行引擎進(jìn)行預(yù)測
def loadONNX2TensorRT(filepath): ''' 通過onnx文件,構(gòu)建TensorRT運(yùn)行引擎 :param filepath: onnx文件路徑 ''' # 計(jì)算開始時(shí)間 Start = time() engine = self.ONNX_build_engine(filepath) # 讀取測試集 datas = DataLoaders() test_loader = datas.testDataLoader() img, target = next(iter(test_loader)) img = img.numpy() target = target.numpy() img = img.ravel() context = engine.create_execution_context() output = np.empty((100, 10), dtype=np.float32) # 分配內(nèi)存 d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize) d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize) bindings = [int(d_input), int(d_output)] # pycuda操作緩沖區(qū) stream = cuda.Stream() # 將輸入數(shù)據(jù)放入device cuda.memcpy_htod_async(d_input, img, stream) # 執(zhí)行模型 context.execute_async(100, bindings, stream.handle, None) # 將預(yù)測結(jié)果從從緩沖區(qū)取出 cuda.memcpy_dtoh_async(output, d_output, stream) # 線程同步 stream.synchronize() print("Test Case: " + str(target)) print("Prediction: " + str(np.argmax(output, axis=1))) print("tensorrt time:", time() - Start) del context del engine
補(bǔ)充知識:Pytorch/Caffe可以先轉(zhuǎn)換為ONNX,再轉(zhuǎn)換為TensorRT
近來工作,試圖把Pytorch用TensorRT運(yùn)行。折騰了半天,沒有完成。github中的轉(zhuǎn)換代碼,只能處理pytorch 0.2.0的功能(也明確表示不維護(hù)了)。和同事一起處理了很多例外,還是沒有通過。吾以為,實(shí)際上即使勉強(qiáng)過了,能不能跑也是問題。
后來有高手建議,先轉(zhuǎn)換為ONNX,再轉(zhuǎn)換為TensorRT。這個(gè)思路基本可行。
是不是這樣就萬事大吉?當(dāng)然不是,還是有嚴(yán)重問題要解決的。這只是個(gè)思路。
以上這篇Pytorch通過保存為ONNX模型轉(zhuǎn)TensorRT5的實(shí)現(xiàn)就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python數(shù)據(jù)分析之雙色球中藍(lán)紅球分析統(tǒng)計(jì)示例
這篇文章主要介紹了Python數(shù)據(jù)分析之雙色球中藍(lán)紅球分析統(tǒng)計(jì),結(jié)合實(shí)例形式較為詳細(xì)的分析了Python針對雙色球藍(lán)紅球中獎(jiǎng)數(shù)據(jù)分析的相關(guān)操作技巧,需要的朋友可以參考下2018-02-02python用socket傳輸圖片的項(xiàng)目實(shí)踐
使用python在網(wǎng)絡(luò)上傳送圖片數(shù)據(jù),需要以byte格式讀取圖片,這樣才可以通過socket傳輸,本文就來介紹了python用socket傳輸圖片的項(xiàng)目實(shí)踐,具有一定的參考價(jià)值,感興趣的可以了解一下2024-02-025種Python統(tǒng)計(jì)次數(shù)方法技巧
這篇文章主要給大家分享的是5種Python統(tǒng)計(jì)次數(shù)方法技巧,文章主要包括字典 dict 統(tǒng)計(jì)、collections.defaultdict 統(tǒng)計(jì)、List count方法、集合(set)和列表(list)統(tǒng)計(jì)、collections.Counter方法,感興趣的小伙伴一起進(jìn)入下面文章內(nèi)容吧2021-11-11Pandas技巧分享之創(chuàng)建測試數(shù)據(jù)
學(xué)習(xí)pandas的過程中,為了嘗試pandas提供的各類功能強(qiáng)大的函數(shù),常常需要花費(fèi)很多時(shí)間去創(chuàng)造測試數(shù)據(jù),本篇介紹了一些快速創(chuàng)建測試數(shù)據(jù)的方法,需要的可以參考一下2023-07-07Python PyQt5 Pycharm 環(huán)境搭建及配置詳解(圖文教程)
這篇文章主要介紹了Python PyQt5 Pycharm 環(huán)境搭建及配置詳解,文中通過圖文介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-07-07python如何為list實(shí)現(xiàn)find方法
這篇文章主要介紹了python如何為list實(shí)現(xiàn)find方法,具有很好的參考價(jià)值,希望對大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-05-05opencv實(shí)現(xiàn)靜態(tài)手勢識別 opencv實(shí)現(xiàn)剪刀石頭布游戲
這篇文章主要為大家詳細(xì)介紹了opencv實(shí)現(xiàn)靜態(tài)手勢識別,opencv實(shí)現(xiàn)剪刀石頭布游戲,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2019-01-01