深度学习模型 Pytorch Densenet 安卓移动端调用

安卓app开发调用深度学习模型

最近使用安卓调用了利用迁移学习训练的Pytorch Densenet121模型,主要是实现图像的分类,在此记录并分享一下。纯小白,第一次写博客,如有错误希望能指正,互相学习。
pytorch移动端可以参考 官方Demo

环境准备

网上都有相应的环境配置教程

  • Anacoda 我的是4.8.2 (查看版本方法:终端输入 conda --version命令)
  • python 我的是3.7.0
  • pytorch 1.4(1.3开始支持移动端,我的是1.4)
  • Java (前排提醒,安卓部分使用Java编写)
  • densenet121模型

模型转换

平时使用的后缀名为 “.pth” 的模型文件不能直接用于安卓端使用,需要转换格式为 “.pt” 。
Pytorch官方文档地址

  • 首先需要有模型文件,可以自己下载官方的模型,也可以用自己已经训练好的。
  • 在模型所在目录新建一个python文件,名字随便起(要合法),下面是desenet121模型的转换代码,resnet模型的转换可以参考这篇 博客
import torch
from torchvision import models

# 使用预训练的模型,不同类型的模型有不一样的函数,这里用的是densenet121模型
model = models.densenet121(pretrained=True)
# 如果只需要网络结构,不需要用预训练模型的参数来初始化就执行下面这行
# model = models.densenet121(pretrained=False)

#分类器,需要根据自己模型的实际情况使用,比如我这里最后是将目标图片分成A,B,C,D四类
model.classifier = torch.nn.Sequential(
    torch.nn.Linear(1024, 512),
    torch.nn.Dropout(0.5),  # drop 50% neurons
    torch.nn.ReLU(),
    torch.nn.Linear(512, 256),
    torch.nn.Dropout(0.2),  # drop 50% neurons
    torch.nn.ReLU(),
    torch.nn.Linear(256, 4),
)

# 加载模型,参数为自己的模型的路径
model.load_state_dict(torch.load('./densenet121模型.pth')) # 加载参数
model.eval() # 模型设为评估模式

# 1张3通道400*400的图片,这里与后面的图片处理对应
input_tensor = torch.rand(1, 3, 400, 400) # 设定输入数据格式
mobile = torch.jit.trace(model, input_tensor) # 模型转化(官方文档https://pytorch.org/docs/master/generated/torch.jit.trace.html)
mobile.save('densenet121.pt') # 保存为pt文件
  • 运行上面的代码就可以得到 pt 格式的模型了

安卓配置

  • 在安卓项目目录中的 app/build.gradle 中相应位置添加以下内容:
repositories {
    jcenter()
}

dependencies {
    implementation 'org.pytorch:pytorch_android:1.4.0'
    implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}
  • 然后就等待加载gradle(时间可能会比较长)
  • 将转换后的 pt格式模型放到项目 app/src/main/assets 下,没有这个文件夹的话可以自己创建。
  • 编写分类标签类,用于存储类别标签。在任意项目package下创建一个MyLabel.java(名字可以自己随便取)。
package com.example.mylabel;

public class MyLabel {
    public static String[] LABEL = new String[]{
        "A", "B", "C", "D"
    };
}
  • 编写调用模型的文件,这里我使用的 java语言,在任意项目package下创建一个Test.java文件(名字自己随便取),用于调用、运行模型。
package com.example.test; 

import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.util.Log;
import com.example.myclass.MyLabel; //导入分类标签类
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;


public class Test {
    private Module module;
    private String modulePath;
    private String result;  //分类结果A,B,C,D
    private float[] resultArray;  //分类数组
    private Bitmap bitmaps;
    /**
     * 构造函数
     * @param modulePath 模型的绝对路径
     * @param bitmap 评分的图片以位图的形式传入
     */
    public Test(String modulePath, Bitmap bitmap) {
        this.modulePath = modulePath;
        this.bitmaps = resize(bitmap);
        this.module = Module.load(modulePath);
        this.resultArray = new float[]{0, 0, 0, 0};
    }

    /**
     * 重新设置图片大小并进行一定的变换与裁剪,这里返回图片的大小与之前模型转换时设置的输入数据格式相对应
     * @param bitmap 输入的图像
     * @return Bitmap 处理后的图像
     */
    private Bitmap resize(Bitmap bitmap) {
        int H = bitmap.getHeight();
        int W = bitmap.getWidth();
        float scaleH;
        float scaleW;
        float IMAGE_SIZE = 400.0f;  //图片大小设置
        //锁定高宽比缩放
        if(H>W){
            scaleH = IMAGE_SIZE * H / W;
            scaleW = IMAGE_SIZE;
        }else{
            scaleH = IMAGE_SIZE;
            scaleW = IMAGE_SIZE * W / H;
        }
        Matrix matrix = new Matrix();
        matrix.postScale(scaleW/W,scaleH/H);
        //创建一个新的bitmap
        Bitmap temp = Bitmap.createBitmap(bitmap,0,0,W,H,matrix,true);
        //在中间切割图像 使其成 IMAGE_SIZE * IMAGE_SIZE
        Bitmap resBitmap = Bitmap.createBitmap(temp,(int)((temp.getWidth()-IMAGE_SIZE)/2),
                (int)((temp.getHeight()-IMAGE_SIZE)/2),(int)IMAGE_SIZE,(int)IMAGE_SIZE);
        return resBitmap;
    }

    /**
     * 运行模型,开始分类
     */
    public void runModel(){
    	 //张量规整,根据自己的情况设置值
        float[] mean = {0.5f,0.5f,0.5f}; 
        float[] std = {0.5f,0.5f,0.5f};
        //输入张量
        Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmaps,mean,std);
        //将 inputTensor 放到模型中运行,通过 module.forward() 得到一个 outputTensor。
        Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
        //获得输出结果(长度为4,分别代表A,B,C,D,与之前转换时分类器的设置有关)
        resultArray = outputTensor.getDataAsFloatArray();
        setResult();
    }

    private void setResult() {
        float max = resultArray[0];
        int index = 0;  //最大值下标
        for(int i=0;i<resultArray.length;i++){
            if(Float.compare(max,resultArray[i])<0){
                max = resultArray[i];
                index = i;
            }
        }
        //根据最大值的下标获得对应的分类结果
        this.result = MyLebel.Lebel[index];
    }

    /**
     * 获取图片评分
     * @return String
     */
    public String getResult() {
        return result;
    }
	/**
	* 获取模型路径
	* @return String
	* /
    public String getModulePath() {
        return modulePath;
    }
}
  • 编写获取assets目录下文件绝对路径的类(因为模型文件存放在该目录下),新建工具类FileUtil.java
package com.example.util;

import android.content.Context;
import android.os.Build;
import android.util.Log;
import androidx.annotation.RequiresApi;

import java.io.*;

public class FileUtil {
    /**
     * 将指定的资产复制到 /files app目录中的文件,并返回此文件的绝对路径。
     * @param context 上下文
     * @param assetName 目标文件名
     * @return String 绝对路径
     */
    @RequiresApi(api = Build.VERSION_CODES.KITKAT)
    public static String assetFilePath(Context context, String assetName) throws IOException {
        File file = new File(context.getFilesDir(), assetName);
        if (file.exists() && file.length() > 0) {
            return file.getAbsolutePath();
        }

        try (InputStream is = context.getAssets().open(assetName)) {
            try (OutputStream os = new FileOutputStream(file)) {
                byte[] buffer = new byte[4 * 1024];
                int read;
                while ((read = is.read(buffer)) != -1) {
                    os.write(buffer, 0, read);
                }
                os.flush();
            }
            return file.getAbsolutePath();
        }
    }
}
  • 最后在需要运行模型的地方,(不讲究的话,比如MainActivity.java的onCreat()方法),添加下面的代码就可以了
try {
	//将图片用Bitmap类包装
     Bitmap bitmap = BitmapFactory.decodeFile("需要分类的图片路径");
     
     //获取模型路径,下面函数参数中getApplicationContext()获取的是上下文,在 某某activity中可以直接用 某某activity.this代替
     String modulePath = FileUtil.assetFilePath(getApplicationContext(), "densenet121.pt");
     
     //获取模型调用类实例
     Test test = new test(modulePath, bitmap);
     //运行模型
     test.runModel();
     //获取结果
     String label = test.getResult();
     Log.d("分类结果",label);
 } catch (IOException e) {
     e.printStackTrace();
 }

参考资料

[1] https://blog.csdn.net/y_dd6011/article/details/104751029
[2] https://pytorch.org/docs/master/generated/torch.jit.trace.html
[3] https://pytorch.org/mobile/android/

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐