Alink
Alink环境安装安装pip install pyalink-flink-1.10 --user -i https://mirrors.aliyun.com/pypi/simple检查安装pip search pyalink如下则是安装成功pyalink (1.2.0)- Alink Python APIpyalink-flink-1.9 (1.2.0)- Alink Python APIpyal
Alink
环境安装
安装
pip install pyalink-flink-1.10 --user -i https://mirrors.aliyun.com/pypi/simple
检查安装
pip search pyalink
如下则是安装成功
pyalink (1.2.0) - Alink Python API
pyalink-flink-1.9 (1.2.0) - Alink Python API
pyalink-flink-1.10 (1.2.0) - Alink Python API
INSTALLED: 1.2.0 (latest)
<properties>
<java.version>1.8</java.version>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<encoding>UTF-8</encoding>
<scala.version>2.11.8</scala.version>
<scala.compat.version>2.11</scala.compat.version>
<hadoop.version>2.7.2</hadoop.version>
<flink.version>1.10.0</flink.version>
<kafka.version>1.1.1</kafka.version>
</properties>
代码测试
协同过滤
ALS
package com.alink.ml;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.dataproc.SplitBatchOp;
import com.alibaba.alink.operator.batch.recommendation.AlsTrainBatchOp;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.PipelineStageBase;
import com.alibaba.alink.pipeline.clustering.KMeansModel;
import com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler;
import com.alibaba.alink.pipeline.recommendation.ALS;
import com.alibaba.alink.pipeline.recommendation.ALSModel;
/**
* Created by edc on 2020/8/6
*/
public class PiplineExample {
public static void main(String[] args) throws Exception {
//数据源
String url = "https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/movielens_ratings.csv";
String schema = "userid bigint, movieid bigint, rating double, timestamp string";
BatchOperator data = new CsvSourceBatchOp()
.setFilePath(url).setSchemaStr(schema);
SplitBatchOp spliter = new SplitBatchOp().setFraction(0.8);
spliter.linkFrom(data);
BatchOperator trainData = spliter;
BatchOperator testData = spliter.getSideOutput(0);
//构建管道
Pipeline pipeline = new Pipeline();
Pipeline pipeline1 = pipeline
.add(new ALS() .setUserCol("userid")
.setItemCol("movieid")
.setRateCol("rating")
.setNumIter(10).setRank(10).setLambda(0.1).setPredictionCol("pred_rating"));
PipelineModel pipelineModel = pipeline1.fit(trainData);
pipelineModel.save("/home/edc/alink-als.csv");
pipelineModel.transform(testData).print();
}
}
结果展示
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uM585LfQ-1598841831485)(/home/edc/Alink.assets/image-20200806101650749.png)]
分类
决策树
package com.alink.ml;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.dataproc.SplitBatchOp;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.operator.common.tree.seriestree.DecisionTree;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.classification.DecisionTreeClassificationModel;
import com.alibaba.alink.pipeline.classification.DecisionTreeClassifier;
import com.alibaba.alink.pipeline.recommendation.ALS;
import com.alibaba.alink.pipeline.regression.DecisionTreeRegressor;
/**
* Created by edc on 2020/8/6
*/
public class Pipline_DecisionTreeRegressor {
public static void main(String[] args) throws Exception {
//数据源
String filepath = "/home/edc/exampleData.csv";
String schema = "f0 double, f1 string, f2 bigint, f3 bigint ,label bigint";
BatchOperator trainData = new CsvSourceBatchOp()
.setFilePath(filepath).setSchemaStr(schema);
trainData.print();
//构建管道
Pipeline pipeline = new Pipeline();
Pipeline pipeline1 = pipeline
.add(new DecisionTreeClassifier()
.setFeatureCols("f0")
.setFeatureCols("f1")
.setFeatureCols("f2")
.setFeatureCols("f3")
.setLabelCol("label")
.setPredictionCol("pred"));
PipelineModel pipelineModel = pipeline1.fit(trainData);
pipelineModel.save("/home/edc/alink-decisionTree.csv");
pipelineModel.transform(trainData).print();
}
}
结果展示
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8G9HLGyx-1598841831486)(/home/edc/Alink.assets/image-20200806134734142.png)]
随机森林
package com.alink.ml;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.classification.LinearSvm;
import com.alibaba.alink.pipeline.classification.RandomForestClassifier;
/**
* Created by edc on 2020/8/6
*/
public class Pipline_RandomForestClassifier {
public static void main(String[] args) throws Exception {
//数据源
String url = "/home/edc/exampleData.csv";
String schema = "f0 double, f1 string, f2 bigint, f3 bigint ,label bigint";
BatchOperator trainData = new CsvSourceBatchOp()
.setFieldDelimiter(",")
.setFilePath(url).setSchemaStr(schema);
//构建管道
Pipeline pipeline = new Pipeline();
Pipeline pipeline1 = pipeline
.add(new RandomForestClassifier()
.setFeatureCols("f0")
.setFeatureCols("f1")
.setFeatureCols("f2")
.setFeatureCols("f3")
.setLabelCol("label")
.setPredictionCol("pred")
.setPredictionDetailCol("pred_detail"));
PipelineModel pipelineModel = pipeline1.fit(trainData);
pipelineModel.save("/home/edc/alink-RandomForestClassifier.csv");
pipelineModel.transform(trainData).print();
// PipelineModel pipelineModel1 = PipelineModel.load("/home/edc/alink-RandomForestClassifier.csv");
// pipelineModel1.transform(trainData).print();
}
}
结果
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vwxYH2eC-1598841831487)(/home/edc/Alink.assets/image-20200806145242868.png)]
多层感知机
package com.alink.ml;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.classification.MultilayerPerceptronClassifier;
import com.alibaba.alink.pipeline.classification.RandomForestClassifier;
/**
* Created by edc on 2020/8/6
*/
public class Pipline_MultilayerPerceptronClassifier {
public static void main(String[] args) throws Exception {
//数据源
String url = "/home/edc/iris.csv";
String schema = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string";
BatchOperator trainData = new CsvSourceBatchOp()
// .setFieldDelimiter(",")
.setFilePath(url).setSchemaStr(schema);
//构建管道
Pipeline pipeline = new Pipeline();
Pipeline pipeline1 = pipeline
.add(new MultilayerPerceptronClassifier()
.setFeatureCols(new String[]{"sepal_length", "sepal_width", "petal_length", "petal_width"})
.setLabelCol("category")
.setLayers(new int[]{4, 5, 3})
.setMaxIter(20)
.setPredictionCol("pred_label")
.setPredictionDetailCol("pred_detail"));
PipelineModel pipelineModel = pipeline1.fit(trainData);
// pipelineModel.save("/home/edc/alink-MultilayerPerceptronClassifier.csv");
// pipelineModel.transform(trainData).firstN(4).print();
PipelineModel pipelineModel1 = PipelineModel.load("/home/edc/alink-MultilayerPerceptronClassifier.csv");
pipelineModel1.transform(trainData).firstN(4).print();
}
}
结果
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-CLyUtocL-1598841831490)(/home/edc/Alink.assets/image-20200806153549431.png)]
gbdt二分类
package com.alink.ml;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.classification.GbdtClassifier;
import com.alibaba.alink.pipeline.classification.RandomForestClassifier;
/**
* Created by edc on 2020/8/6
*/
public class Pipline_GbdtClassifier {
public static void main(String[] args) throws Exception {
//数据源
String url = "/home/edc/exampleData.csv";
String schema = "f0 double, f1 string, f2 bigint, f3 bigint ,label bigint";
BatchOperator trainData = new CsvSourceBatchOp()
.setFieldDelimiter(",")
.setFilePath(url).setSchemaStr(schema);
//构建管道
Pipeline pipeline = new Pipeline();
Pipeline pipeline1 = pipeline
.add(new GbdtClassifier()
.setLearningRate(1.0)
.setNumTrees(3)
.setMinSamplesPerLeaf(1)
.setPredictionDetailCol("pred_detail")
.setPredictionCol("pred")
.setLabelCol("label")
.setFeatureCols("f0", "f1", "f2", "f3"));
PipelineModel pipelineModel = pipeline1.fit(trainData);
// pipelineModel.save("/home/edc/alink-GbdtClassifier.csv");
// pipelineModel.transform(trainData).print();
PipelineModel pipelineModel1 = PipelineModel.load("/home/edc/alink-GbdtClassifier.csv");
pipelineModel1.transform(trainData).print();
}
}
结果
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Ged2DWkS-1598841831491)(/home/edc/Alink.assets/image-20200806154349592.png)]
softmax 算法
package com.alink.ml;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.classification.GbdtClassifier;
import com.alibaba.alink.pipeline.classification.Softmax;
/**
* Created by edc on 2020/8/6
*/
public class Pipline_Softmax {
public static void main(String[] args) throws Exception {
//数据源
String url = "/home/edc/alink-Softmax-data.csv";
String schema = "f0 int, f1 int, label int";
BatchOperator trainData = new CsvSourceBatchOp()
.setFieldDelimiter(",")
.setFilePath(url).setSchemaStr(schema);
//构建管道
Pipeline pipeline = new Pipeline();
Pipeline pipeline1 = pipeline
.add(new Softmax()
.setFeatureCols("f0", "f1")
.setLabelCol("label")
.setPredictionCol("pred")
);
PipelineModel pipelineModel = pipeline1.fit(trainData);
// pipelineModel.save("/home/edc/alink-Softmax.csv");
// pipelineModel.transform(trainData).print();
PipelineModel pipelineModel1 = PipelineModel.load("/home/edc/alink-Softmax.csv");
pipelineModel1.transform(trainData).print();
}
}
结果
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8MvD3Ker-1598841831492)(/home/edc/Alink.assets/image-20200806155103625.png)]
逻辑回归
package com.alink.ml;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.classification.LogisticRegression;
import com.alibaba.alink.pipeline.classification.Softmax;
/**
* Created by edc on 2020/8/6
*/
public class Pipline_LogisticRegression {
public static void main(String[] args) throws Exception {
//数据源
String url = "/home/edc/alink-Softmax-data.csv";
String schema = "f0 int, f1 int, label int";
BatchOperator trainData = new CsvSourceBatchOp()
.setFieldDelimiter(",")
.setFilePath(url).setSchemaStr(schema);
//构建管道
Pipeline pipeline = new Pipeline();
Pipeline pipeline1 = pipeline
.add(new LogisticRegression()
.setFeatureCols("f0", "f1")
.setLabelCol("label")
.setPredictionCol("pred")
);
PipelineModel pipelineModel = pipeline1.fit(trainData);
// pipelineModel.save("/home/edc/alink-LogisticRegression.csv");
// pipelineModel.transform(trainData).print();
PipelineModel pipelineModel1 = PipelineModel.load("/home/edc/alink-LogisticRegression.csv");
pipelineModel1.transform(trainData).print();
}
}
结果
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qoQs21qy-1598841831492)(/home/edc/Alink.assets/image-20200806155809839.png)]
OneVsRest
package com.alink.ml;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.classification.LogisticRegression;
import com.alibaba.alink.pipeline.classification.OneVsRest;
import com.alibaba.alink.pipeline.classification.Softmax;
/**
* Created by edc on 2020/8/6
*/
public class Pipline_OneVsRest {
public static void main(String[] args) throws Exception {
//数据源
String url = "/home/edc/iris.csv";
String schema = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string";
BatchOperator trainData = new CsvSourceBatchOp()
.setFieldDelimiter(",")
.setFilePath(url).setSchemaStr(schema);
//构建管道
Pipeline pipeline = new Pipeline();
Pipeline pipeline1 = pipeline
.add(new OneVsRest()
.setClassifier(new LogisticRegression()
.setFeatureCols("sepal_length", "sepal_width", "petal_length", "petal_width")
.setLabelCol("category")
.setMaxIter(100)
)
.setNumClass(3)
.setPredictionCol("pred_result")
.setPredictionDetailCol("pred_detail")
);
PipelineModel pipelineModel = pipeline1.fit(trainData);
pipelineModel.save("/home/edc/alink-OneVsRest.csv");
pipelineModel.transform(trainData).print();
//
// PipelineModel pipelineModel1 = PipelineModel.load("/home/edc/alink-OneVsRest.csv");
// pipelineModel1.transform(trainData).print();
}
}
结果
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Wtm8jKaz-1598841831493)(/home/edc/Alink.assets/image-20200806160754376.png)]
统计分析
相关系数(batch)
卡方检验(batch)
全表统计(batch)
数据处理
缺失值填充
Pipeline pipeline = new Pipeline();
Pipeline pipeline1 = pipeline
.add(new Imputer()
.setSelectedCols("f0")
.setStrategy(HasStrategy.Strategy.MAX)
)
.add(new Softmax()
.setFeatureCols("f0", "f1")
.setLabelCol("label")
.setPredictionCol("pred")
);
采样(batch)
String[] colnames = new String[] {"id", "weight", "col0"};
MemSourceBatchOp inOp = new MemSourceBatchOp(Arrays.asList(testArray), colnames);
WeightSampleBatchOp sampleBatchOp = new WeightSampleBatchOp()
.setWeightCol("weight")
.setRatio(0.4)
.setWithReplacement(true)
.linkFrom(inOp);
按个数采样(batch)
inOp = dataframeToOperator(df, schemaStr='Y string', op_type='batch')
sampleOp = SampleWithSizeBatchOp()\
.setSize(2)\
.setWithReplacement(False)
inOp.link(sampleOp).print()
VectorNormalize
标准化
标准化训练(batch)
标准化预测(batch)
标准化预测(stream)
VectorSizeHit
VectorSizeHint(stream)
VectorSizeHint(batch)
VectorAssembler
VectorAssembler(stream)
VectorAssembler(batch)
VectorSlice
VectorSlice(stream)
VectorSlice(batch)
归一化
归一化预测(stream)
归一化训练(batch)
归一化预测(batch)
离散余弦变换
离散余弦变换(stream)
离散余弦变换(batch)
绝对值
绝对值最大标准化预测(stream)
绝对值最大标准化预测(batch)
绝对值最大标准化训练(batch)
VectorInteraction
VectorInteraction(stream)
VectorInteraction(batch)
VectorElementWiseProduct
VectorElementWiseProduct(stream)
VectorElementWiseProduct(batch)
StringIndexer
StringIndexer流式预测(stream)
StringIndexer训练(batch)
StringIndexer预测(batch)
IndexToString
IndexToString预测(stream)
IndexToString预测(batch)
VectorPolynomialExpand
VectorPolynomialExpand(stream)
VectorPolynomialExpand(batch)
添加id列(batch)
Json值抽取(stream)
Json值抽取(batch)
多列字符串编码
多列字符串编码预测(stream)
多列字符串编码预测(batch)
多列字符串编码训练(batch)
向量缺失值填充
向量缺失值填充预测(stream)
向量缺失值填充预测(batch)
向量缺失值填充训练(batch)
向量绝对值最大标准化
向量绝对值最大标准化预测(stream)
向量绝对值最大标准化预测(batch)
向量绝对值最大标准化训练(batch)
向量归一化
向量归一化预测(stream)
向量归一化预测(batch)
向量归一化训练(batch)
向量标准化
向量标准化预测(stream)
向量标准化预测(batch)
向量标准化训练(batch)
UDF & UDTF
批式UDF
流式UDF
批式UDTF
批式UDTF
流式UDTF
特征工程
Quantile离散
Quantile离散预测(stream)
Quantile离散预测(batch)
Quantile离散预测(batch)
Quantile离散训练(batch)
OneHot编码
OneHot编码预测(stream)
OneHot编码训练(batch)
OneHot编码预测(batch)
卡方筛选(batch)
二值化
二值化(stream)
二值化(batch)
特征哈希
特征哈希(stream)
特征哈希(batch)
特征分桶
特征分桶(stream)
特征分桶(batch)
主成分分析
主成分分析预测(stream)
主成分分析训练(batch)
主成分分析预测(batch)
回归
gbdt回归
gbdt回归预测(stream)
gbdt回归训练(batch)
gbdt回归预测(batch)
线性回归
线性回归训练(batch)
线性回归预测(batch)
线性回归预测(stream)
随机森林回归
随机森林回归预测(stream)
随机森林回归训练(batch)
随机森林回归预测(batch)
保序回归
保序回归预测(batch)
保序回归训练(batch)
保序回归预测(stream)
广义线性回归
广义线性回归预测(stream)
广义线性回归预测(batch)
广义线性回归训练(batch)
广义线性回归评估(batch)
AFT生存回归
AFT生存回归预测(batch)
AFT生存回归预测(stream)
AFT生存回归训练(batch)
决策树回归
决策树回归训练(batch)
决策树回归预测(batch)
决策树回归预测(stream)
Lasso回归
Lasso回归训练(batch)
Lasso回归预测(batch)
Lasso回归预测(stream)
岭回归
岭回归训练(batch)
岭回归预测(batch)
岭回归预测(stream)
聚类
二分K均值聚类
二分K均值聚类训练(batch)
二分K均值聚类预测(stream)
二分K均值聚类预测(batch)
KMeans
KMeans预测(stream)
KMeans预测(batch)
KMeans训练(batch)
高斯混合模型
高斯混合模型(batch)
高斯混合模型流式预测(stream)
高斯混合模型预测(batch)
LDA
LDA算法训练(batch)
LDA算法训练(batch)
LDA批预测(batch)
关联规则
FPGrowth算法(batch)
PrefixSpan算法(batch)
异常检测
SOS
SQL
As(batch)
As(stream)
Distinct(batch)
Filter(stream)
Filter(batch)
FullouterJoin(batch)
GroupBy(batch)
Intersect(batch)
IntersectAll(batch)
Join(batch)
LeftOuterIoin(batch)
RightOuterJoin(batch)
MinusAll(batch)
Minus(batch)
Orderby(batch)
Select
Select(stream)
Select(batch)
数据拆分(stream)
数据拆分(batch)
Union(batch)
UnionAll(stream)
UnionAll(batch)
Where(batch)
数据格式转换
AnyToTriple(batch)
ColumnsToCsv(batch)
ColumnsToJson(batch)
ColumnsToKv(batch)
ColumnsToTriple(batch)
ColumnsToVector(batch)
CsvToColumns(batch)
CsvToJson(batch)
CsvToKv(batch)
CsvToTriple(batch)
CsvToVector(batch)
JsonToColumns(batch)
JsonToCsv(batch)
JsonToKv(batch)
JsonToTriple(batch)
JsonToVector(batch)
KvToColumns(batch)
KvToCsv(batch)
KvToJson(batch)
KvToTriple(batch)
KvToVector(batch)
TripleToAny(batch)
TripleToColumns(batch)
TripleToCsv(batch)
TripleToJson(batch)
TripleToKv(batch)
TripleToVector(batch)
VectorToColumns(batch)
VectorToCsv(batch)
VectorToJson(batch)
VectorToKv(batch)
VectorToTriple(batch)
AnyToTriple(stream)
ColumnsToCsv(stream)
ColumnsToJson(stream)
ColumnsToKv(stream)
ColumnsToTriple(stream)
ColumnsToVector(stream)
CsvToColumns(stream)
CsvToJson(stream)
CsvToKv(stream)
CsvToTriple(stream)
CsvToVector(stream)
JsonToColumns(stream)
JsonToCsv(stream)
JsonToKv(stream)
JsonToTriple(stream)
JsonToVector(stream)
KvToColumns(stream)
KvToCsv(stream)
KvToJson(stream)
KvToTriple(stream)
KvToVector(stream)
VectorToColumns(stream)
VectorToCsv(stream)
VectorToJson(stream)
VectorToKv(stream)
VectorToTriple(stream)
ColumnsToCsv
ColumnsToJson
ColumnsToKv
ColumnsToVector
CsvToColumns
CsvToJson
CsvToKv
CsvToVector
JsonToColumns
JsonToCsv
JsonToKv
JsonToVector
KvToColumns
KvToCsv
KvToJson
KvToVector
VectorToColumns
VectorToCsv
VectorToJson
VectorToKv
源码解读
pipline
- pipline里面维护了一个arraylist来存储PipelineStageBase
/**
* A pipeline is a linear workflow which chains {@link EstimatorBase}s and {@link TransformerBase}s to
* execute an algorithm.
*/
public class Pipeline extends EstimatorBase<Pipeline, PipelineModel> {
private ArrayList<PipelineStageBase> stages = new ArrayList<>();
public Pipeline() {
this(new Params());
}
public Pipeline(Params params) {
super(params);
}
public Pipeline(PipelineStageBase<?>... stages) {
super(null);
if (null != stages) {
this.stages.addAll(Arrays.asList(stages));
}
}
- pipline的add方法
/**
* Inserts the specified stage at the specified position in this
* pipeline. Shifts the stage currently at that position (if any) and
* any subsequent stages to the right (adds one to their indices).
*
* @param index index at which the specified stage is to be inserted
* @param stage pipelineStage to be inserted
* @return this pipeline
* @throws IndexOutOfBoundsException
*/
public Pipeline add(int index, PipelineStageBase stage) {
this.stages.add(index, stage);
return this;
}
/**
* Appends the specified stage to the end of this pipeline.
*
* @param stage pipelineStage to be appended to this pipeline
* @return this pipeline
*/
public Pipeline add(PipelineStageBase stage) {
this.stages.add(stage);
return this;
}
- pipline的fit方法会分辨EstimatorBase和TransformerBase操作,封装成PipelineModel
/**
* Train the pipeline with stream data.
*
* @param input input data
* @return pipeline model
*/
@Override
public PipelineModel fit(StreamOperator input) {
int lastEstimatorIdx = getIndexOfLastEstimator();
TransformerBase[] transformers = new TransformerBase[stages.size()];
for (int i = 0; i < stages.size(); i++) {
PipelineStageBase stage = stages.get(i);
if (i <= lastEstimatorIdx) {
if (stage instanceof EstimatorBase) {
transformers[i] = ((EstimatorBase) stage).fit(input);
} else if (stage instanceof TransformerBase) {
transformers[i] = (TransformerBase) stage;
}
if (i < lastEstimatorIdx) {
input = transformers[i].transform(input);
}
} else {
// After lastEstimatorIdx, there're only Transformer stages, so it's safe to do type cast.
transformers[i] = (TransformerBase) stage;
}
}
return new PipelineModel(transformers).setMLEnvironmentId(input.getMLEnvironmentId());
}
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bETR5e2k-1598841831494)(/home/edc/Alink.assets/image-20200807142321836.png)]
更多推荐
所有评论(0)