2020年开篇,但是好像没啥特别的,参加个公司年会,时间就这么过去了,就这么迎来了2020年。当然,我们国家传统是过了春节才算是新年,所以原则上2020年还没来到,但是日历上已经变了,现在2020年1月1日,觉得还是要随便写点什么,记录点什么吧。

       一直用python去训练模型,考虑用点不一样的技术,用java去训练一个模型吧,这样也方便和java微服务以及大数据集成在一起了,让模型更好的服务于生产系统当中。其实java也提供了很多机器学习库的,像h20,Weka,Mahout等(虽然我都没用过),Tensorflow也提供了java的api。但是像我这样全技术栈的算法并不多(自吹一下),所以在具体工作配合时候难免会遇到技术栈切换的成本,当然作为工程师,掌握一门工程性比较强的语言也是还是有必要的,只是应用方向不同啦。

       xgboost也提供java的很好的实现,虽然可能不如python用起来顺手,毕竟numpy,pandas这些库天生就是为处理结构化数据而打造的。其实个人之前也分析过,numpy的效率其实很高,但是怎么集成到以java为主流的应用系统是个问题。

  •  创建maven项目

需要三个jar包,jar包名如图所示。

  • 准备数据
import numpy as np
from sklearn.datasets import load_iris

data = load_iris()['data']
target = load_iris()['target']

x_data = np.hstack((target.reshape(-1, 1), data))
x_data = x_data [x_data [:, 0] <= 1]

train_index = [i for i in np.random.randint(0, len(x_data), int(len(x_data) * 0.8))]
test_index = list(set([i for i in range(len(x_data))]) - set(train_index))

train = x_data[train_index]
test = x_data[test_index]

with open("src/main/resources/train.txt", "w") as f:
    for line in train:
        f.write(",".join([str(x) for x in line]) + "\n")

with open("src/main/resources/test.txt", "w") as f:
    for line in test:
        f.write(",".join([str(x) for x in line]) + "\n")
  • 训练并保存模型
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;

import java.util.HashMap;
import java.util.Map;


public class XgboostTrain {

    private static DMatrix trainMat = null;
    private static DMatrix testMat = null;

    public static void main(String [] args) throws XGBoostError {

        try {
            trainMat = new DMatrix("src/main/resources/train.txt");
        } catch (XGBoostError xgBoostError) {
            xgBoostError.printStackTrace();
        }
        try {
            testMat = new DMatrix("src/main/resources/test.txt");
        } catch (XGBoostError xgBoostError) {
            xgBoostError.printStackTrace();
        }

        Map<String, Object> params = new HashMap<String, Object>() {
            {
                put("eta", 1.0);
                put("max_depth", 3);
                put("objective", "binary:logistic");
                put("eval_metric", "logloss");
            }
        };

        Map<String, DMatrix> watches = new HashMap<String, DMatrix>() {
            {
                put("train",trainMat);
                put("test", testMat);
            }
        };
        int nround = 10;
        try {
            Booster booster = XGBoost.train(trainMat, params, nround, watches, null, null);
            booster.saveModel("src/main/resources/model.bin");
        } catch (XGBoostError xgBoostError) {
            xgBoostError.printStackTrace();
        }
    }

}

  • 进行预测
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;

public class XgboostPredict {

    public static void main(String [] args) throws XGBoostError {
        float[] data = new float[] {5.2f, 3.5f, 1.5f, 0.2f};
        int nrow = 1;
        int ncol = 4;
        float missing = 0.0f;
        DMatrix dmat = new DMatrix(data, nrow, ncol, missing);
        Booster booster = XGBoost.loadModel("src/main/resources/model.bin");
        float[][] predicts = booster.predict(dmat);
        for (float[] array: predicts){
            for (float values: array) {
                System.out.print(values + " ");
            }
            System.out.println();
        }
    }

}

       这样就简单的实现了,xgboost模型的训练和预测,其实直接使用xgboost模型可以直接免去复杂的特征工程操作,不太需要对特征再进行归一化和one-hot编码操作,这样的话,数据准备好,启动一个SpringBoot基于Java实现的模型服务就搭建起来了。

Logo

权威|前沿|技术|干货|国内首个API全生命周期开发者社区

更多推荐