Alink

环境安装

安装

pip install pyalink-flink-1.10 --user -i https://mirrors.aliyun.com/pypi/simple

检查安装

pip search pyalink

如下则是安装成功

pyalink (1.2.0)             - Alink Python API
pyalink-flink-1.9 (1.2.0)   - Alink Python API
pyalink-flink-1.10 (1.2.0)  - Alink Python API
  INSTALLED: 1.2.0 (latest)
 <properties>
        <java.version>1.8</java.version>
        <maven.compiler.source>1.8</maven.compiler.source>
        <maven.compiler.target>1.8</maven.compiler.target>
        <encoding>UTF-8</encoding>
        <scala.version>2.11.8</scala.version>
        <scala.compat.version>2.11</scala.compat.version>
        <hadoop.version>2.7.2</hadoop.version>
        <flink.version>1.10.0</flink.version>
        <kafka.version>1.1.1</kafka.version>
    </properties>

代码测试

协同过滤

ALS

package com.alink.ml;

import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.dataproc.SplitBatchOp;
import com.alibaba.alink.operator.batch.recommendation.AlsTrainBatchOp;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.PipelineStageBase;
import com.alibaba.alink.pipeline.clustering.KMeansModel;
import com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler;
import com.alibaba.alink.pipeline.recommendation.ALS;
import com.alibaba.alink.pipeline.recommendation.ALSModel;

/**
 * Created by edc on 2020/8/6
 */
public class PiplineExample {
    public static void main(String[] args) throws Exception {
        //数据源
        String url = "https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/movielens_ratings.csv";
        String schema = "userid bigint, movieid bigint, rating double, timestamp string";

        BatchOperator data = new CsvSourceBatchOp()
                .setFilePath(url).setSchemaStr(schema);

        SplitBatchOp spliter = new SplitBatchOp().setFraction(0.8);
        spliter.linkFrom(data);

        BatchOperator trainData = spliter;
        BatchOperator testData = spliter.getSideOutput(0);

        //构建管道
        Pipeline pipeline = new Pipeline();


        Pipeline pipeline1 = pipeline	
             .add(new ALS() .setUserCol("userid")
                .setItemCol("movieid")
                .setRateCol("rating")
                .setNumIter(10).setRank(10).setLambda(0.1).setPredictionCol("pred_rating"));

        PipelineModel pipelineModel = pipeline1.fit(trainData);

        pipelineModel.save("/home/edc/alink-als.csv");

        pipelineModel.transform(testData).print();

    }
}

结果展示

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uM585LfQ-1598841831485)(/home/edc/Alink.assets/image-20200806101650749.png)]

分类

决策树

package com.alink.ml;

import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.dataproc.SplitBatchOp;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.operator.common.tree.seriestree.DecisionTree;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.classification.DecisionTreeClassificationModel;
import com.alibaba.alink.pipeline.classification.DecisionTreeClassifier;
import com.alibaba.alink.pipeline.recommendation.ALS;
import com.alibaba.alink.pipeline.regression.DecisionTreeRegressor;

/**
 * Created by edc on 2020/8/6
 */
public class Pipline_DecisionTreeRegressor {
    public static void main(String[] args) throws Exception {
        //数据源
        String filepath = "/home/edc/exampleData.csv";
        String schema = "f0 double, f1 string, f2 bigint, f3 bigint ,label bigint";

        BatchOperator trainData = new CsvSourceBatchOp()
                .setFilePath(filepath).setSchemaStr(schema);

        trainData.print();

        //构建管道
        Pipeline pipeline = new Pipeline();
        Pipeline pipeline1 = pipeline
                .add(new DecisionTreeClassifier()
                        .setFeatureCols("f0")
                        .setFeatureCols("f1")
                        .setFeatureCols("f2")
                        .setFeatureCols("f3")
                        .setLabelCol("label")
                        .setPredictionCol("pred"));

        PipelineModel pipelineModel = pipeline1.fit(trainData);

        pipelineModel.save("/home/edc/alink-decisionTree.csv");

        pipelineModel.transform(trainData).print();

    }
}

结果展示

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8G9HLGyx-1598841831486)(/home/edc/Alink.assets/image-20200806134734142.png)]

随机森林

package com.alink.ml;

import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.classification.LinearSvm;
import com.alibaba.alink.pipeline.classification.RandomForestClassifier;

/**
 * Created by edc on 2020/8/6
 */
public class Pipline_RandomForestClassifier {
    public static void main(String[] args) throws Exception {
        //数据源
        String url = "/home/edc/exampleData.csv";
        String schema = "f0 double, f1 string, f2 bigint, f3 bigint ,label bigint";


        BatchOperator trainData = new CsvSourceBatchOp()
                .setFieldDelimiter(",")
                .setFilePath(url).setSchemaStr(schema);

        //构建管道
        Pipeline pipeline = new Pipeline();
        Pipeline pipeline1 = pipeline
                .add(new RandomForestClassifier()
                        .setFeatureCols("f0")
                        .setFeatureCols("f1")
                        .setFeatureCols("f2")
                        .setFeatureCols("f3")
                        .setLabelCol("label")
                        .setPredictionCol("pred")
                        .setPredictionDetailCol("pred_detail"));

        PipelineModel pipelineModel = pipeline1.fit(trainData);

        pipelineModel.save("/home/edc/alink-RandomForestClassifier.csv");
        pipelineModel.transform(trainData).print();

//        PipelineModel pipelineModel1 = PipelineModel.load("/home/edc/alink-RandomForestClassifier.csv");
//        pipelineModel1.transform(trainData).print();

    }
}

结果

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vwxYH2eC-1598841831487)(/home/edc/Alink.assets/image-20200806145242868.png)]

多层感知机

package com.alink.ml;

import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.classification.MultilayerPerceptronClassifier;
import com.alibaba.alink.pipeline.classification.RandomForestClassifier;

/**
 * Created by edc on 2020/8/6
 */
public class Pipline_MultilayerPerceptronClassifier {
    public static void main(String[] args) throws Exception {
        //数据源
        String url = "/home/edc/iris.csv";
        String schema = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string";


        BatchOperator trainData = new CsvSourceBatchOp()
//                .setFieldDelimiter(",")
                .setFilePath(url).setSchemaStr(schema);

        //构建管道
        Pipeline pipeline = new Pipeline();

        Pipeline pipeline1 = pipeline
                .add(new MultilayerPerceptronClassifier()
                        .setFeatureCols(new String[]{"sepal_length", "sepal_width", "petal_length", "petal_width"})
                        .setLabelCol("category")
                        .setLayers(new int[]{4, 5, 3})
                        .setMaxIter(20)
                        .setPredictionCol("pred_label")
                        .setPredictionDetailCol("pred_detail"));

        PipelineModel pipelineModel = pipeline1.fit(trainData);

//        pipelineModel.save("/home/edc/alink-MultilayerPerceptronClassifier.csv");
//        pipelineModel.transform(trainData).firstN(4).print();


        PipelineModel pipelineModel1 = PipelineModel.load("/home/edc/alink-MultilayerPerceptronClassifier.csv");
        pipelineModel1.transform(trainData).firstN(4).print();

    }
}

结果

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-CLyUtocL-1598841831490)(/home/edc/Alink.assets/image-20200806153549431.png)]

gbdt二分类

package com.alink.ml;

import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.classification.GbdtClassifier;
import com.alibaba.alink.pipeline.classification.RandomForestClassifier;

/**
 * Created by edc on 2020/8/6
 */
public class Pipline_GbdtClassifier {
    public static void main(String[] args) throws Exception {
        //数据源
        String url = "/home/edc/exampleData.csv";
        String schema = "f0 double, f1 string, f2 bigint, f3 bigint ,label bigint";


        BatchOperator trainData = new CsvSourceBatchOp()
                .setFieldDelimiter(",")
                .setFilePath(url).setSchemaStr(schema);

        //构建管道
        Pipeline pipeline = new Pipeline();
        Pipeline pipeline1 = pipeline
                .add(new GbdtClassifier()
                        .setLearningRate(1.0)
                        .setNumTrees(3)
                        .setMinSamplesPerLeaf(1)
                        .setPredictionDetailCol("pred_detail")
                        .setPredictionCol("pred")
                        .setLabelCol("label")
                        .setFeatureCols("f0", "f1", "f2", "f3"));

        PipelineModel pipelineModel = pipeline1.fit(trainData);

//        pipelineModel.save("/home/edc/alink-GbdtClassifier.csv");
//        pipelineModel.transform(trainData).print();

        PipelineModel pipelineModel1 = PipelineModel.load("/home/edc/alink-GbdtClassifier.csv");
        pipelineModel1.transform(trainData).print();

    }
}

结果

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Ged2DWkS-1598841831491)(/home/edc/Alink.assets/image-20200806154349592.png)]

softmax 算法

package com.alink.ml;

import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.classification.GbdtClassifier;
import com.alibaba.alink.pipeline.classification.Softmax;

/**
 * Created by edc on 2020/8/6
 */
public class Pipline_Softmax {
    public static void main(String[] args) throws Exception {
        //数据源
        String url = "/home/edc/alink-Softmax-data.csv";
        String schema = "f0 int, f1 int, label int";


        BatchOperator trainData = new CsvSourceBatchOp()
                .setFieldDelimiter(",")
                .setFilePath(url).setSchemaStr(schema);

        //构建管道
        Pipeline pipeline = new Pipeline();
        Pipeline pipeline1 = pipeline
                .add(new Softmax()
                        .setFeatureCols("f0", "f1")
                        .setLabelCol("label")
                        .setPredictionCol("pred")
                );

        PipelineModel pipelineModel = pipeline1.fit(trainData);

//        pipelineModel.save("/home/edc/alink-Softmax.csv");
//        pipelineModel.transform(trainData).print();

        PipelineModel pipelineModel1 = PipelineModel.load("/home/edc/alink-Softmax.csv");
        pipelineModel1.transform(trainData).print();

    }
}

结果

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8MvD3Ker-1598841831492)(/home/edc/Alink.assets/image-20200806155103625.png)]

逻辑回归

package com.alink.ml;

import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.classification.LogisticRegression;
import com.alibaba.alink.pipeline.classification.Softmax;

/**
 * Created by edc on 2020/8/6
 */
public class Pipline_LogisticRegression {
    public static void main(String[] args) throws Exception {
        //数据源
        String url = "/home/edc/alink-Softmax-data.csv";
        String schema = "f0 int, f1 int, label int";


        BatchOperator trainData = new CsvSourceBatchOp()
                .setFieldDelimiter(",")
                .setFilePath(url).setSchemaStr(schema);

        //构建管道
        Pipeline pipeline = new Pipeline();
        Pipeline pipeline1 = pipeline
                .add(new LogisticRegression()
                        .setFeatureCols("f0", "f1")
                        .setLabelCol("label")
                        .setPredictionCol("pred")
                );

        PipelineModel pipelineModel = pipeline1.fit(trainData);

//        pipelineModel.save("/home/edc/alink-LogisticRegression.csv");
//        pipelineModel.transform(trainData).print();

        PipelineModel pipelineModel1 = PipelineModel.load("/home/edc/alink-LogisticRegression.csv");
        pipelineModel1.transform(trainData).print();

    }
}

结果

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qoQs21qy-1598841831492)(/home/edc/Alink.assets/image-20200806155809839.png)]

OneVsRest

package com.alink.ml;

import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.classification.LogisticRegression;
import com.alibaba.alink.pipeline.classification.OneVsRest;
import com.alibaba.alink.pipeline.classification.Softmax;

/**
 * Created by edc on 2020/8/6
 */
public class Pipline_OneVsRest {
    public static void main(String[] args) throws Exception {
        //数据源
        String url = "/home/edc/iris.csv";
        String schema = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string";


        BatchOperator trainData = new CsvSourceBatchOp()
                .setFieldDelimiter(",")
                .setFilePath(url).setSchemaStr(schema);

        //构建管道
        Pipeline pipeline = new Pipeline();
        Pipeline pipeline1 = pipeline
                .add(new OneVsRest()
                        .setClassifier(new LogisticRegression()
                                .setFeatureCols("sepal_length", "sepal_width", "petal_length", "petal_width")
                                .setLabelCol("category")
                                .setMaxIter(100)
                        )
                        .setNumClass(3)
                        .setPredictionCol("pred_result")
                        .setPredictionDetailCol("pred_detail")
                );

        PipelineModel pipelineModel = pipeline1.fit(trainData);

        pipelineModel.save("/home/edc/alink-OneVsRest.csv");
        pipelineModel.transform(trainData).print();
//
//        PipelineModel pipelineModel1 = PipelineModel.load("/home/edc/alink-OneVsRest.csv");
//        pipelineModel1.transform(trainData).print();

    }
}

结果

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Wtm8jKaz-1598841831493)(/home/edc/Alink.assets/image-20200806160754376.png)]

统计分析

相关系数(batch)

卡方检验(batch)

全表统计(batch)

数据处理

缺失值填充

Pipeline pipeline = new Pipeline();
Pipeline pipeline1 = pipeline
        .add(new Imputer()
                .setSelectedCols("f0")
                .setStrategy(HasStrategy.Strategy.MAX)
        )
        .add(new Softmax()
                .setFeatureCols("f0", "f1")
                .setLabelCol("label")
                .setPredictionCol("pred")
        );

采样(batch)

 String[] colnames = new String[] {"id", "weight", "col0"};
        MemSourceBatchOp inOp = new MemSourceBatchOp(Arrays.asList(testArray), colnames);
        WeightSampleBatchOp sampleBatchOp = new WeightSampleBatchOp()
            .setWeightCol("weight")
            .setRatio(0.4)
            .setWithReplacement(true)
            .linkFrom(inOp);

按个数采样(batch)

inOp = dataframeToOperator(df, schemaStr='Y string', op_type='batch')
sampleOp = SampleWithSizeBatchOp()\
        .setSize(2)\
        .setWithReplacement(False)
inOp.link(sampleOp).print()

VectorNormalize

标准化

标准化训练(batch)
标准化预测(batch)
标准化预测(stream)

VectorSizeHit

VectorSizeHint(stream)
VectorSizeHint(batch)

VectorAssembler

VectorAssembler(stream)
VectorAssembler(batch)

VectorSlice

VectorSlice(stream)
VectorSlice(batch)

归一化

归一化预测(stream)
归一化训练(batch)
归一化预测(batch)

离散余弦变换

离散余弦变换(stream)
离散余弦变换(batch)

绝对值

绝对值最大标准化预测(stream)
绝对值最大标准化预测(batch)
绝对值最大标准化训练(batch)

VectorInteraction

VectorInteraction(stream)
VectorInteraction(batch)

VectorElementWiseProduct

VectorElementWiseProduct(stream)
VectorElementWiseProduct(batch)

StringIndexer

StringIndexer流式预测(stream)
StringIndexer训练(batch)
StringIndexer预测(batch)

IndexToString

IndexToString预测(stream)
IndexToString预测(batch)

VectorPolynomialExpand

VectorPolynomialExpand(stream)
VectorPolynomialExpand(batch)

添加id列(batch)

Json值抽取(stream)

Json值抽取(batch)

多列字符串编码

多列字符串编码预测(stream)
多列字符串编码预测(batch)
多列字符串编码训练(batch)

向量缺失值填充

向量缺失值填充预测(stream)
向量缺失值填充预测(batch)
向量缺失值填充训练(batch)

向量绝对值最大标准化

向量绝对值最大标准化预测(stream)
向量绝对值最大标准化预测(batch)
向量绝对值最大标准化训练(batch)

向量归一化

向量归一化预测(stream)
向量归一化预测(batch)
向量归一化训练(batch)

向量标准化

向量标准化预测(stream)
向量标准化预测(batch)
向量标准化训练(batch)

UDF & UDTF

批式UDF
流式UDF
批式UDTF
批式UDTF
流式UDTF

特征工程

Quantile离散

Quantile离散预测(stream)
Quantile离散预测(batch)
Quantile离散预测(batch)
Quantile离散训练(batch)

OneHot编码

OneHot编码预测(stream)
OneHot编码训练(batch)
OneHot编码预测(batch)

卡方筛选(batch)

二值化

二值化(stream)
二值化(batch)

特征哈希

特征哈希(stream)
特征哈希(batch)

特征分桶

特征分桶(stream)
特征分桶(batch)

主成分分析

主成分分析预测(stream)
主成分分析训练(batch)
主成分分析预测(batch)

回归

gbdt回归

gbdt回归预测(stream)
gbdt回归训练(batch)
gbdt回归预测(batch)

线性回归

线性回归训练(batch)
线性回归预测(batch)
线性回归预测(stream)

随机森林回归

随机森林回归预测(stream)
随机森林回归训练(batch)
随机森林回归预测(batch)

保序回归

保序回归预测(batch)
保序回归训练(batch)
保序回归预测(stream)

广义线性回归

广义线性回归预测(stream)
广义线性回归预测(batch)
广义线性回归训练(batch)
广义线性回归评估(batch)

AFT生存回归

AFT生存回归预测(batch)
AFT生存回归预测(stream)
AFT生存回归训练(batch)

决策树回归

决策树回归训练(batch)
决策树回归预测(batch)
决策树回归预测(stream)

Lasso回归

Lasso回归训练(batch)
Lasso回归预测(batch)
Lasso回归预测(stream)

岭回归

岭回归训练(batch)
岭回归预测(batch)
岭回归预测(stream)

聚类

二分K均值聚类

二分K均值聚类训练(batch)
二分K均值聚类预测(stream)
二分K均值聚类预测(batch)

KMeans

KMeans预测(stream)
KMeans预测(batch)
KMeans训练(batch)

高斯混合模型

高斯混合模型(batch)
高斯混合模型流式预测(stream)
高斯混合模型预测(batch)

LDA

LDA算法训练(batch)
LDA算法训练(batch)
LDA批预测(batch)

关联规则

FPGrowth算法(batch)

PrefixSpan算法(batch)

异常检测

SOS

SQL

As(batch)

As(stream)

Distinct(batch)

Filter(stream)

Filter(batch)

FullouterJoin(batch)

GroupBy(batch)

Intersect(batch)

IntersectAll(batch)

Join(batch)

LeftOuterIoin(batch)

RightOuterJoin(batch)

MinusAll(batch)

Minus(batch)

Orderby(batch)

Select

Select(stream)

Select(batch)

数据拆分(stream)

数据拆分(batch)

Union(batch)

UnionAll(stream)

UnionAll(batch)

Where(batch)

数据格式转换

AnyToTriple(batch)
ColumnsToCsv(batch)
ColumnsToJson(batch)
ColumnsToKv(batch)
ColumnsToTriple(batch)
ColumnsToVector(batch)
CsvToColumns(batch)
CsvToJson(batch)

CsvToKv(batch)

CsvToTriple(batch)
CsvToVector(batch)
JsonToColumns(batch)

JsonToCsv(batch)

JsonToKv(batch)
JsonToTriple(batch)
JsonToVector(batch)
KvToColumns(batch)
KvToCsv(batch)
KvToJson(batch)
KvToTriple(batch)
KvToVector(batch)
TripleToAny(batch)
TripleToColumns(batch)
TripleToCsv(batch)
TripleToJson(batch)
TripleToKv(batch)
TripleToVector(batch)
VectorToColumns(batch)
VectorToCsv(batch)
VectorToJson(batch)
VectorToKv(batch)
VectorToTriple(batch)
AnyToTriple(stream)
ColumnsToCsv(stream)
ColumnsToJson(stream)
ColumnsToKv(stream)
ColumnsToTriple(stream)
ColumnsToVector(stream)
CsvToColumns(stream)
CsvToJson(stream)
CsvToKv(stream)
CsvToTriple(stream)
CsvToVector(stream)
JsonToColumns(stream)
JsonToCsv(stream)
JsonToKv(stream)
JsonToTriple(stream)
JsonToVector(stream)
KvToColumns(stream)
KvToCsv(stream)
KvToJson(stream)
KvToTriple(stream)
KvToVector(stream)
VectorToColumns(stream)
VectorToCsv(stream)
VectorToJson(stream)
VectorToKv(stream)
VectorToTriple(stream)
ColumnsToCsv
ColumnsToJson
ColumnsToKv
ColumnsToVector
CsvToColumns
CsvToJson
CsvToKv
CsvToVector
JsonToColumns

JsonToCsv

JsonToKv
JsonToVector
KvToColumns
KvToCsv

KvToJson

KvToVector

VectorToColumns
VectorToCsv
VectorToJson
VectorToKv

源码解读

pipline

  1. pipline里面维护了一个arraylist来存储PipelineStageBase
/**
 * A pipeline is a linear workflow which chains {@link EstimatorBase}s and {@link TransformerBase}s to
 * execute an algorithm.
 */
public class Pipeline extends EstimatorBase<Pipeline, PipelineModel> {

	private ArrayList<PipelineStageBase> stages = new ArrayList<>();

	public Pipeline() {
		this(new Params());
	}

	public Pipeline(Params params) {
		super(params);
	}

	public Pipeline(PipelineStageBase<?>... stages) {
		super(null);
		if (null != stages) {
			this.stages.addAll(Arrays.asList(stages));
		}
	}

  1. pipline的add方法
	/**
	 * Inserts the specified stage at the specified position in this
	 * pipeline. Shifts the stage currently at that position (if any) and
	 * any subsequent stages to the right (adds one to their indices).
	 *
	 * @param index index at which the specified stage is to be inserted
	 * @param stage pipelineStage to be inserted
	 * @return this pipeline
	 * @throws IndexOutOfBoundsException
	 */
	public Pipeline add(int index, PipelineStageBase stage) {
		this.stages.add(index, stage);
		return this;
	}
	/**
	 * Appends the specified stage to the end of this pipeline.
	 *
	 * @param stage pipelineStage to be appended to this pipeline
	 * @return this pipeline
	 */
	public Pipeline add(PipelineStageBase stage) {
		this.stages.add(stage);
		return this;
	}
  1. pipline的fit方法会分辨EstimatorBase和TransformerBase操作,封装成PipelineModel
	/**
	 * Train the pipeline with stream data.
	 *
	 * @param input input data
	 * @return pipeline model
	 */
	@Override
	public PipelineModel fit(StreamOperator input) {
		int lastEstimatorIdx = getIndexOfLastEstimator();
		TransformerBase[] transformers = new TransformerBase[stages.size()];
		for (int i = 0; i < stages.size(); i++) {
			PipelineStageBase stage = stages.get(i);
			if (i <= lastEstimatorIdx) {
				if (stage instanceof EstimatorBase) {
					transformers[i] = ((EstimatorBase) stage).fit(input);
				} else if (stage instanceof TransformerBase) {
					transformers[i] = (TransformerBase) stage;
				}
				if (i < lastEstimatorIdx) {
					input = transformers[i].transform(input);
				}
			} else {
				// After lastEstimatorIdx, there're only Transformer stages, so it's safe to do type cast.
				transformers[i] = (TransformerBase) stage;
			}
		}
		return new PipelineModel(transformers).setMLEnvironmentId(input.getMLEnvironmentId());
	}

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bETR5e2k-1598841831494)(/home/edc/Alink.assets/image-20200807142321836.png)]

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐