深度学习模型 Pytorch Densenet 安卓移动端调用
深度学习模型 Pytorch Densenet 安卓移动端开发安卓app开发调用深度学习模型新的改变功能快捷键合理的创建标题,有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中、居左、居右SmartyPants创建一个自定义列表如何创建一个注脚注释也是必不可少的KaTeX数学公式新的甘特图功能,丰富你的文章UML 图表FLowcha
·
深度学习模型 Pytorch Densenet 安卓移动端调用
安卓app开发调用深度学习模型
最近使用安卓调用了利用迁移学习训练的Pytorch Densenet121模型,主要是实现图像的分类,在此记录并分享一下。纯小白,第一次写博客,如有错误希望能指正,互相学习。
pytorch移动端可以参考 官方Demo
环境准备
网上都有相应的环境配置教程
- Anacoda 我的是4.8.2 (查看版本方法:终端输入 conda --version命令)
- python 我的是3.7.0
- pytorch 1.4(1.3开始支持移动端,我的是1.4)
- Java (前排提醒,安卓部分使用Java编写)
- densenet121模型
模型转换
平时使用的后缀名为 “.pth” 的模型文件不能直接用于安卓端使用,需要转换格式为 “.pt” 。
Pytorch官方文档地址
- 首先需要有模型文件,可以自己下载官方的模型,也可以用自己已经训练好的。
- 在模型所在目录新建一个python文件,名字随便起(要合法),下面是desenet121模型的转换代码,resnet模型的转换可以参考这篇 博客
import torch
from torchvision import models
# 使用预训练的模型,不同类型的模型有不一样的函数,这里用的是densenet121模型
model = models.densenet121(pretrained=True)
# 如果只需要网络结构,不需要用预训练模型的参数来初始化就执行下面这行
# model = models.densenet121(pretrained=False)
#分类器,需要根据自己模型的实际情况使用,比如我这里最后是将目标图片分成A,B,C,D四类
model.classifier = torch.nn.Sequential(
torch.nn.Linear(1024, 512),
torch.nn.Dropout(0.5), # drop 50% neurons
torch.nn.ReLU(),
torch.nn.Linear(512, 256),
torch.nn.Dropout(0.2), # drop 50% neurons
torch.nn.ReLU(),
torch.nn.Linear(256, 4),
)
# 加载模型,参数为自己的模型的路径
model.load_state_dict(torch.load('./densenet121模型.pth')) # 加载参数
model.eval() # 模型设为评估模式
# 1张3通道400*400的图片,这里与后面的图片处理对应
input_tensor = torch.rand(1, 3, 400, 400) # 设定输入数据格式
mobile = torch.jit.trace(model, input_tensor) # 模型转化(官方文档https://pytorch.org/docs/master/generated/torch.jit.trace.html)
mobile.save('densenet121.pt') # 保存为pt文件
- 运行上面的代码就可以得到 pt 格式的模型了
安卓配置
- 在安卓项目目录中的 app/build.gradle 中相应位置添加以下内容:
repositories {
jcenter()
}
dependencies {
implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}
- 然后就等待加载gradle(时间可能会比较长)
- 将转换后的 pt格式模型放到项目 app/src/main/assets 下,没有这个文件夹的话可以自己创建。
- 编写分类标签类,用于存储类别标签。在任意项目package下创建一个MyLabel.java(名字可以自己随便取)。
package com.example.mylabel;
public class MyLabel {
public static String[] LABEL = new String[]{
"A", "B", "C", "D"
};
}
- 编写调用模型的文件,这里我使用的 java语言,在任意项目package下创建一个Test.java文件(名字自己随便取),用于调用、运行模型。
package com.example.test;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.util.Log;
import com.example.myclass.MyLabel; //导入分类标签类
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
public class Test {
private Module module;
private String modulePath;
private String result; //分类结果A,B,C,D
private float[] resultArray; //分类数组
private Bitmap bitmaps;
/**
* 构造函数
* @param modulePath 模型的绝对路径
* @param bitmap 评分的图片以位图的形式传入
*/
public Test(String modulePath, Bitmap bitmap) {
this.modulePath = modulePath;
this.bitmaps = resize(bitmap);
this.module = Module.load(modulePath);
this.resultArray = new float[]{0, 0, 0, 0};
}
/**
* 重新设置图片大小并进行一定的变换与裁剪,这里返回图片的大小与之前模型转换时设置的输入数据格式相对应
* @param bitmap 输入的图像
* @return Bitmap 处理后的图像
*/
private Bitmap resize(Bitmap bitmap) {
int H = bitmap.getHeight();
int W = bitmap.getWidth();
float scaleH;
float scaleW;
float IMAGE_SIZE = 400.0f; //图片大小设置
//锁定高宽比缩放
if(H>W){
scaleH = IMAGE_SIZE * H / W;
scaleW = IMAGE_SIZE;
}else{
scaleH = IMAGE_SIZE;
scaleW = IMAGE_SIZE * W / H;
}
Matrix matrix = new Matrix();
matrix.postScale(scaleW/W,scaleH/H);
//创建一个新的bitmap
Bitmap temp = Bitmap.createBitmap(bitmap,0,0,W,H,matrix,true);
//在中间切割图像 使其成 IMAGE_SIZE * IMAGE_SIZE
Bitmap resBitmap = Bitmap.createBitmap(temp,(int)((temp.getWidth()-IMAGE_SIZE)/2),
(int)((temp.getHeight()-IMAGE_SIZE)/2),(int)IMAGE_SIZE,(int)IMAGE_SIZE);
return resBitmap;
}
/**
* 运行模型,开始分类
*/
public void runModel(){
//张量规整,根据自己的情况设置值
float[] mean = {0.5f,0.5f,0.5f};
float[] std = {0.5f,0.5f,0.5f};
//输入张量
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmaps,mean,std);
//将 inputTensor 放到模型中运行,通过 module.forward() 得到一个 outputTensor。
Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
//获得输出结果(长度为4,分别代表A,B,C,D,与之前转换时分类器的设置有关)
resultArray = outputTensor.getDataAsFloatArray();
setResult();
}
private void setResult() {
float max = resultArray[0];
int index = 0; //最大值下标
for(int i=0;i<resultArray.length;i++){
if(Float.compare(max,resultArray[i])<0){
max = resultArray[i];
index = i;
}
}
//根据最大值的下标获得对应的分类结果
this.result = MyLebel.Lebel[index];
}
/**
* 获取图片评分
* @return String
*/
public String getResult() {
return result;
}
/**
* 获取模型路径
* @return String
* /
public String getModulePath() {
return modulePath;
}
}
- 编写获取assets目录下文件绝对路径的类(因为模型文件存放在该目录下),新建工具类FileUtil.java
package com.example.util;
import android.content.Context;
import android.os.Build;
import android.util.Log;
import androidx.annotation.RequiresApi;
import java.io.*;
public class FileUtil {
/**
* 将指定的资产复制到 /files app目录中的文件,并返回此文件的绝对路径。
* @param context 上下文
* @param assetName 目标文件名
* @return String 绝对路径
*/
@RequiresApi(api = Build.VERSION_CODES.KITKAT)
public static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}
}
- 最后在需要运行模型的地方,(不讲究的话,比如MainActivity.java的onCreat()方法),添加下面的代码就可以了
try {
//将图片用Bitmap类包装
Bitmap bitmap = BitmapFactory.decodeFile("需要分类的图片路径");
//获取模型路径,下面函数参数中getApplicationContext()获取的是上下文,在 某某activity中可以直接用 某某activity.this代替
String modulePath = FileUtil.assetFilePath(getApplicationContext(), "densenet121.pt");
//获取模型调用类实例
Test test = new test(modulePath, bitmap);
//运行模型
test.runModel();
//获取结果
String label = test.getResult();
Log.d("分类结果",label);
} catch (IOException e) {
e.printStackTrace();
}
参考资料
[1] https://blog.csdn.net/y_dd6011/article/details/104751029
[2] https://pytorch.org/docs/master/generated/torch.jit.trace.html
[3] https://pytorch.org/mobile/android/
更多推荐
已为社区贡献1条内容
所有评论(0)