如何將pytorch模型部署到安卓上的方法示例
這篇文章演示如何將訓(xùn)練好的pytorch模型部署到安卓設(shè)備上。我也是剛開(kāi)始學(xué)安卓,代碼寫(xiě)的簡(jiǎn)單。
環(huán)境:
pytorch版本:1.10.0
模型轉(zhuǎn)化
pytorch_android支持的模型是.pt模型,我們訓(xùn)練出來(lái)的模型是.pth。所以需要轉(zhuǎn)化才可以用。先看官網(wǎng)上給的轉(zhuǎn)化方式:
import torch import torchvision from torch.utils.mobile_optimizer import optimize_for_mobile model = torchvision.models.mobilenet_v3_small(pretrained=True) model.eval() example = torch.rand(1, 3, 224, 224) traced_script_module = torch.jit.trace(model, example) optimized_traced_model = optimize_for_mobile(traced_script_module) optimized_traced_model._save_for_lite_interpreter("app/src/main/assets/model.ptl")
這個(gè)模型在安卓對(duì)應(yīng)的包:
repositories { jcenter() } dependencies { implementation 'org.pytorch:pytorch_android_lite:1.9.0' implementation 'org.pytorch:pytorch_android_torchvision:1.9.0' }
注:pytorch_android_lite版本和轉(zhuǎn)化模型用的版本要一致,不一致就會(huì)報(bào)各種錯(cuò)誤。
目前用這種方法有點(diǎn)問(wèn)題,我采用的另一種方法。
轉(zhuǎn)化代碼如下:
import torch import torch.utils.data.distributed # pytorch環(huán)境中 model_pth = 'model_31_0.96.pth' #模型的參數(shù)文件 mobile_pt ='model.pt' # 將模型保存為Android可以調(diào)用的文件 model = torch.load(model_pth) model.eval() # 模型設(shè)為評(píng)估模式 device = torch.device('cpu') model.to(device) # 1張3通道224*224的圖片 input_tensor = torch.rand(1, 3, 224, 224) # 設(shè)定輸入數(shù)據(jù)格式 mobile = torch.jit.trace(model, input_tensor) # 模型轉(zhuǎn)化 mobile.save(mobile_pt) # 保存文件
對(duì)應(yīng)的包:
//pytorch implementation 'org.pytorch:pytorch_android:1.10.0' implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
定義模型文件和轉(zhuǎn)化后的文件路徑。
load模型。這里要注意,如果保存模型
torch.save(model,'models.pth')
加載模型則是
model=torch.load('models.pth')
如果保存模型是
torch.save(model.state_dict(),"models.pth")
加載模型則是
model.load_state_dict(torch.load('models.pth'))
定義輸入數(shù)據(jù)格式。
模型轉(zhuǎn)化,然后再保存模型。
安卓部署
新建項(xiàng)目
新建安卓項(xiàng)目,選擇Empy Activity,然后選擇Next
然后,填寫(xiě)項(xiàng)目信息,選擇安卓版本,我用的4.4,點(diǎn)擊完成
導(dǎo)入包
導(dǎo)入pytorch_android的包
//pytorch implementation 'org.pytorch:pytorch_android:1.10.0' implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
如果有參數(shù)報(bào)錯(cuò)請(qǐng)參照我的完整的配置,代碼如下:
plugins { id 'com.android.application' } android { compileSdk 32 defaultConfig { applicationId "com.example.myapplication" minSdk 21 targetSdk 32 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" } buildTypes { release { minifyEnabled false proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' } } compileOptions { sourceCompatibility JavaVersion.VERSION_1_8 targetCompatibility JavaVersion.VERSION_1_8 } } dependencies { implementation 'androidx.appcompat:appcompat:1.3.0' implementation 'com.google.android.material:material:1.4.0' implementation 'androidx.constraintlayout:constraintlayout:2.0.4' testImplementation 'junit:junit:4.13.2' androidTestImplementation 'androidx.test.ext:junit:1.1.3' androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0' //pytorch implementation 'org.pytorch:pytorch_android:1.10.0' implementation 'org.pytorch:pytorch_android_torchvision:1.10.0' }
頁(yè)面文件
頁(yè)面的配置如下:
<?xml version="1.0" encoding="utf-8"?> <FrameLayout xmlns:android="http://schemas.android.com/apk/res/android" xmlns:tools="http://schemas.android.com/tools" android:layout_width="match_parent" android:layout_height="match_parent" tools:context=".MainActivity"> <ImageView android:id="@+id/image" android:layout_width="match_parent" android:layout_height="match_parent" android:scaleType="fitCenter" /> <TextView android:id="@+id/text" android:layout_width="match_parent" android:layout_height="wrap_content" android:layout_gravity="top" android:textSize="24sp" android:background="#80000000" android:textColor="@android:color/holo_red_light" /> </FrameLayout>
這個(gè)頁(yè)面只有兩個(gè)空間,一個(gè)展示圖片,一個(gè)顯示文字。
模型推理
新增assets文件夾,然后將轉(zhuǎn)化的模型和待測(cè)試的圖片放進(jìn)去。
新增ImageNetClasses類(lèi),這個(gè)類(lèi)存放類(lèi)別名字。
代碼如下:
package com.example.myapplication; public class ImageNetClasses { public static String[] IMAGENET_CLASSES = new String[]{ "Black-grass", "Charlock", "Cleavers", "Common Chickweed", "Common wheat", "Fat Hen", "Loose Silky-bent", "Maize", "Scentless Mayweed", "Shepherds Purse", "Small-flowered Cranesbill", "Sugar beet", }; }
在MainActivity類(lèi)中,增加模型推理的邏輯。完成代碼如下:
package com.example.myapplication; import android.content.Context; import android.graphics.Bitmap; import android.graphics.BitmapFactory; import android.os.Bundle; import android.util.Log; import android.widget.ImageView; import android.widget.TextView; import org.pytorch.IValue; import org.pytorch.Module; import org.pytorch.Tensor; import org.pytorch.torchvision.TensorImageUtils; import org.pytorch.MemoryFormat; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import androidx.appcompat.app.AppCompatActivity; public class MainActivity extends AppCompatActivity { @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); Bitmap bitmap = null; Module module = null; try { // creating bitmap from packaged into app android asset 'image.jpg', // app/src/main/assets/image.jpg bitmap = BitmapFactory.decodeStream(getAssets().open("1.png")); // loading serialized torchscript module from packaged into app android asset model.pt, // app/src/model/assets/model.pt module = Module.load(assetFilePath(this, "models.pt")); } catch (IOException e) { Log.e("PytorchHelloWorld", "Error reading assets", e); finish(); } // showing image on UI ImageView imageView = findViewById(R.id.image); imageView.setImageBitmap(bitmap); // preparing input tensor final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap, TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST); // running the model final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor(); // getting tensor content as java array of floats final float[] scores = outputTensor.getDataAsFloatArray(); // searching for the index with maximum score float maxScore = -Float.MAX_VALUE; int maxScoreIdx = -1; for (int i = 0; i < scores.length; i++) { if (scores[i] > maxScore) { maxScore = scores[i]; maxScoreIdx = i; } } System.out.println(maxScoreIdx); String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx]; // showing className on UI TextView textView = findViewById(R.id.text); textView.setText(className); } /** * Copies specified asset to the file in /files app directory and returns this file absolute path. * * @return absolute file path */ public static String assetFilePath(Context context, String assetName) throws IOException { File file = new File(context.getFilesDir(), assetName); if (file.exists() && file.length() > 0) { return file.getAbsolutePath(); } try (InputStream is = context.getAssets().open(assetName)) { try (OutputStream os = new FileOutputStream(file)) { byte[] buffer = new byte[4 * 1024]; int read; while ((read = is.read(buffer)) != -1) { os.write(buffer, 0, read); } os.flush(); } return file.getAbsolutePath(); } } }
然后運(yùn)行。
到此這篇關(guān)于如何將pytorch模型部署到安卓上的方法示例的文章就介紹到這了,更多相關(guān)pytorch模型部署到安卓?jī)?nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
scrapy框架攜帶cookie訪問(wèn)淘寶購(gòu)物車(chē)功能的實(shí)現(xiàn)代碼
這篇文章主要介紹了scrapy框架攜帶cookie訪問(wèn)淘寶購(gòu)物車(chē),本文通過(guò)實(shí)例代碼圖文詳解給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-07-07開(kāi)源軟件包和環(huán)境管理系統(tǒng)Anaconda的安裝使用
Anaconda是一個(gè)用于科學(xué)計(jì)算的Python發(fā)行版,支持 Linux, Mac, Windows系統(tǒng),提供了包管理與環(huán)境管理的功能,可以很方便地解決多版本python并存、切換以及各種第三方包安裝問(wèn)題。2017-09-09Python使用pickle進(jìn)行序列化和反序列化的示例代碼
這篇文章主要介紹了Python使用pickle進(jìn)行序列化和反序列化,幫助大家更好的理解和使用python的pickle庫(kù),感興趣的朋友可以了解下2020-09-09python?pip安裝的包目錄(site-packages目錄的位置)
這篇文章主要介紹了python?pip安裝的包放在哪里(site-packages目錄的位置),本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2023-03-03k-means 聚類(lèi)算法與Python實(shí)現(xiàn)代碼
這篇文章主要介紹了k-means 聚類(lèi)算法與Python實(shí)現(xiàn)代碼,本文通過(guò)示例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-06-06