Skip to content

Commit

Permalink
[HUDI-4217] improve repeat init object in ExpressionPayload (apache#5825
Browse files Browse the repository at this point in the history
)
  • Loading branch information
KnightChess authored Jun 15, 2022
1 parent c291b05 commit 2bf0a19
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,59 @@

package org.apache.spark.sql.hudi.command.payload

import com.google.common.cache.CacheBuilder
import org.apache.avro.Schema
import org.apache.avro.generic.IndexedRecord
import org.apache.hudi.{AvroConversionUtils, SparkAdapterSupport}
import org.apache.hudi.HoodieSparkUtils.sparkAdapter
import org.apache.hudi.AvroConversionUtils
import org.apache.spark.sql.avro.HoodieAvroDeserializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.hudi.command.payload.SqlTypedRecord.{getAvroDeserializer, getSqlType}
import org.apache.spark.sql.types.StructType

import java.util.concurrent.Callable

/**
* A sql typed record which will convert the avro field to sql typed value.
*/
class SqlTypedRecord(val record: IndexedRecord) extends IndexedRecord with SparkAdapterSupport {
class SqlTypedRecord(val record: IndexedRecord) extends IndexedRecord {

private lazy val sqlType = AvroConversionUtils.convertAvroSchemaToStructType(getSchema)
private lazy val avroDeserializer = sparkAdapter.createAvroDeserializer(record.getSchema, sqlType)
private lazy val sqlRow = avroDeserializer.deserialize(record).get.asInstanceOf[InternalRow]
private lazy val sqlRow = getAvroDeserializer(getSchema).deserialize(record).get.asInstanceOf[InternalRow]

override def put(i: Int, v: Any): Unit = {
record.put(i, v)
}

override def get(i: Int): AnyRef = {
sqlRow.get(i, sqlType(i).dataType)
sqlRow.get(i, getSqlType(getSchema)(i).dataType)
}

override def getSchema: Schema = record.getSchema
}

object SqlTypedRecord {

private val sqlTypeCache = CacheBuilder.newBuilder().build[Schema, StructType]()

private val avroDeserializerCache = CacheBuilder.newBuilder().build[Schema, HoodieAvroDeserializer]()

def getSqlType(schema: Schema): StructType = {
sqlTypeCache.get(schema, new Callable[StructType] {
override def call(): StructType = {
val structType = AvroConversionUtils.convertAvroSchemaToStructType(schema)
sqlTypeCache.put(schema, structType)
structType
}
})
}

def getAvroDeserializer(schema: Schema): HoodieAvroDeserializer= {
avroDeserializerCache.get(schema, new Callable[HoodieAvroDeserializer] {
override def call(): HoodieAvroDeserializer = {
val deserializer = sparkAdapter.createAvroDeserializer(schema, getSqlType(schema))
avroDeserializerCache.put(schema, deserializer)
deserializer
}
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.hudi.sql.IExpressionEvaluator
import org.apache.spark.sql.avro.{AvroSerializer, SchemaConverters}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.hudi.SerDeUtils
import org.apache.spark.sql.hudi.command.payload.ExpressionPayload.getEvaluator
import org.apache.spark.sql.hudi.command.payload.ExpressionPayload.{getEvaluator, getMergedSchema}
import org.apache.spark.sql.types.{StructField, StructType}

import java.util.concurrent.Callable
Expand Down Expand Up @@ -228,9 +228,7 @@ class ExpressionPayload(record: GenericRecord,
*/
private def joinRecord(sourceRecord: IndexedRecord, targetRecord: IndexedRecord): IndexedRecord = {
val leftSchema = sourceRecord.getSchema
// the targetRecord is load from the disk, it contains the meta fields, so we remove it here
val rightSchema = HoodieAvroUtils.removeMetadataFields(targetRecord.getSchema)
val joinSchema = mergeSchema(leftSchema, rightSchema)
val joinSchema = getMergedSchema(leftSchema, targetRecord.getSchema)

val values = new ArrayBuffer[AnyRef]()
for (i <- 0 until joinSchema.getFields.size()) {
Expand All @@ -244,17 +242,6 @@ class ExpressionPayload(record: GenericRecord,
convertToRecord(values.toArray, joinSchema)
}

private def mergeSchema(a: Schema, b: Schema): Schema = {
val mergedFields =
a.getFields.asScala.map(field =>
new Schema.Field("a_" + field.name,
field.schema, field.doc, field.defaultVal, field.order)) ++
b.getFields.asScala.map(field =>
new Schema.Field("b_" + field.name,
field.schema, field.doc, field.defaultVal, field.order))
Schema.createRecord(a.getName, a.getDoc, a.getNamespace, a.isError, mergedFields.asJava)
}

private def evaluate(evaluator: IExpressionEvaluator, sqlTypedRecord: SqlTypedRecord): GenericRecord = {
try evaluator.eval(sqlTypedRecord) catch {
case e: Throwable =>
Expand Down Expand Up @@ -318,5 +305,30 @@ object ExpressionPayload {
}
})
}

private val mergedSchemaCache = CacheBuilder.newBuilder().build[TupleSchema, Schema]()

def getMergedSchema(source: Schema, target: Schema): Schema = {

mergedSchemaCache.get(TupleSchema(source, target), new Callable[Schema] {
override def call(): Schema = {
val rightSchema = HoodieAvroUtils.removeMetadataFields(target)
mergeSchema(source, rightSchema)
}
})
}

def mergeSchema(a: Schema, b: Schema): Schema = {
val mergedFields =
a.getFields.asScala.map(field =>
new Schema.Field("a_" + field.name,
field.schema, field.doc, field.defaultVal, field.order)) ++
b.getFields.asScala.map(field =>
new Schema.Field("b_" + field.name,
field.schema, field.doc, field.defaultVal, field.order))
Schema.createRecord(a.getName, a.getDoc, a.getNamespace, a.isError, mergedFields.asJava)
}

case class TupleSchema(first: Schema, second: Schema)
}

0 comments on commit 2bf0a19

Please sign in to comment.