亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

Java開發(fā)Spark應(yīng)用程序自定義PipeLineStage詳解

 更新時(shí)間:2023年02月01日 11:15:42   作者:KYs_Daddy  
這篇文章主要為大家介紹了Java開發(fā)Spark應(yīng)用程序自定義PipeLineStage詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪

引言

在Spark中使用Pipeline進(jìn)行數(shù)據(jù)建模是一種非常高效的手段。作為Pipeline中基本數(shù)據(jù)加工處理單元——PipelineStage,Spark提供了用戶自定義的抽象子類Transformer和Estimator。

關(guān)于自定義PipelineStage的詳細(xì)方法,大部分的資料和介紹都是基于scala的。少數(shù)基于Java的介紹都極不完整,有些可能還存在一定的誤導(dǎo)。所以接下來我們將系統(tǒng)的介紹用Java開發(fā)Spark如何自定義PipelineStage。

本文使用環(huán)境:Spark-2.3.0,Java 8。

背景知識(shí)介紹

在spark中構(gòu)建一條Pipeline需要串聯(lián)多個(gè)PipelineStage,每個(gè)PipelineStage單獨(dú)處理一個(gè)數(shù)據(jù)加工環(huán)節(jié),如數(shù)據(jù)清洗、特征提取、特征選擇、預(yù)估等。PipelineStage按是否有訓(xùn)練訓(xùn)練方法分為Transformer和Estimator兩個(gè)抽象子類。其中Estimator可以進(jìn)行訓(xùn)練,有fit抽象方法要實(shí)現(xiàn)。

用于分類回歸等任務(wù)的Predictor都繼承于Estimator;而Transformer無需訓(xùn)練,沒有fit方法,一般的數(shù)據(jù)轉(zhuǎn)換器如VectorAssembler、StopWordsRemover等都是Transformer的子類。值得注意的是,所有由Estimator訓(xùn)練得到的Model類也都是Transformer的子類。

自定義PipelineStage需要繼承Transformer或Estimator并實(shí)現(xiàn)他們的方法。除此之外,我們自定義的PipelineStage要能同其他官方定義的PipelineStage一樣按照統(tǒng)一的讀寫流程進(jìn)行存儲(chǔ)和加載。PipelineStage讀寫基于Param對(duì)象,PipelineStage中的成員變量需要用Param類進(jìn)行封裝,然后用PipelineStage類中已實(shí)現(xiàn)的Params接口方法對(duì)封裝后的成員變量進(jìn)行統(tǒng)一訪問和處理。由于PipelineStage具有以上特性,我們自定義PipelineStage至少需要以下幾個(gè)步驟:

  • 繼承Transformer或Estimator抽象類;
  • 定義由Param封裝的成員變量,并通過調(diào)用由PipelineStage類實(shí)現(xiàn)的Params接口方法,定義對(duì)成員變量進(jìn)行操作的方法;
  • 實(shí)現(xiàn)Transformer或Estimator中的fit、transform等核心抽象方法;
  • 定義或?qū)崿F(xiàn)讀寫方法用于存儲(chǔ)和加載對(duì)象實(shí)例。

下面我們將自定義一個(gè)Transformer,并對(duì)其中的一些細(xì)節(jié)與要點(diǎn)進(jìn)行詳述。

定義一個(gè)Transformer

1. 場(chǎng)景介紹

本例定義一個(gè)名為SeqAssembler的Transformer,用于提取用戶最近n次(n>=0,包括本單)下單的序列特征。 輸入Dataset包括以下字段:user_id, buy_rn, feat1, feat2, feat3。經(jīng)過SeqAssembler后輸出:user_id, buy_rn, feat1, feat2, feat3, features。 其中features為數(shù)組類型,shape (3n, ) :
輸入:

輸出:

2. 代碼實(shí)現(xiàn)

2.1 定義并封裝成員變量

SeqAssembler中要定義如下成員變量:

private String idCol;
private String rnCol;
private String[] featCols;
private String outputCol;
private Integer limitRn;

使用org.apache.spark.ml.param.Param對(duì)成員變量進(jìn)行封裝,String[] 用StringArrayParam封裝, Integer成員變量采用String類型的Param進(jìn)行封裝,方便保存的時(shí)候進(jìn)行Json化,封裝后的成員變量如下:

private Param<String> idCol;
private Param<String> rnCol;
private StringArrayParam featCols;
private Param<String> outputCol;
private Param<String> limitRn; //Integer成員變量需要用String類型Param封裝,由于保存時(shí)要調(diào)用JsonEncoder方法,JsonEncoder僅支持String、數(shù)組等類型的數(shù)據(jù)。

此外我們還需要定義一個(gè)名為uid成員變量,用于識(shí)別SeqAssembler對(duì)象,并定義至少兩個(gè)構(gòu)造器,需要注意的細(xì)節(jié)如下:

  • uid不用聲明成靜態(tài)的,同一Spark進(jìn)程下初始化多個(gè)SeqAssembler對(duì)象,每個(gè)SeqAssembler對(duì)象都要有自己的uid,不用全局唯一。
  • uid的初始化需要在Param成員變量初始化之前,有了uid之后才能進(jìn)行Param成員變量的初始化。
  • 需要至少定義兩個(gè)構(gòu)造器,其中一個(gè)是無參構(gòu)造器,另一個(gè)是需要傳入唯一參數(shù)String uid的有參構(gòu)造器,有參構(gòu)造器用于load過程中構(gòu)造SeqAssembler對(duì)象。各成員變量需要在構(gòu)造器中完成初始化。

至此,SeqAssembler類中定義內(nèi)容如下:

public class SeqAssembler extends Transformer {
    private String uid;
    private Param<String> idCol;
    private Param<String> rnCol;
    private StringArrayParam featCols;
    private Param<String> outputCol;
    private Param<String> limitRn;
    /**
     * 定義一個(gè)輔助Param初始化的方法,在構(gòu)造器中對(duì)各Param成員變量進(jìn)行初始化
     */
    public void initParam(){
        idCol = new Param<String>(this,"idCol","Column name for id");
        rnCol = new Param<String>(this,"rnCol","Column name for sequential rn");
        featCols = new StringArrayParam(this,"featCols","Column names of features");
        outputCol = new Param<String>(this,"outputCol","Column name of output");
        limitRn = new Param<String>(this,"limitRn","Column name of limitRn");
    }
    public SeqAssembler() {
        uid = Identifiable$.MODULE$.randomUID("SeqAssembler"); //uid初始化在Param類型成員變量前
        initParam();
    }
    public SeqAssembler(String value){
        uid = value; //uid初始化在Param類型成員變量前
        initParam();
    }
    @Override
    public Dataset<Row> transform(Dataset<?> dataset) {
        return null;
    }
    @Override
    public StructType transformSchema(StructType schema) {
        return null;
    }
    @Override
    public Transformer copy(ParamMap extra) {
        return null;
    }
    @Override
    public String uid() {
        return null;
    }
}

接著定義get、set方法,調(diào)用PipelineStage類中已實(shí)現(xiàn)的Params接口下的$()和set()方法,方便對(duì)Param封裝后的成員變量進(jìn)行賦值取值操作。

   /**
     * 定義對(duì)Param成員變量進(jìn)行操作的get/set方法, 通過調(diào)用PipelineStage類中已實(shí)現(xiàn)的Params的$()、set()方法對(duì)
     * Param成員變量進(jìn)行操作。
     * $()、set()對(duì)Param進(jìn)行操作前會(huì)調(diào)用shouldOwn(),驗(yàn)證被操作的Param成員變量是否已經(jīng)被維護(hù)到params數(shù)組中
     */
    public String getIdCol() {
        return this.$(idCol);
    }
    public SeqAssembler setIdCol(String value) {
        return (SeqAssembler) this.<String>set(idCol,value);
    }
    public String getRnCol() {
        return this.$(rnCol);
    }
    public SeqAssembler setRnCol(String value) {
        return (SeqAssembler) this.<String>set(rnCol,value);
    }
    public String[] getFeatCols() {
        return this.$(featCols);
    }
    public SeqAssembler setFeatCols(String[] value) {
        return (SeqAssembler) this.<String[]>set(featCols,value);
    }
    public String getOutputCol() {
        return this.$(outputCol);
    }
    public SeqAssembler setOutputCol(String value) {
        return (SeqAssembler) this.<String>set(outputCol,value);
    }
    public Integer getLimitRn() {
        return Integer.parseInt(this.$(limitRn));
    }
    public SeqAssembler setLimitRn(Integer value) {
        return (SeqAssembler) this.<String>set(limitRn,value.toString());
    }

此外,我們還需要為每個(gè)Param定義一個(gè)public方法,因?yàn)镻arams接口會(huì)延遲加載并生成一個(gè)名為params數(shù)組。延遲加載時(shí)通過反射掃描一遍public方法, 將作為返回值的Param成員變量維護(hù)進(jìn)params數(shù)組中。

Params源碼中通過反射延遲加載params數(shù)組:

DefaultParamsReader的load方法中通過params數(shù)組反射構(gòu)造對(duì)象:

如果Param封裝的字段缺乏作用域pubic、無參、返回類型為對(duì)應(yīng)Param的方法,在load過程中通過反射構(gòu)造出的對(duì)象會(huì)出現(xiàn)成員變量缺失,用讀取的metadata裝配時(shí)會(huì)出錯(cuò)。 因此我們需要為每個(gè)Param定義如下方法:

    /**
     * 需要為每個(gè)Param定義一個(gè)public方法, 因?yàn)镻arams會(huì)延遲加載并生成一個(gè)Param[] params數(shù)組,
     * params的生成方式是通過反射掃描一遍public方法, 將作為返回值的Param成員變量維護(hù)進(jìn)params數(shù)組中。
     *
     * org.apache.spark.ml.param.shared下的所有接口都有一個(gè)以Param類型為返回值的方法,也是為了方便子類
     * 通過實(shí)現(xiàn)org.apache.spark.ml.param.shared接口,達(dá)到將Param成員變量維護(hù)進(jìn)params數(shù)組的目的。
     */
    public Param<String> idCol(){
        return idCol;
    }
    public Param<String> rnCol(){
        return rnCol;
    }
    public StringArrayParam featCols(){
        return featCols;
    }
    public Param<String> outputCol(){
        return outputCol;
    }
    public Param<String> limitRn(){
        return limitRn;
    }

如果研究spark ml的源碼不難發(fā)現(xiàn),官方的各個(gè)Transformer子類都實(shí)現(xiàn)org.apache.spark.ml.param.shared包下HasInputCols、HasOutputCol等接口,這些接口下都有一個(gè)滿足以上3要素(public、無參、Param類型返回)的方法,用途與我們上面為每個(gè)Param定義的方法類似。

2.2 實(shí)現(xiàn)抽象方法

接下來,我們需要實(shí)現(xiàn)從Transformer類中繼承來各個(gè)抽象方法,包括transform、transformSchema、copy、uid。

transform方法中包含的是整個(gè)數(shù)據(jù)處理的邏輯,該方法定義的原則是不改變?cè)瓟?shù)據(jù)的字段與條數(shù),只在原數(shù)據(jù)基礎(chǔ)上新增字段。下面實(shí)現(xiàn)的transform方法用于本例中最近幾次下單 特征的提取。

@Override
public Dataset<Row> transform(Dataset<?> dataset) {
    Dataset<Row> df = dataset.toDF();
    String idColName = getIdCol();
    String rnColName = getRnCol();
    String[] featCols = getFeatCols();
    String outputCol = getOutputCol();
    Integer limitRnValue = getLimitRn();
    // 獲取原始數(shù)據(jù)中rn字段下最大值
    Integer maxRN = (Integer) df.groupBy().max(rnColName).first().get(0);
    // 限制設(shè)置的limitRN不得大于maxRn。
    if(limitRnValue>maxRN){
        throw new ValueException(String.format( "the value of limitRn %d is larger than max value of rnCol %d, choose a smaller limitRn instead",
                limitRnValue,maxRN));
    }
//        定義一個(gè)備用的Dataset df_c
    Dataset<Row> df_c = df.select(idColName,rnColName);
    df_c = df_c.withColumnRenamed(idColName,idColName+"_c")
            .withColumnRenamed(rnColName, rnColName+"_c");
//        將df與df_c進(jìn)行連接,連接條件df.idCol==df_c.idCol && df.rnCol<=df_c.rnCol
    Column joinExpr = df.col(idColName).equalTo(df_c.col(idColName+"_c")).and(df.col(rnColName).leq(df_c.col(rnColName+"_c")));
    Dataset<Row> joinedDf = df_c.join(df,joinExpr,"left");
//        打上一列rnCol_p = df_c.rnCol - df.rnCol 最近購(gòu)買次序列,當(dāng)前次的值0
    String pivotRnColName = rnColName+"_p";
    joinedDf = joinedDf.withColumn(pivotRnColName,joinedDf.col(rnColName+"_c").minus(joinedDf.col(rnColName)));
//        表格透視前的準(zhǔn)備工作,定義一些map和array,用于記錄表格透視計(jì)算規(guī)則和透視后的列名
    Map<String, String> featAggMap = new HashMap<>();
    Integer featNums = featCols.length;
    String[] pivotColNames = new String[maxRN*featNums-1];
    String firstPivotColName = "0_min"+"("+featCols[0]+")";
    int n = 0;
    for(int i=0; i<maxRN; i++){
        for(String feat:featCols){
            if(i==0){
                featAggMap.put(feat,"min");
            }
            if(n>0){
                pivotColNames[n - 1] = String.valueOf(i) + "_min" + "(" + feat + ")";
            }
            n++;
        }
    }
//        對(duì)表格進(jìn)行透視、特征字段合并、得到outputCol
    Dataset<Row> transformed = joinedDf.groupBy(joinedDf.col(idColName+"_c"), joinedDf.col(rnColName+"_c")).pivot(pivotRnColName).agg(featAggMap);
    transformed = transformed.withColumn(outputCol, functions.array(firstPivotColName,pivotColNames));
//        將outputCol連到原df上,保證經(jīng)過transform后的df只在原數(shù)據(jù)基礎(chǔ)上新增一列
    Column joinExprT = df.col(idColName).equalTo(transformed.col(idColName+"_c")).and(df.col(rnColName).equalTo(transformed.col(rnColName+"_c")));
    df = df.join(transformed.select(idColName+"_c",rnColName+"_c",outputCol),joinExprT,"left").drop(idColName+"_c",rnColName+"_c");
    return df;
}

實(shí)現(xiàn)transformSchema方法,通常在其中定義輸入數(shù)據(jù)類型判斷的邏輯,并返回一個(gè)與transform方法輸出的Dataset相對(duì)應(yīng)的schema:

    @Override
    /**
     * transformSchema中定義輸入數(shù)據(jù)類型判斷的邏輯,并返回一個(gè)與transform方法輸出的Dataset相對(duì)應(yīng)的schema
     */
    public StructType transformSchema(StructType schema) {
        HashSet<String> featColSet = new HashSet<String>(Arrays.asList(getFeatCols()));
        StructField[] fields = schema.fields();
        for(StructField field:fields){
            if(featColSet.contains(field.name())){
                if(!field.dataType().sameType(DoubleType)&&!field.dataType().sameType(IntegerType)){
                    throw new TypeConstraintException(String.format("featCol DataType need DoubleType or IntegerType, " +
                            "but column %s is a %s." ,field.name(),field.dataType().typeName()));
                }
            }
        }
        StructType addedSchema = schema.add(getOutputCol(), new VectorUDT(), true);
        return addedSchema;
    }

實(shí)現(xiàn)uid與copy方法:

    @Override
    public Transformer copy(ParamMap extra) {
        return this.<SeqAssembler>defaultCopy(extra);
    }
    @Override
    public String uid() {
        return uid;
    }

最后,我們需要實(shí)現(xiàn)和定義讀寫方法,其中用于寫的兩個(gè)方法write()、save()通過實(shí)現(xiàn) DefaultParamsWritable接口來實(shí)現(xiàn);用于讀的兩個(gè)方法read()、load()直接自定義實(shí)現(xiàn),需要聲明為靜態(tài)方法。

    /**
     * 調(diào)用DefaultParamsWriter和DefaultParamsReader實(shí)現(xiàn)write()/save(), read()/load()方法.
     */
    @Override
    public MLWriter write() {
        MLWriter defaultParamsWriter = new DefaultParamsWriter(this);
        return defaultParamsWriter;
    }
    @Override
    public void save(String path) throws IOException {
        write().saveImpl(path);
    }
    public static MLReader read() {
        MLReader defaultParamsReader = new DefaultParamsReader();
        return defaultParamsReader;
    }
    public static SeqAssembler load(String path) {
        return (SeqAssembler) read().load(path);
    }

最終完整的SeqAssembler類如下:

import jdk.nashorn.internal.runtime.regexp.joni.exception.ValueException;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.StringArrayParam;
import org.apache.spark.ml.util.*;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions;
import static org.apache.spark.sql.types.DataTypes.DoubleType;
import static org.apache.spark.sql.types.DataTypes.IntegerType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.actors.threadpool.Arrays;
import javax.xml.bind.TypeConstraintException;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
/**
 * @author wangjiahui
 * @create 2021-03-12-21:00
 */
public class SeqAssembler extends Transformer implements DefaultParamsWritable {
    private String uid;
    private Param<String> idCol;
    private Param<String> rnCol;
    private StringArrayParam featCols;
    private Param<String> outputCol;
    private Param<String> limitRn;
    /**
     * 定義一個(gè)輔助Param初始化的方法,在構(gòu)造器中對(duì)各Param成員變量進(jìn)行初始化
     */
    public void initParam(){
        idCol = new Param<String>(this,"idCol","Column name for id");
        rnCol = new Param<String>(this,"rnCol","Column name for sequential rn");
        featCols = new StringArrayParam(this,"featCols","Column names of features");
        outputCol = new Param<String>(this,"outputCol","Column name of output");
        limitRn = new Param<String>(this,"limitRn","Column name of limitRn");
    }
    public SeqAssembler() {
        uid = Identifiable$.MODULE$.randomUID("SeqAssembler"); //uid初始化在Param類型成員變量前
        initParam();
    }
    public SeqAssembler(String value){
        uid = value; //uid初始化在Param類型成員變量前
        initParam();
    }
    /**
     * 需要為每個(gè)Param定義一個(gè)public方法, 因?yàn)镻arams會(huì)延遲加載并生成一個(gè)Param[] params數(shù)組,
     * params的生成方式是通過反射掃描一遍public方法, 將作為返回值的Param成員變量維護(hù)進(jìn)params數(shù)組中。
     *
     * org.apache.spark.ml.param.shared下的所有接口都有一個(gè)以Param類型為返回值的方法,也是為了方便子類
     * 通過實(shí)現(xiàn)org.apache.spark.ml.param.shared接口,達(dá)到將Param成員變量維護(hù)進(jìn)params數(shù)組的目的。
     */
    public Param<String> idCol(){
        return idCol;
    }
    public Param<String> rnCol(){
        return rnCol;
    }
    public StringArrayParam featCols(){
        return featCols;
    }
    public Param<String> outputCol(){
        return outputCol;
    }
    public Param<String> limitRn(){
        return limitRn;
    }
    /**
     * 定義對(duì)Param成員變量進(jìn)行操作的get/set方法, 通過調(diào)用PipelineStage類中已實(shí)現(xiàn)的Params的$()、set()方法對(duì)
     * Param成員變量進(jìn)行操作。
     * $()、set()對(duì)Param進(jìn)行操作前會(huì)調(diào)用shouldOwn(),驗(yàn)證被操作的Param成員變量是否已經(jīng)被維護(hù)到params數(shù)組中
     */
    public String getIdCol() {
        return this.$(idCol);
    }
    public SeqAssembler setIdCol(String value) {
        return (SeqAssembler) this.<String>set(idCol,value);
    }
    public String getRnCol() {
        return this.$(rnCol);
    }
    public SeqAssembler setRnCol(String value) {
        return (SeqAssembler) this.<String>set(rnCol,value);
    }
    public String[] getFeatCols() {
        return this.$(featCols);
    }
    public SeqAssembler setFeatCols(String[] value) {
        return (SeqAssembler) this.<String[]>set(featCols,value);
    }
    public String getOutputCol() {
        return this.$(outputCol);
    }
    public SeqAssembler setOutputCol(String value) {
        return (SeqAssembler) this.<String>set(outputCol,value);
    }
    public Integer getLimitRn() {
        return Integer.parseInt(this.$(limitRn));
    }
    public SeqAssembler setLimitRn(Integer value) {
        return (SeqAssembler) this.<String>set(limitRn,value.toString());
    }
    @Override
    public Dataset<Row> transform(Dataset<?> dataset) {
        Dataset<Row> df = dataset.toDF();
        transformSchema(dataset.schema());
        String idColName = getIdCol();
        String rnColName = getRnCol();
        String[] featCols = getFeatCols();
        String outputCol = getOutputCol();
        Integer limitRnValue = getLimitRn();
        // 獲取原始數(shù)據(jù)中rn字段下最大值
        Integer maxRN = (Integer) df.groupBy().max(rnColName).first().get(0);
        // 限制設(shè)置的limitRN不得大于maxRn。
        if(limitRnValue>maxRN){
            throw new ValueException(String.format( "the value of limitRn %d is larger than max value of rnCol %d, choose a smaller limitRn instead",
                    limitRnValue,maxRN));
        }
//        定義一個(gè)備用的Dataset df_c
        Dataset<Row> df_c = df.select(idColName,rnColName);
        df_c = df_c.withColumnRenamed(idColName,idColName+"_c")
                .withColumnRenamed(rnColName, rnColName+"_c");
//        將df與df_c進(jìn)行連接,連接條件df.idCol==df_c.idCol && df.rnCol<=df_c.rnCol
        Column joinExpr = df.col(idColName).equalTo(df_c.col(idColName+"_c")).and(df.col(rnColName).leq(df_c.col(rnColName+"_c")));
        Dataset<Row> joinedDf = df_c.join(df,joinExpr,"left");
//        打上一列rnCol_p = df_c.rnCol - df.rnCol 最近購(gòu)買次序列,當(dāng)前次的值0
        String pivotRnColName = rnColName+"_p";
        joinedDf = joinedDf.withColumn(pivotRnColName,joinedDf.col(rnColName+"_c").minus(joinedDf.col(rnColName)));
//        表格透視前的準(zhǔn)備工作,定義一些map和array,用于記錄表格透視計(jì)算規(guī)則和透視后的列名
        Map<String, String> featAggMap = new HashMap<>();
        Integer featNums = featCols.length;
        String[] pivotColNames = new String[maxRN*featNums-1];
        String firstPivotColName = "0_min"+"("+featCols[0]+")";
        int n = 0;
        for(int i=0; i<maxRN; i++){
            for(String feat:featCols){
                if(i==0){
                    featAggMap.put(feat,"min");
                }
                if(n>0){
                    pivotColNames[n - 1] = String.valueOf(i) + "_min" + "(" + feat + ")";
                }
                n++;
            }
        }
//        對(duì)表格進(jìn)行透視、特征字段合并、得到outputCol
        Dataset<Row> transformed = joinedDf.groupBy(joinedDf.col(idColName+"_c"), joinedDf.col(rnColName+"_c")).pivot(pivotRnColName).agg(featAggMap);
        transformed = transformed.withColumn(outputCol, functions.array(firstPivotColName,pivotColNames));
//        將outputCol連到原df上,保證經(jīng)過transform后的df只在原數(shù)據(jù)基礎(chǔ)上新增一列
        Column joinExprT = df.col(idColName).equalTo(transformed.col(idColName+"_c")).and(df.col(rnColName).equalTo(transformed.col(rnColName+"_c")));
        df = df.join(transformed.select(idColName+"_c",rnColName+"_c",outputCol),joinExprT,"left").drop(idColName+"_c",rnColName+"_c");
        return df;
    }
    @Override
    /**
     * transformSchema中定義輸入數(shù)據(jù)類型判斷的邏輯,并返回一個(gè)與transform方法輸出的Dataset相對(duì)應(yīng)的schema
     */
    public StructType transformSchema(StructType schema) {
        HashSet<String> featColSet = new HashSet<String>(Arrays.asList(getFeatCols()));
        StructField[] fields = schema.fields();
        for(StructField field:fields){
            if(featColSet.contains(field.name())){
                if(!field.dataType().sameType(DoubleType)&&!field.dataType().sameType(IntegerType)){
                    throw new TypeConstraintException(String.format("featCol DataType need DoubleType or IntegerType, " +
                            "but column %s is a %s.",field.name(),field.dataType().typeName()));
                }
            }
        }
        StructType addedSchema = schema.add(getOutputCol(), new VectorUDT(), true);
        return addedSchema;
    }
    @Override
    public Transformer copy(ParamMap extra) {
        return this.<SeqAssembler>defaultCopy(extra);
    }
    @Override
    public String uid() {
        return uid;
    }
    /**
     * 調(diào)用DefaultParamsWriter和DefaultParamsReader實(shí)現(xiàn)write()/save(), read()/load()方法.
     */
    @Override
    public MLWriter write() {
        MLWriter defaultParamsWriter = new DefaultParamsWriter(this);
        return defaultParamsWriter;
    }
    @Override
    public void save(String path) throws IOException {
        write().saveImpl(path);
    }
    public static MLReader read() {
        MLReader defaultParamsReader = new DefaultParamsReader();
        return defaultParamsReader;
    }
    public static SeqAssembler load(String path) {
        return (SeqAssembler) read().load(path);
    }
}

單元測(cè)試代碼如下:

import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.Transformer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.junit.Test;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static org.apache.spark.sql.types.DataTypes.*;
import static org.apache.spark.sql.types.DataTypes.IntegerType;
/**
 * @author wangjiahui
 * @create 2023-01-25-20:51
 */
public class TestClient {
    @Test
    public void testSeqAssembler(){
        // 配置自己的SparkSession
         SparkSession spark = LocalSparkSession.getSpark();
        // 定義一個(gè)測(cè)試用的DataSet
        List<Row> rows = new ArrayList<>();
        Row row1 = RowFactory.create("a",1, 2.1, 1, 1);
        Row row2 = RowFactory.create("b",1, 2.0, 3, 2);
        Row row3 = RowFactory.create("b",2, 2.3, 4, 1);
        Row row4 = RowFactory.create("c",1, 3.1, 3, 3);
        Row row5 = RowFactory.create("c",2, 1.5, 3, 7);
        Row row6 = RowFactory.create("c",3, 4.2, 4, 2);
        rows.add(row1);
        rows.add(row2);
        rows.add(row3);
        rows.add(row4);
        rows.add(row5);
        rows.add(row6);
        List<StructField> fields = new ArrayList<StructField>();
        StructField col1 = DataTypes.createStructField("user_id", StringType, true);
        StructField col2 = DataTypes.createStructField("buy_rn", IntegerType, true);
        StructField col3 = DataTypes.createStructField("feat_1", DoubleType, true);
        StructField col4 = DataTypes.createStructField("feat_2", IntegerType, true);
        StructField col5= DataTypes.createStructField("feat_3", IntegerType, true);
        fields.add(col1);
        fields.add(col2);
        fields.add(col3);
        fields.add(col4);
        fields.add(col5);
        StructType schema = DataTypes.createStructType(fields);
        Dataset dfr = spark.createDataFrame(rows,schema);
        Dataset<Row> df = dfr.toDF();
        df = df.persist();
        System.out.println("in:");
        df.show();
        df.printSchema();
        // 定義兩個(gè)seqAssembler
        String[] featCols = new String[] {"feat_1", "feat_2", "feat_3"};
        SeqAssembler seqAssembler1 = new SeqAssembler()
                .setIdCol("user_id")
                .setRnCol("buy_rn")
                .setLimitRn(3)
                .setFeatCols(featCols)
                .setOutputCol("features");
        SeqAssembler seqAssembler2 = new SeqAssembler()
                .setIdCol("user_id")
                .setRnCol("buy_rn")
                .setLimitRn(2)
                .setFeatCols(featCols)
                .setOutputCol("features");
//        Dataset<Row> transformed = seqAssembler.transform(df);
        // 定義pipeline
        List<PipelineStage> pipelineStages = new ArrayList<>();
        pipelineStages.add(seqAssembler1);
        pipelineStages.add(seqAssembler2);
        Pipeline pipeline = new Pipeline();
        pipeline.setStages(pipelineStages.toArray(new PipelineStage[pipelineStages.size()]));
        // 寫入
        try {
            pipeline.write().overwrite().save("oss://<自己的路徑>");
        } catch (IOException e) {
            e.printStackTrace();
        }
        // 讀取
        Pipeline loadedPipeline = Pipeline.load("oss://<自己的路徑>");
        Transformer seqAssemblerLoad = (Transformer) loadedPipeline.getStages()[0];
        // 使用
        Dataset<Row> transformed = seqAssemblerLoad.transform(df);
        System.out.println("out: ");
        transformed.show(false);
        transformed.printSchema();
        spark.close();
    }
}

3. Pipeline的存儲(chǔ)文件

在oss/hdfs上找到上面單元測(cè)試中pipeline的存儲(chǔ)路徑,并將存儲(chǔ)文件夾下載到本地,pipeline存儲(chǔ)文件夾中包含metadata, stages兩個(gè)目錄,metadata中存放的是pipeline的信息,包括pipeline的uid、對(duì)應(yīng)stage的uid等。pipeline metadata文件如下:

stages目錄中存放的是我們定義的兩個(gè)SeqAssembler的metadata,SeqAssembler的metadata中的文件內(nèi)容與pipeline的metadata中的文件內(nèi)容類似,記錄了SeqAssembler相關(guān)信息與Param數(shù)據(jù):

小結(jié)

在這篇文章中我們介紹了使用java開發(fā)spark如何自定義PipelineStage,并用一個(gè)SeqAssembler的例子對(duì)自定義PipelineStage中的一些注意事項(xiàng)進(jìn)行了說明。相信這篇文章對(duì)不少java的spark開發(fā)者有一定的幫助。

以上就是Java開發(fā)Spark應(yīng)用程序自定義PipeLineStage詳解的詳細(xì)內(nèi)容,更多關(guān)于Java Spark自定義PipeLineStage的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • Java連接SAP RFC實(shí)現(xiàn)數(shù)據(jù)抽取的示例詳解

    Java連接SAP RFC實(shí)現(xiàn)數(shù)據(jù)抽取的示例詳解

    這篇文章主要為大家學(xué)習(xí)介紹了Java如何連接SAP RFC實(shí)現(xiàn)數(shù)據(jù)抽取的功能,文中的示例代碼講解詳細(xì),具有一定的參考價(jià)值,需要的可以了解下
    2023-08-08
  • Java中接口的多態(tài)詳解

    Java中接口的多態(tài)詳解

    大家好,本篇文章主要講的是Java中接口的多態(tài)詳解,感興趣的同學(xué)趕快來看一看吧,對(duì)你有幫助的話記得收藏一下
    2022-02-02
  • Java的main方法使用及說明

    Java的main方法使用及說明

    這篇文章主要介紹了Java的main方法使用及說明,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2024-08-08
  • SpringBoot整合Lombok插件與使用詳解

    SpringBoot整合Lombok插件與使用詳解

    Lombok是Java開發(fā)的插件,通過注解自動(dòng)生成常用代碼,如getter/setter,節(jié)省開發(fā)時(shí)間,提高效率,它在編譯期生成方法,不影響性能,安裝Lombok需要添加Maven依賴和IDEA插件,使用注解如@Data、@Getter等簡(jiǎn)化代碼編寫,官網(wǎng)提供詳細(xì)文檔
    2024-09-09
  • java組裝樹形結(jié)構(gòu)List問題

    java組裝樹形結(jié)構(gòu)List問題

    這篇文章主要介紹了java組裝樹形結(jié)構(gòu)List問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2023-08-08
  • 一文秒懂java到底是值傳遞還是引用傳遞

    一文秒懂java到底是值傳遞還是引用傳遞

    這篇文章主要介紹了java到底是值傳遞還是引用傳遞的相關(guān)知識(shí),本文通過幾個(gè)例子給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2020-06-06
  • 使用Spring?Boot?2.x構(gòu)建Web服務(wù)的詳細(xì)代碼

    使用Spring?Boot?2.x構(gòu)建Web服務(wù)的詳細(xì)代碼

    這篇文章主要介紹了使用Spring?Boot?2.x構(gòu)建Web服務(wù)的詳細(xì)代碼,主要基于JWT的身份認(rèn)證,本文通過實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2022-03-03
  • Java基礎(chǔ)之常用的命令行指令

    Java基礎(chǔ)之常用的命令行指令

    這篇文章主要介紹了Java基礎(chǔ)之常用的命令行指令,文中有非常詳細(xì)的圖文示例,對(duì)正在學(xué)習(xí)java基礎(chǔ)的小伙伴們有非常好的幫助,需要的朋友可以參考下
    2021-04-04
  • Maven中兩個(gè)命令clean 和 install的使用

    Maven中兩個(gè)命令clean 和 install的使用

    Maven是一個(gè)項(xiàng)目管理和自動(dòng)構(gòu)建工具,clean命令用于刪除項(xiàng)目中由先前構(gòu)建生成的target目錄,install命令用于將打包好的jar包安裝到本地倉(cāng)庫中,供其他項(xiàng)目依賴使用,下面就來詳細(xì)的介紹一下這兩個(gè)命令
    2024-09-09
  • Java代碼實(shí)現(xiàn)隨機(jī)生成漢字的方法

    Java代碼實(shí)現(xiàn)隨機(jī)生成漢字的方法

    今天小編就為大家分享一篇關(guān)于Java代碼實(shí)現(xiàn)隨機(jī)生成漢字的方法,小編覺得內(nèi)容挺不錯(cuò)的,現(xiàn)在分享給大家,具有很好的參考價(jià)值,需要的朋友一起跟隨小編來看看吧
    2019-03-03

最新評(píng)論