美文网首页
Spark开发--Spark SQL--自定义函数(十五)

Spark开发--Spark SQL--自定义函数(十五)

作者: 无剑_君 | 来源:发表于2020-04-14 16:23 被阅读0次

函数:http://spark.apache.org/docs/latest/api/sql/index.html

一、自定义函数简介

在Spark中,也支持Hive中的自定义函数。自定义函数大致可以分为三种:
UDF(User-Defined-Function),即最基本的自定义函数,类似to_char,to_date等
UDAF(User- Defined Aggregation Funcation),用户自定义聚合函数,类似在group by之后使用的sum,avg等

二、自定义UDF函数

自定义一个UDF函数需要继承UserDefinedAggregateFunction类,并实现其中的8个方法。
通过spark.udf.register("funcName", func) 来进行注册
使用:select funcName(name) from people 来直接使用

1. 匿名函数注册UDF

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

object MySpark {

  def main(args: Array[String]) {
    // 定义应用名称
    val conf = new SparkConf().setAppName("mySpark0")
    conf.setMaster("spark://master:7077")
    conf.setJars(Seq("/root/SparkTest.jar"))

    // 创建SparkSession对象
    val spark = SparkSession.builder()
      .appName("DataFrameAPP")
      .config(conf)
      .getOrCreate()

    // 构造测试数据,有两个字段、名字和年龄
    val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
    //创建测试df
    val userDF = spark.createDataFrame(userData).toDF("name", "age")
    userDF.show
    // 注册一张user表
    userDF.createOrReplaceTempView("user")
    // 匿名方法注册函数
    spark.udf.register("strLen", (str: String) => str.length())
    // 函数使用
    spark.sql("select name,strLen(name) as name_len from user").show

  }
}

// 执行结果
+-----+--------+
| name|name_len|
+-----+--------+
|  Leo|       3|
|Marry|       5|
| Jack|       4|
|  Tom|       3|
+-----+--------+

2. 实名函数注册UDF

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

object MySpark {

  def main(args: Array[String]) {
    // 定义应用名称
    val conf = new SparkConf().setAppName("mySpark0")
    conf.setMaster("spark://master:7077")
    conf.setJars(Seq("/root/SparkTest.jar"))

    // 创建SparkSession对象
    val spark = SparkSession.builder()
      .appName("DataFrameAPP")
      .config(conf)
      .getOrCreate()

    // 构造测试数据,有两个字段、名字和年龄
    val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
    //创建测试df
    val userDF = spark.createDataFrame(userData).toDF("name", "age")
    userDF.show
    // 注册一张user表
    userDF.createOrReplaceTempView("user")
    // 实名函数注册
    spark.udf.register("isAdult", isAdult _)
    // 使用
    spark.sql("select name,isAdult(age) as age from user").show
  }
  /**
    * 自定义函数
    * 根据年龄大小返回是否成年 成年:true,未成年:false
    */
  def isAdult(age: Int) = {
    if (age < 18) {
      false
    } else {
      true
    }
  }
}

// 执行结果
+-----+-----+
| name|  age|
+-----+-----+
|  Leo|false|
|Marry| true|
| Jack|false|
|  Tom| true|
+-----+-----+

3. DataFrame用法

DataFrame的udf方法虽然和Spark Sql的名字一样,但是属于不同的类,它在org.apache.spark.sql.functions里。

import org.apache.spark.sql.functions._
//注册自定义函数(通过匿名函数)
val strLen = udf((str: String) => str.length())
//注册自定义函数(通过实名函数)
val udf_isAdult = udf(isAdult _)

示例:

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

object MySpark {

  def main(args: Array[String]) {
    // 定义应用名称
    val conf = new SparkConf().setAppName("mySpark0")
    conf.setMaster("spark://master:7077")
    conf.setJars(Seq("/root/SparkTest.jar"))

    // 创建SparkSession对象
    val spark = SparkSession.builder()
      .appName("DataFrameAPP")
      .config(conf)
      .getOrCreate()

    // 构造测试数据,有两个字段、名字和年龄
    val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
    //创建测试df
    val userDF = spark.createDataFrame(userData).toDF("name", "age")
    userDF.show
    // 注册一张user表
    userDF.createOrReplaceTempView("user")
    import org.apache.spark.sql.functions._
    // 注册自定义函数(通过匿名函数)
    val strLen = udf((str: String) => str.length())
    //注册自定义函数(通过实名函数)
    val udf_isAdult = udf(isAdult _)
    // 使用,通过withColumn添加列
//    userDF.withColumn("name_len", strLen(col("name")))
//      .withColumn("isAdult", udf_isAdult(col("age"))).show
    // 通过select添加列
    userDF.select(col("*"), strLen(col("name")) as "name_len",
      udf_isAdult(col("age")) as "isAdult").show
  }
  /**
    * 自定义函数
    * 根据年龄大小返回是否成年 成年:true,未成年:false
    */
  def isAdult(age: Int) = {
    if (age < 18) {
      false
    } else {
      true
    }
  }
}

执行结果:

+-----+---+--------+-------+
| name|age|name_len|isAdult|
+-----+---+--------+-------+
|  Leo| 16|       3|  false|
|Marry| 21|       5|   true|
| Jack| 14|       4|  false|
|  Tom| 18|       3|   true|
+-----+---+--------+-------+

三、自定义UDAF函数

UDAF:用户自定义的聚合函数,函数本身作用于数据集合,能够在具体操作的基础上进行自定义操作。
示例1:
update:各个分组的值内部聚合
merge:各个节点的同一分组的值聚合
evaluate:聚合各个分组的缓存值
自定义函数:

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
  *  单词数量统计
  */
object  StringCount extends UserDefinedAggregateFunction {
  /**
    * inputSchema,指的是,输入数据的类型
    *
    * @return
    */
  override def inputSchema: StructType = {
    StructType(Array(StructField("str", StringType, true)))
  }

  /**
    * bufferSchema,指的是,中间进行聚合时,所处理的数据的类型
    *
    * @return
    */
  override def bufferSchema: StructType = {
    // 计数为整型
    StructType(Array(StructField("count", IntegerType, true)))
  }

  /**
    * dataType,指的是,函数返回值的类型
    *
    * @return
    */
  override def dataType: DataType = {
    IntegerType
  }

  override def deterministic: Boolean = {
    true
  }

  /**
    * 为每个分组的数据执行初始化操作
    *
    * @param buffer
    */
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0
  }

  /**
    * 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
    * 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合
    * 聚和发生在reduce端.
    * 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
    * update的结果写入buffer中,每个分组中的每一行数据都要进行update操作
    */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    // 进行累加计数
    buffer(0) = buffer.getAs[Int](0) + 1
  }

  /**
    * 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
    * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
    * 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
    * 也可以是一个节点里面的多个executor合并 reduce端大聚合
    * merge后的结果写如buffer1中
    */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    // 合并操作
    buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
  }

  /**
    * 最后,指的是,一个分组的聚合值,如何通过中间的缓存聚合值,最后返回一个最终的聚合值
    * @param buffer
    * @return
    */
  override def evaluate(buffer: Row): Any = {
    // 返回结果
    buffer.getAs[Int](0)
  }
}

注册,测试:

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

object MySpark {

  def main(args: Array[String]) {
    // 定义应用名称
    val conf = new SparkConf().setAppName("mySpark0")
    conf.setMaster("spark://master:7077")
    conf.setJars(Seq("/root/SparkTest.jar"))

    // 创建SparkSession对象
    val spark = SparkSession.builder()
      .appName("DataFrameAPP")
      .config(conf)
      .getOrCreate()

    // 构造测试数据,有两个字段、名字和年龄
    val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
    //创建测试df
    val userDF = spark.createDataFrame(userData).toDF("name", "age")
    //    userDF.show
    // 注册一张user表
    userDF.createOrReplaceTempView("user")
    // 注册函数:SQLContext.udf.register(), 添加 StringCount对象
    spark.udf.register("strCount", StringCount)
    // 使用自定义函数,单词计数
    spark.sql("select name,strCount(name) from user group by name").show
  }
}

因为新增加函数类,如果报错:
Caused by: java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List$SerializationProxy to field org.apache.spark.sql.execution.aggregate.HashAggregateExec.aggregateExpressions of type scala.collection.Seq
请重新打包,拷贝jar文件到D:\root:

原jar
目标

示例2:
自定义函数:

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
  *  求平均数
  */
class CustomerAvg extends UserDefinedAggregateFunction {
  //输入的类型
  override def inputSchema: StructType = StructType(StructField("salary", LongType) :: Nil)
  //缓存数据的类型
  override def bufferSchema: StructType = {
    StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
  }
  //返回值类型
  override def dataType: DataType = LongType
  //幂等性
  override def deterministic: Boolean = true
  //初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 0L
  }
  //更新 分区内操作
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0)=buffer.getLong(0) +input.getLong(0)
    buffer(1)=buffer.getLong(1)+1L
  }
  //合并 分区与分区之间操作
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
    buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
  }
  //最终执行的方法
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0)/buffer.getLong(1)
  }
}

测试:

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

object MySpark {

  def main(args: Array[String]) {
    // 定义应用名称
    val conf = new SparkConf().setAppName("mySpark0")
    conf.setMaster("spark://master:7077")
    conf.setJars(Seq("/root/SparkTest.jar"))

    // 创建SparkSession对象
    val spark = SparkSession.builder()
      .appName("DataFrameAPP")
      .config(conf)
      .getOrCreate()

    // 构造测试数据,有两个字段、名字和年龄
    val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
    //创建测试df
    val userDF = spark.createDataFrame(userData).toDF("name", "age")
    //  userDF.show
    // 注册一张user表
    userDF.createOrReplaceTempView("user")
    spark.udf.register("MyAvg",new CustomerAvg)
    // 测试
    spark.sql("select MyAvg(age) avg_age from user").show()
  }
}

+-------+
|avg_age|
+-------+
|     17|
+-------+

相关文章

网友评论

      本文标题:Spark开发--Spark SQL--自定义函数(十五)

      本文链接:https://www.haomeiwen.com/subject/letcoctx.html