美文网首页
用隐式类增强DataFrameWriter实现spark对mys

用隐式类增强DataFrameWriter实现spark对mys

作者: zhujh | 来源:发表于2020-06-09 09:39 被阅读0次

实际应用中经常会遇到spark把DataFrame保存到mysql,同时遇重更新无重插入的场景,spark原生save只实现了insert,在遇到唯一性约束时就会抛出异常。为了解决这种问题,我曾用过两种方式,一种是采用foreachPartition,在每个partition里建立connection然后插入数据,另一种方式是在mysql中建立临时表和触发器,spark将DataFrame的数据SaveMode.Append到临时表,临时表的触发器对正式表进行更新。两种方法中,前者需要污染大量代码,后者则把所有压力集中到mysql中,而且因为mysql没有postgresql的return null机制,需要定期清除临时表,可能会引起事务卡死。
scala有很多类似于java和python的语法风格,但隐式implicit是scala独有的特性,尤其是隐式类implicit class,能起到类似于javascript的prototype的作用,能对各种类进行增强,在不修改源代码重新编译的情况下,给类增加方法。
通过源码分析,DataFrameWriter的save方法是通过Datasource的planForWriting更新logicalPlan,在Datasource中根据className所对应类(jdbc对应的是JdbcRelationProvider类)的createRelation方法写入数据库。因此,需要改动的地方并不多,只需要增加一个类似于JdbcRelationProvider的类,其实只要继承并修改createRelation方法即可,并在Datasource更新logicalPlan的时候把className指定成这个类即可。

class MysqlUpdateRelationProvider extends JdbcRelationProvider {
  override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], df: DataFrame): BaseRelation = {
    val options = new JdbcOptionsInWrite(parameters)
    val isCaseSensitive = sqlContext.sparkSession.sessionState.conf.caseSensitiveAnalysis
    val conn = JdbcUtils.createConnectionFactory(options)()
    try {
      val tableExists = JdbcUtils.tableExists(conn, options)
      if (tableExists) {
        mode match {
          case SaveMode.Overwrite =>
            if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
              // In this case, we should truncate table and then load.
              truncateTable(conn, options)
              val tableSchema = JdbcUtils.getSchemaOption(conn, options)
              JdbcUtilsEnhance.updateTable(df, tableSchema, isCaseSensitive, options)
            } else {
              // Otherwise, do not truncate the table, instead drop and recreate it
              dropTable(conn, options.table, options)
              createTable(conn, df, options)
              JdbcUtilsEnhance.updateTable(df, Some(df.schema), isCaseSensitive, options)
            }

          case SaveMode.Append =>
            val tableSchema = JdbcUtils.getSchemaOption(conn, options)
            JdbcUtilsEnhance.updateTable(df, tableSchema, isCaseSensitive, options)

          case SaveMode.ErrorIfExists =>
            throw new Exception(
              s"Table or view '${options.table}' already exists. " +
                s"SaveMode: ErrorIfExists.")

          case SaveMode.Ignore =>
          // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
          // to not save the contents of the DataFrame and to not change the existing data.
          // Therefore, it is okay to do nothing here and then just return the relation below.
        }
      } else {
        createTable(conn, df, options)
        JdbcUtilsEnhance.updateTable(df, Some(df.schema), isCaseSensitive, options)
      }
    } finally {
      conn.close()
    }

    createRelation(sqlContext, parameters)
  }

上述语句几乎完全复制黏贴自父类,只是在JdbcUtilsEnhance.updateTable的地方,原来都是saveTable。

object JdbcUtilsEnhance {
  def updateTable(df: DataFrame,
                  tableSchema: Option[StructType],
                  isCaseSensitive: Boolean,
                  options: JdbcOptionsInWrite): Unit = {
    val url = options.url
    val table = options.table
    val dialect = JdbcDialects.get(url)
    println(dialect)
    val rddSchema = df.schema
    val getConnection: () => Connection = createConnectionFactory(options)
    val batchSize = options.batchSize
    println(batchSize)
    val isolationLevel = options.isolationLevel

    val updateStmt = getUpdateStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
    println(updateStmt)
    val repartitionedDF = options.numPartitions match {
      case Some(n) if n <= 0 => throw new IllegalArgumentException(
        s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " +
          "via JDBC. The minimum value is 1.")
      case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n)
      case _ => df
    }
    repartitionedDF.rdd.foreachPartition(iterator => savePartition(
      getConnection, table, iterator, rddSchema, updateStmt, batchSize, dialect, isolationLevel,
      options)
    )
  }

  def getUpdateStatement(table: String,
                         rddSchema: StructType,
                         tableSchema: Option[StructType],
                         isCaseSensitive: Boolean,
                         dialect: JdbcDialect): String = {
    val columns = if (tableSchema.isEmpty) {
      rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
    } else {
      val columnNameEquality = if (isCaseSensitive) {
        org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
      } else {
        org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
      }
      // The generated insert statement needs to follow rddSchema's column sequence and
      // tableSchema's column names. When appending data into some case-sensitive DBMSs like
      // PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of
      // RDD column names for user convenience.
      val tableColumnNames = tableSchema.get.fieldNames
      rddSchema.fields.map { col =>
        val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
          throw new Exception(s"""Column "${col.name}" not found in schema $tableSchema""")
        }
        dialect.quoteIdentifier(normalizedName)
      }.mkString(",")
    }
    val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
    s"""INSERT INTO $table ($columns) VALUES ($placeholders)
       |ON DUPLICATE KEY UPDATE
       |${columns.split(",").map(col=>s"$col=VALUES($col)").mkString(",")}
       |""".stripMargin
  }
}

增加的两个update相关的方法也几乎也全部复制黏贴自JdbcUtils。
然后是隐士类,由于DataFrameWriter的属性几乎都是private类型,所以需要用到反射。

object DataFrameWriterEnhance {

  implicit class DataFrameWriterMysqlUpdateEnhance(writer: DataFrameWriter[Row]) {
    def update(): Unit = {
      val extraOptionsField = writer.getClass.getDeclaredField("extraOptions")
      val dfField = writer.getClass.getDeclaredField("df")
      val sourceField = writer.getClass.getDeclaredField("source")
      val partitioningColumnsField = writer.getClass.getDeclaredField("partitioningColumns")
      extraOptionsField.setAccessible(true)
      dfField.setAccessible(true)
      sourceField.setAccessible(true)
      partitioningColumnsField.setAccessible(true)
      val extraOptions = extraOptionsField.get(writer).asInstanceOf[scala.collection.mutable.HashMap[String, String]]
      val df = dfField.get(writer).asInstanceOf[sql.DataFrame]
      val partitioningColumns = partitioningColumnsField.get(writer).asInstanceOf[Option[Seq[String]]]
      val logicalPlanField = df.getClass.getDeclaredField("logicalPlan")
      logicalPlanField.setAccessible(true)
      var logicalPlan = logicalPlanField.get(df).asInstanceOf[LogicalPlan]
      val session = df.sparkSession
      val dataSource = DataSource(
        sparkSession = session,
        className = "org.apache.spark.enhance.MysqlUpdateRelationProvider",
        partitionColumns = partitioningColumns.getOrElse(Nil),
        options = extraOptions.toMap)
      logicalPlan = dataSource.planForWriting(SaveMode.Append, logicalPlan)
      val qe = session.sessionState.executePlan(logicalPlan)
      SQLExecution.withNewExecutionId(session, qe)(qe.toRdd)
    }
  }
}

这样在应用中就可以通过update方法,实现对mysql的upsert,下面假设x字段唯一

spark.sparkContext.parallelize(
      Seq(("x1", "测试1"), ("x2", "测试2")), 2
    ).toDF("x", "y")
    .write.format("jdbc").mode(SaveMode.Append)
      .options(Map(
        "url" -> config.database.url,
        "dbtable" -> "foo",
        "user" -> config.database.username,
        "password" -> config.database.password,
        "driver" -> config.database.driver
      )).update()

这种方式对代码污染小,泛用性大,对spark的catalyst没有任何改动,但实现了需求,同时也可以拓展到任意的关系型数据库,甚至稍加改动也可以支持redis等。

相关文章

  • 用隐式类增强DataFrameWriter实现spark对mys

    实际应用中经常会遇到spark把DataFrame保存到mysql,同时遇重更新无重插入的场景,spark原生sa...

  • Scala隐式转换

    隐式转换的作用:偷偷的对某一个类进行增强(添加某一个方法) 隐式转换通用格式1、通常需要一个增强的类,例如Rich...

  • scala隐式转换及其DSL实现(二)

    在实现具体应用前,先介绍下隐式类的概念。 隐式类介绍 Scala 2.10引入了一种叫做隐式类的新特性。隐式类指的...

  • Spark 2.x读写MySQL

    简介 从 spark 2.0 开始,我们可以使用DataFrameReader 和 DataFrameWriter...

  • 隐式转换和隐式参数

    作用:能够丰富现有类库的功能,对类的方法进行增强 隐式转换函数:以implicit关键字声明并带有单个参数的函数 ...

  • 隐式增强

    给AS01做数量的校验 1.第一反应:找出口,十个出口打断点没有一个进入断点的2.第二反应:找BADI,实例化没有...

  • Spark DataFrame使用问题记录:insertInto

    1 问题描述 最近工作中有使用到spark sql的DataFrameWriter.insertInto函数往Hi...

  • 🍎 从示例理解 Scala 隐式转换

    隐式转换和隐式参数是 Scala 的两个功能强大的工具。隐式转换可以丰富现有类的功能,隐式对象是如何被自动呼出以用...

  • Dart 笔记 9 - 类(2)

    抽象类 加 abstract 关键字,抽象方法没有实现体,不需要 abstract 关键字。 隐式接口 每个类都隐...

  • synchronized 以及java内置锁

    Java中锁大致上分为两类:一类是显示锁,一类是隐式锁;今天我们重点来分析一下java中隐式锁的实现: java中...

网友评论

      本文标题:用隐式类增强DataFrameWriter实现spark对mys

      本文链接:https://www.haomeiwen.com/subject/ofnjtktx.html