JAVA训练XGBOOST
2020年开篇,但是好像没啥特别的,参加个公司年会,时间就这么过去了,就这么迎来了2020年。当然,我们国家传统是过了春节才算是新年,所以原则上2020年还没来到,但是日历上已经变了,现在2020年1月1日,觉得还是要随便写点什么,记录点什么吧。一直用python去训练模型,考虑用点不一样的技术,用java去训练一个模型吧,这样也方便和java微服务以及大数据集成在...
2020年开篇,但是好像没啥特别的,参加个公司年会,时间就这么过去了,就这么迎来了2020年。当然,我们国家传统是过了春节才算是新年,所以原则上2020年还没来到,但是日历上已经变了,现在2020年1月1日,觉得还是要随便写点什么,记录点什么吧。
一直用python去训练模型,考虑用点不一样的技术,用java去训练一个模型吧,这样也方便和java微服务以及大数据集成在一起了,让模型更好的服务于生产系统当中。其实java也提供了很多机器学习库的,像h20,Weka,Mahout等(虽然我都没用过),Tensorflow也提供了java的api。但是像我这样全技术栈的算法并不多(自吹一下),所以在具体工作配合时候难免会遇到技术栈切换的成本,当然作为工程师,掌握一门工程性比较强的语言也是还是有必要的,只是应用方向不同啦。
xgboost也提供java的很好的实现,虽然可能不如python用起来顺手,毕竟numpy,pandas这些库天生就是为处理结构化数据而打造的。其实个人之前也分析过,numpy的效率其实很高,但是怎么集成到以java为主流的应用系统是个问题。
- 创建maven项目
需要三个jar包,jar包名如图所示。
- 准备数据
import numpy as np
from sklearn.datasets import load_iris
data = load_iris()['data']
target = load_iris()['target']
x_data = np.hstack((target.reshape(-1, 1), data))
x_data = x_data [x_data [:, 0] <= 1]
train_index = [i for i in np.random.randint(0, len(x_data), int(len(x_data) * 0.8))]
test_index = list(set([i for i in range(len(x_data))]) - set(train_index))
train = x_data[train_index]
test = x_data[test_index]
with open("src/main/resources/train.txt", "w") as f:
for line in train:
f.write(",".join([str(x) for x in line]) + "\n")
with open("src/main/resources/test.txt", "w") as f:
for line in test:
f.write(",".join([str(x) for x in line]) + "\n")
- 训练并保存模型
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import java.util.HashMap;
import java.util.Map;
public class XgboostTrain {
private static DMatrix trainMat = null;
private static DMatrix testMat = null;
public static void main(String [] args) throws XGBoostError {
try {
trainMat = new DMatrix("src/main/resources/train.txt");
} catch (XGBoostError xgBoostError) {
xgBoostError.printStackTrace();
}
try {
testMat = new DMatrix("src/main/resources/test.txt");
} catch (XGBoostError xgBoostError) {
xgBoostError.printStackTrace();
}
Map<String, Object> params = new HashMap<String, Object>() {
{
put("eta", 1.0);
put("max_depth", 3);
put("objective", "binary:logistic");
put("eval_metric", "logloss");
}
};
Map<String, DMatrix> watches = new HashMap<String, DMatrix>() {
{
put("train",trainMat);
put("test", testMat);
}
};
int nround = 10;
try {
Booster booster = XGBoost.train(trainMat, params, nround, watches, null, null);
booster.saveModel("src/main/resources/model.bin");
} catch (XGBoostError xgBoostError) {
xgBoostError.printStackTrace();
}
}
}
- 进行预测
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
public class XgboostPredict {
public static void main(String [] args) throws XGBoostError {
float[] data = new float[] {5.2f, 3.5f, 1.5f, 0.2f};
int nrow = 1;
int ncol = 4;
float missing = 0.0f;
DMatrix dmat = new DMatrix(data, nrow, ncol, missing);
Booster booster = XGBoost.loadModel("src/main/resources/model.bin");
float[][] predicts = booster.predict(dmat);
for (float[] array: predicts){
for (float values: array) {
System.out.print(values + " ");
}
System.out.println();
}
}
}
这样就简单的实现了,xgboost模型的训练和预测,其实直接使用xgboost模型可以直接免去复杂的特征工程操作,不太需要对特征再进行归一化和one-hot编码操作,这样的话,数据准备好,启动一个SpringBoot基于Java实现的模型服务就搭建起来了。
更多推荐
所有评论(0)