pytorch 移動(dòng)端部署之helloworld的使用
開始
安裝Androidstudio 4.1
克隆此項(xiàng)目
git clone https://github.com/pytorch/android-demo-app.git
使用androidstudio 打開 android-demo-app 中的HelloWordApp
打開之后androidstudio 會(huì)自動(dòng)創(chuàng)建依賴 只需要等待即可
這個(gè)代碼已經(jīng)是官方寫好的故而
開一下官方教程中的代碼都在什么位置
這句
repositories { jcenter() } dependencies { implementation 'org.pytorch:pytorch_android:1.4.0' implementation 'org.pytorch:pytorch_android_torchvision:1.4.0' }
位置
HelloWorldApp\app\build.gradle
里面的全部代碼
apply plugin: 'com.android.application' repositories { jcenter() } android { compileSdkVersion 28 buildToolsVersion "29.0.2" defaultConfig { applicationId "org.pytorch.helloworld" minSdkVersion 21 targetSdkVersion 28 versionCode 1 versionName "1.0" } buildTypes { release { minifyEnabled false } } } dependencies { implementation 'androidx.appcompat:appcompat:1.1.0' implementation 'org.pytorch:pytorch_android:1.4.0' implementation 'org.pytorch:pytorch_android_torchvision:1.4.0' }
這句
Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg")); Module module = Module.load(assetFilePath(this, "model.pt")); Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap, TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB); Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor(); float[] scores = outputTensor.getDataAsFloatArray(); float maxScore = -Float.MAX_VALUE; int maxScoreIdx = -1; for (int i = 0; i < scores.length; i++) { if (scores[i] > maxScore) { maxScore = scores[i]; maxScoreIdx = i; } } String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
都在這里
HelloWorldApp\app\src\main\java\org\pytorch\helloworld\MainActivity.java
全部代碼
package org.pytorch.helloworld; import android.content.Context; import android.graphics.Bitmap; import android.graphics.BitmapFactory; import android.os.Bundle; import android.util.Log; import android.widget.ImageView; import android.widget.TextView; import org.pytorch.IValue; import org.pytorch.Module; import org.pytorch.Tensor; import org.pytorch.torchvision.TensorImageUtils; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import androidx.appcompat.app.AppCompatActivity; public class MainActivity extends AppCompatActivity { @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); Bitmap bitmap = null; Module module = null; try { // creating bitmap from packaged into app android asset 'image.jpg', // app/src/main/assets/image.jpg bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg")); // loading serialized torchscript module from packaged into app android asset model.pt, // app/src/model/assets/model.pt module = Module.load(assetFilePath(this, "model.pt")); } catch (IOException e) { Log.e("PytorchHelloWorld", "Error reading assets", e); finish(); } // showing image on UI ImageView imageView = findViewById(R.id.image); imageView.setImageBitmap(bitmap); // preparing input tensor final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap, TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB); // running the model final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor(); // getting tensor content as java array of floats final float[] scores = outputTensor.getDataAsFloatArray(); // searching for the index with maximum score float maxScore = -Float.MAX_VALUE; int maxScoreIdx = -1; for (int i = 0; i < scores.length; i++) { if (scores[i] > maxScore) { maxScore = scores[i]; maxScoreIdx = i; } } String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx]; // showing className on UI TextView textView = findViewById(R.id.text); textView.setText(className); } /** * Copies specified asset to the file in /files app directory and returns this file absolute path. * * @return absolute file path */ public static String assetFilePath(Context context, String assetName) throws IOException { File file = new File(context.getFilesDir(), assetName); if (file.exists() && file.length() > 0) { return file.getAbsolutePath(); } try (InputStream is = context.getAssets().open(assetName)) { try (OutputStream os = new FileOutputStream(file)) { byte[] buffer = new byte[4 * 1024]; int read; while ((read = is.read(buffer)) != -1) { os.write(buffer, 0, read); } os.flush(); } return file.getAbsolutePath(); } } }
在Build 中選擇Build Bundile APK 的 Build APK 就可以了
生成的apk 在
HelloWorldApp\app\build\outputs\apk\debug
中 這個(gè)是可以直接安裝的
安裝后是一個(gè)固定的照片 就是檢測了一個(gè)固定的照片
這是一個(gè)例子如果你只是想測試自己的模型調(diào)用能不能成功這個(gè)項(xiàng)目改改模型和模型加載即可
這個(gè)項(xiàng)目模型是一個(gè)resnet18 接著我們將其替換為resnet50
模型轉(zhuǎn)換代碼如下
import torch import torchvision.models as models from PIL import Image import numpy as np image = Image.open("test.jpg") #圖片發(fā)在了build文件夾下 image = image.resize((224, 224),Image.ANTIALIAS) image = np.asarray(image) image = image / 255 image = torch.Tensor(image).unsqueeze_(dim=0) image = image.permute((0, 3, 1, 2)).float() model = models.resnet50(pretrained=True) model = model.eval() resnet = torch.jit.trace(model, torch.rand(1,3,224,224)) # output=resnet(torch.ones(1,3,224,224)) output = resnet(image) max_index = torch.max(output, 1)[1].item() print(max_index) # ImageNet1000類的類別序 resnet.save('model.pt') if __name__ == '__main__': pass
將這個(gè)保存的模型 覆蓋掉下面路徑中的模型
(在覆蓋之前最好備份一個(gè)原來的模型,這里我們選擇修改原來模型的名字為model_1.pt)
HelloWorldApp\app\src\main\assets\model.pt
成功覆蓋后再一次執(zhí)行打包操作(在Build 中選擇Build Bundile APK 的 Build APK 就可以了
生成的apk 在
HelloWorldApp\app\build\outputs\apk\debug)
而后打開文件發(fā)現(xiàn)一個(gè)123M的apk 之前的apk是73M
安裝 并且測試
完美打開也就是說一切resnet 系列的 都可以通過這個(gè) 項(xiàng)目進(jìn)行演化出來
到此這篇關(guān)于pytorch 移動(dòng)端部署之helloworld的使用的文章就介紹到這了,更多相關(guān)pytorch 移動(dòng)端部署helloworld內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python 實(shí)現(xiàn)訓(xùn)練集、測試集隨機(jī)劃分
今天小編就為大家分享一篇Python 實(shí)現(xiàn)訓(xùn)練集、測試集隨機(jī)劃分,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-01-01python 獲取url中的參數(shù)列表實(shí)例
今天小編就為大家分享一篇python 獲取url中的參數(shù)列表實(shí)例,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-12-1268行Python代碼實(shí)現(xiàn)帶難度升級(jí)的貪吃蛇
本文主要介紹了Python代碼實(shí)現(xiàn)帶難度升級(jí)的貪吃蛇,文中通過示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2022-01-01快速進(jìn)修Python指南之面向?qū)ο筮M(jìn)階
這篇文章主要為大家介紹了Java開發(fā)者快速進(jìn)修Python指南之面向?qū)ο筮M(jìn)階,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-12-12Python文本統(tǒng)計(jì)功能之西游記用字統(tǒng)計(jì)操作示例
這篇文章主要介紹了Python文本統(tǒng)計(jì)功能之西游記用字統(tǒng)計(jì)操作,結(jié)合實(shí)例形式分析了Python文本讀取、遍歷、統(tǒng)計(jì)等相關(guān)操作技巧,需要的朋友可以參考下2018-05-05