使用LibTorch進(jìn)行C++調(diào)用pytorch模型方式
前天由于某些原因需要利用C++調(diào)用PyTorch,于是接觸到了LibTorch,配了兩天最終有了一定的效果,于是記錄一下。
環(huán)境
- PyTorch1.6.0
- cuda10.2
- opencv4.4.0
- VS2017
具體過(guò)程
下載LibTorch
去PyTorch官網(wǎng)下載LibTorch包,選擇對(duì)應(yīng)的版本,這里我選擇Stable(1.6.0),Windows,LibTorch,C++/JAVA,10.2,然后我選擇release版本下載,如下圖
下載完后先不用管它,之后再用
用pytorch生成模型文件
我先創(chuàng)建了一個(gè)python文件,加載resnet50預(yù)訓(xùn)練模型,用來(lái)生成模型文件,代碼如下
import torch import torchvision.models as models from PIL import Image import numpy as np from torchvision import transforms model_resnet = models.resnet50(pretrained=True).cuda() # model_resnet.load_state_dict(torch.load("resnet_Epoch_4_Top1_99.75845336914062.pkl")) model_resnet.eval() # 自己選擇任意一張圖片,并將它的路徑寫(xiě)在open方法里,用來(lái)讀取圖像,我這里路徑就是‘111.jpg'了 image = Image.open("111.jpg").convert('RGB') tf = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), # transforms.Normalize(mean=[0.5]*3, std=[0.5]*3) ]) img = tf(image) img = img.unsqueeze(dim=0) print(img.shape) input = torch.rand(1, 3, 224, 224).cuda() traced_script_module_resnet = torch.jit.trace(model_resnet, input) output = traced_script_module_resnet(img.cuda()) print(output.shape) pred = torch.argmax(output, dim=1) print(pred) traced_script_module_resnet.save("model_resnet_jit_cuda.pt")
最后可以生成一個(gè)model_resnet_jit_cuda.pt文件,產(chǎn)生的輸出如下所示
第一行是我們讀取圖像的shape,我們讀取圖片之后經(jīng)過(guò)各種resize,增加維度,把圖片數(shù)據(jù)的shape修改成模型接受的格式,可以看到預(yù)測(cè)的結(jié)果是921,之后我們將用到生成的model_resnet_jit_cuda.pt文件。
VS創(chuàng)建工程并進(jìn)行環(huán)境配置
我在這個(gè)python文件路徑下創(chuàng)建了這個(gè)vs工程Project1
創(chuàng)建完成之后我們打開(kāi)Project1文件夾,里面內(nèi)容如下
現(xiàn)在創(chuàng)建VS工程先告一段落,開(kāi)始進(jìn)行工程環(huán)境配置。把之前下載的LibTorch,解壓到當(dāng)前目錄,解壓后會(huì)出現(xiàn)一個(gè)libtorch的文件夾,文件夾目錄里的內(nèi)容為
這里將我框選的文件夾路徑配置到工程屬性當(dāng)中,打開(kāi)剛才新建的VS工程,選擇項(xiàng)目為relaese的×64版本
然后點(diǎn)擊項(xiàng)目->Project1屬性,彈出屬性頁(yè)
在屬性頁(yè)同樣注意是release的×64平臺(tái),點(diǎn)擊VC++目錄,在包含目錄下加載我之前框出來(lái)的include文件夾路徑,在庫(kù)目錄下加載框出來(lái)的lib文件夾路徑,同時(shí),我們也要用到opencv,所以也需要在包含目錄下加載opencv的include文件夾與opencv2文件夾,在庫(kù)目錄下加載opencv\build\x64\vc14\lib,如下圖
然后在屬性頁(yè)的鏈接器->輸入,添加附加依賴(lài)項(xiàng),首先先把opencv的依賴(lài)項(xiàng)添加了
opencv_world440.lib,(如果一直用的Debug模式,就添加opencv_world440d.lib),然后將libtorch/lib里所有后綴為.lib的文件全添加進(jìn)來(lái),打開(kāi)這個(gè)文件夾
全都寫(xiě)進(jìn)去,再點(diǎn)擊確定,如下圖所示
然后點(diǎn)擊鏈接器->命令行,加上/INCLUDE:?warp_size@cuda@at@@YAHXZ 這一句,加上這一句是因?yàn)槲覀円胏uda版本的,如果是cpu版本可以不加。
最后點(diǎn)擊C/C++ ->常規(guī)的SDL檢查,設(shè)置為否
點(diǎn)擊C/C++ ->語(yǔ)言的符合模式,設(shè)置為否
到此我們的配置就全部結(jié)束了!最后!復(fù)制libtorch/lib文件夾下所有文件,粘貼到工程文件夾Project1/×64/release文件夾里(點(diǎn)擊此處的Project1文件夾可以發(fā)現(xiàn)里面也有一個(gè)×64/release,之前我也糾結(jié)是放在哪,然后我都試了一下,發(fā)現(xiàn)這個(gè)里面是可以不放的)
運(yùn)行VS2017工程文件
然后我運(yùn)行VS工程下一個(gè)空的main文件,沒(méi)有報(bào)錯(cuò),配置大致是沒(méi)問(wèn)題的,最后添加完整代碼,如下
#include <torch/script.h> // One-stop header. #include <opencv2/opencv.hpp> #include <iostream> #include <memory> //https://pytorch.org/tutorials/advanced/cpp_export.html std::string image_path = "../../111.jpg"; int main(int argc, const char* argv[]) { // Deserialize the ScriptModule from a file using torch::jit::load(). //std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("../../model_resnet_jit.pt"); using torch::jit::script::Module; Module module = torch::jit::load("../../model_resnet_jit_cuda.pt"); module.to(at::kCUDA); //assert(module != nullptr); //std::cout << "ok\n"; //輸入圖像 auto image = cv::imread(image_path, cv::ImreadModes::IMREAD_COLOR); cv::cvtColor(image, image, cv::COLOR_BGR2RGB); cv::Mat image_transfomed; cv::resize(image, image_transfomed, cv::Size(224, 224)); // 轉(zhuǎn)換為T(mén)ensor torch::Tensor tensor_image = torch::from_blob(image_transfomed.data, { image_transfomed.rows, image_transfomed.cols,3 }, torch::kByte); tensor_image = tensor_image.permute({ 2,0,1 }); tensor_image = tensor_image.toType(torch::kFloat); tensor_image = tensor_image.div(255); tensor_image = tensor_image.unsqueeze(0); tensor_image = tensor_image.to(at::kCUDA); // 網(wǎng)絡(luò)前向計(jì)算 at::Tensor output = module.forward({ tensor_image }).toTensor(); //std::cout << "output:" << output << std::endl; auto prediction = output.argmax(1); std::cout << "prediction:" << prediction << std::endl; int maxk = 3; auto top3 = std::get<1>(output.topk(maxk, 1, true, true)); std::cout << "top3: " << top3 << '\n'; std::vector<int> res; for (auto i = 0; i < maxk; i++) { res.push_back(top3[0][i].item().toInt()); } for (auto i : res) { std::cout << i << " "; } std::cout << "\n"; system("pause"); }
得到最終輸出為921,可以看到和之前的python文件下輸出一致,這里還輸出了它的top前三,分別是921,787,490。
注意到,我的這兩個(gè)輸出相同的前提條件是:
1、確定加載的是由對(duì)應(yīng)python文件生成的模型!
2、輸入的圖片是同一張!并且在python下和C++下進(jìn)行了同樣的轉(zhuǎn)換,這里我在python下,將它進(jìn)行了RGB模型的轉(zhuǎn)換,resize(224, 224),并且將它的每一個(gè)元素值除以255.0,轉(zhuǎn)換到0~1之間(ToTensor()方法),最后維度轉(zhuǎn)換為1, 3, 224, 224,在C++中同樣需要將BGR模型轉(zhuǎn)化為RGB模型,進(jìn)行圖像縮放至224,224,并且將像素值除以255,將類(lèi)型轉(zhuǎn)化為float類(lèi)型,最后維度同樣轉(zhuǎn)換為1,3,224,224,再進(jìn)行網(wǎng)絡(luò)前向計(jì)算。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python tkinter Entry控件的焦點(diǎn)移動(dòng)操作
這篇文章主要介紹了python tkinter Entry控件的焦點(diǎn)移動(dòng)操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2021-05-05Python網(wǎng)絡(luò)編程之xmlrpc模塊
這篇文章介紹了Python網(wǎng)絡(luò)編程之xmlrpc模塊,文中通過(guò)示例代碼介紹的非常詳細(xì)。對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2022-05-05python操作mysql實(shí)現(xiàn)一個(gè)超市管理系統(tǒng)
超市管理系統(tǒng)有管理員和普通用戶兩條分支,只需掌握Python基礎(chǔ)語(yǔ)法,就可以完成這個(gè)項(xiàng)目,下面這篇文章主要給大家介紹了關(guān)于python操作mysql實(shí)現(xiàn)一個(gè)超市管理系統(tǒng)的相關(guān)資料,需要的朋友可以參考下2022-12-12unittest+coverage單元測(cè)試代碼覆蓋操作實(shí)例詳解
這篇文章主要為大家詳細(xì)介紹了unittest+coverage單元測(cè)試代碼覆蓋操作的實(shí)例,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-04-04Django中的Model操作表的實(shí)現(xiàn)
這篇文章主要介紹了Django中的Model操作表的實(shí)現(xiàn),小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2018-07-07淺談Python在pycharm中的調(diào)試(debug)
今天小編就為大家分享一篇淺談Python在pycharm中的調(diào)試(debug),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-11-11