Pytorch通過保存為ONNX模型轉TensorRT5的實現(xiàn)
1 Pytorch以ONNX方式保存模型
def saveONNX(model, filepath): ''' 保存ONNX模型 :param model: 神經(jīng)網(wǎng)絡模型 :param filepath: 文件保存路徑 ''' # 神經(jīng)網(wǎng)絡輸入數(shù)據(jù)類型 dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda') torch.onnx.export(model, dummy_input, filepath, verbose=True)
2 利用TensorRT5中ONNX解析器構建Engine
def ONNX_build_engine(onnx_file_path):
'''
通過加載onnx文件,構建engine
:param onnx_file_path: onnx文件路徑
:return: engine
'''
# 打印日志
G_LOGGER = trt.Logger(trt.Logger.WARNING)
with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
builder.max_batch_size = 100
builder.max_workspace_size = 1 << 20
print('Loading ONNX file from path {}...'.format(onnx_file_path))
with open(onnx_file_path, 'rb') as model:
print('Beginning ONNX file parsing')
parser.parse(model.read())
print('Completed parsing of ONNX file')
print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
engine = builder.build_cuda_engine(network)
print("Completed creating Engine")
# 保存計劃文件
# with open(engine_file_path, "wb") as f:
# f.write(engine.serialize())
return engine
3 構建TensorRT運行引擎進行預測
def loadONNX2TensorRT(filepath):
'''
通過onnx文件,構建TensorRT運行引擎
:param filepath: onnx文件路徑
'''
# 計算開始時間
Start = time()
engine = self.ONNX_build_engine(filepath)
# 讀取測試集
datas = DataLoaders()
test_loader = datas.testDataLoader()
img, target = next(iter(test_loader))
img = img.numpy()
target = target.numpy()
img = img.ravel()
context = engine.create_execution_context()
output = np.empty((100, 10), dtype=np.float32)
# 分配內存
d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
bindings = [int(d_input), int(d_output)]
# pycuda操作緩沖區(qū)
stream = cuda.Stream()
# 將輸入數(shù)據(jù)放入device
cuda.memcpy_htod_async(d_input, img, stream)
# 執(zhí)行模型
context.execute_async(100, bindings, stream.handle, None)
# 將預測結果從從緩沖區(qū)取出
cuda.memcpy_dtoh_async(output, d_output, stream)
# 線程同步
stream.synchronize()
print("Test Case: " + str(target))
print("Prediction: " + str(np.argmax(output, axis=1)))
print("tensorrt time:", time() - Start)
del context
del engine
補充知識:Pytorch/Caffe可以先轉換為ONNX,再轉換為TensorRT
近來工作,試圖把Pytorch用TensorRT運行。折騰了半天,沒有完成。github中的轉換代碼,只能處理pytorch 0.2.0的功能(也明確表示不維護了)。和同事一起處理了很多例外,還是沒有通過。吾以為,實際上即使勉強過了,能不能跑也是問題。
后來有高手建議,先轉換為ONNX,再轉換為TensorRT。這個思路基本可行。
是不是這樣就萬事大吉?當然不是,還是有嚴重問題要解決的。這只是個思路。
以上這篇Pytorch通過保存為ONNX模型轉TensorRT5的實現(xiàn)就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
Python數(shù)據(jù)分析之雙色球中藍紅球分析統(tǒng)計示例
這篇文章主要介紹了Python數(shù)據(jù)分析之雙色球中藍紅球分析統(tǒng)計,結合實例形式較為詳細的分析了Python針對雙色球藍紅球中獎數(shù)據(jù)分析的相關操作技巧,需要的朋友可以參考下2018-02-02
Pandas技巧分享之創(chuàng)建測試數(shù)據(jù)
學習pandas的過程中,為了嘗試pandas提供的各類功能強大的函數(shù),常常需要花費很多時間去創(chuàng)造測試數(shù)據(jù),本篇介紹了一些快速創(chuàng)建測試數(shù)據(jù)的方法,需要的可以參考一下2023-07-07
Python PyQt5 Pycharm 環(huán)境搭建及配置詳解(圖文教程)
這篇文章主要介紹了Python PyQt5 Pycharm 環(huán)境搭建及配置詳解,文中通過圖文介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2019-07-07
opencv實現(xiàn)靜態(tài)手勢識別 opencv實現(xiàn)剪刀石頭布游戲
這篇文章主要為大家詳細介紹了opencv實現(xiàn)靜態(tài)手勢識別,opencv實現(xiàn)剪刀石頭布游戲,具有一定的參考價值,感興趣的小伙伴們可以參考一下2019-01-01

