Skip to content

Commit

Permalink
[SPARK-44343][CONNECT] Prepare ScalaReflection to the move to SQL/API
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR moves all catalyst specific internals out of ScalaReflection into other catalyst classes:
- Serializer Expression Generation is moved to `SerializerBuildHelper`.
- Deaerializer Expression Generation is moved to `DeserializerBuildHelper`.
- Common utils are moved to `EncoderUtils`.

### Why are the changes needed?
We want to use ScalaReflection based encoder inference both for SQL/Core and Connect.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing tests.

Closes apache#41920 from hvanhovell/SPARK-44343.

Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
  • Loading branch information
hvanhovell committed Jul 10, 2023
1 parent f1ec99b commit f6866d1
Show file tree
Hide file tree
Showing 13 changed files with 630 additions and 617 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@

package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, StaticInvoke}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, MapKeys, MapValues, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.types._

object DeserializerBuildHelper {
Expand Down Expand Up @@ -193,4 +197,246 @@ object DeserializerBuildHelper {
UpCast(expr, DecimalType, walkedTypePath.getPaths)
case _ => UpCast(expr, expected, walkedTypePath.getPaths)
}

/**
* Returns an expression for deserializing the Spark SQL representation of an object into its
* external form. The mapping between the internal and external representations is
* described by encoder `enc`. The Spark SQL representation is located at ordinal 0 of
* a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed using
* `UnresolvedExtractValue`.
*
* The returned expression is used by `ExpressionEncoder`. The encoder will resolve and bind this
* deserializer expression when using it.
*
* @param enc encoder that describes the mapping between the Spark SQL representation and the
* external representation.
*/
def createDeserializer[T](enc: AgnosticEncoder[T]): Expression = {
val walkedTypePath = WalkedTypePath().recordRoot(enc.clsTag.runtimeClass.getName)
// Assumes we are deserializing the first column of a row.
val input = GetColumnByOrdinal(0, enc.dataType)
enc match {
case AgnosticEncoders.RowEncoder(fields) =>
val children = fields.zipWithIndex.map { case (f, i) =>
createDeserializer(f.enc, GetStructField(input, i), walkedTypePath)
}
CreateExternalRow(children, enc.schema)
case _ =>
val deserializer = createDeserializer(
enc,
upCastToExpectedType(input, enc.dataType, walkedTypePath),
walkedTypePath)
expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath)
}
}

/**
* Returns an expression for deserializing the value of an input expression into its external
* representation. The mapping between the internal and external representations is
* described by encoder `enc`.
*
* @param enc encoder that describes the mapping between the Spark SQL representation and the
* external representation.
* @param path The expression which can be used to extract serialized value.
* @param walkedTypePath The paths from top to bottom to access current field when deserializing.
*/
private def createDeserializer(
enc: AgnosticEncoder[_],
path: Expression,
walkedTypePath: WalkedTypePath): Expression = enc match {
case _ if isNativeEncoder(enc) =>
path
case _: BoxedLeafEncoder[_, _] =>
createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
case JavaEnumEncoder(tag) =>
val toString = createDeserializerForString(path, returnNullable = false)
createDeserializerForTypesSupportValueOf(toString, tag.runtimeClass)
case ScalaEnumEncoder(parent, tag) =>
StaticInvoke(
parent,
ObjectType(tag.runtimeClass),
"withName",
createDeserializerForString(path, returnNullable = false) :: Nil,
returnNullable = false)
case StringEncoder =>
createDeserializerForString(path, returnNullable = false)
case _: ScalaDecimalEncoder =>
createDeserializerForScalaBigDecimal(path, returnNullable = false)
case _: JavaDecimalEncoder =>
createDeserializerForJavaBigDecimal(path, returnNullable = false)
case ScalaBigIntEncoder =>
createDeserializerForScalaBigInt(path)
case JavaBigIntEncoder =>
createDeserializerForJavaBigInteger(path, returnNullable = false)
case DayTimeIntervalEncoder =>
createDeserializerForDuration(path)
case YearMonthIntervalEncoder =>
createDeserializerForPeriod(path)
case _: DateEncoder =>
createDeserializerForSqlDate(path)
case _: LocalDateEncoder =>
createDeserializerForLocalDate(path)
case _: TimestampEncoder =>
createDeserializerForSqlTimestamp(path)
case _: InstantEncoder =>
createDeserializerForInstant(path)
case LocalDateTimeEncoder =>
createDeserializerForLocalDateTime(path)
case UDTEncoder(udt, udtClass) =>
val obj = NewInstance(udtClass, Nil, ObjectType(udtClass))
Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
case OptionEncoder(valueEnc) =>
val newTypePath = walkedTypePath.recordOption(valueEnc.clsTag.runtimeClass.getName)
val deserializer = createDeserializer(valueEnc, path, newTypePath)
WrapOption(deserializer, externalDataTypeFor(valueEnc))

case ArrayEncoder(elementEnc: AgnosticEncoder[_], containsNull) =>
Invoke(
deserializeArray(
path,
elementEnc,
containsNull,
None,
walkedTypePath),
toArrayMethodName(elementEnc),
ObjectType(enc.clsTag.runtimeClass),
returnNullable = false)

case IterableEncoder(clsTag, elementEnc, containsNull, _) =>
deserializeArray(
path,
elementEnc,
containsNull,
Option(clsTag.runtimeClass),
walkedTypePath)

case MapEncoder(tag, keyEncoder, valueEncoder, _)
if classOf[java.util.Map[_, _]].isAssignableFrom(tag.runtimeClass) =>
// TODO (hvanhovell) this is can be improved.
val newTypePath = walkedTypePath.recordMap(
keyEncoder.clsTag.runtimeClass.getName,
valueEncoder.clsTag.runtimeClass.getName)

val keyData =
Invoke(
UnresolvedMapObjects(
p => createDeserializer(keyEncoder, p, newTypePath),
MapKeys(path)),
"array",
ObjectType(classOf[Array[Any]]))

val valueData =
Invoke(
UnresolvedMapObjects(
p => createDeserializer(valueEncoder, p, newTypePath),
MapValues(path)),
"array",
ObjectType(classOf[Array[Any]]))

StaticInvoke(
ArrayBasedMapData.getClass,
ObjectType(classOf[java.util.Map[_, _]]),
"toJavaMap",
keyData :: valueData :: Nil,
returnNullable = false)

case MapEncoder(tag, keyEncoder, valueEncoder, _) =>
val newTypePath = walkedTypePath.recordMap(
keyEncoder.clsTag.runtimeClass.getName,
valueEncoder.clsTag.runtimeClass.getName)
UnresolvedCatalystToExternalMap(
path,
createDeserializer(keyEncoder, _, newTypePath),
createDeserializer(valueEncoder, _, newTypePath),
tag.runtimeClass)

case ProductEncoder(tag, fields) =>
val cls = tag.runtimeClass
val dt = ObjectType(cls)
val isTuple = cls.getName.startsWith("scala.Tuple")
val arguments = fields.zipWithIndex.map {
case (field, i) =>
val newTypePath = walkedTypePath.recordField(
field.enc.clsTag.runtimeClass.getName,
field.name)
// For tuples, we grab the inner fields by ordinal instead of name.
val getter = if (isTuple) {
addToPathOrdinal(path, i, field.enc.dataType, newTypePath)
} else {
addToPath(path, field.name, field.enc.dataType, newTypePath)
}
expressionWithNullSafety(
createDeserializer(field.enc, getter, newTypePath),
field.enc.nullable,
newTypePath)
}
exprs.If(
IsNull(path),
exprs.Literal.create(null, dt),
NewInstance(cls, arguments, dt, propagateNull = false))

case AgnosticEncoders.RowEncoder(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
val newTypePath = walkedTypePath.recordField(
f.enc.clsTag.runtimeClass.getName,
f.name)
exprs.If(
Invoke(path, "isNullAt", BooleanType, exprs.Literal(i) :: Nil),
exprs.Literal.create(null, externalDataTypeFor(f.enc)),
createDeserializer(f.enc, GetStructField(path, i), newTypePath))
}
exprs.If(IsNull(path),
exprs.Literal.create(null, externalDataTypeFor(enc)),
CreateExternalRow(convertedFields, enc.schema))

case JavaBeanEncoder(tag, fields) =>
val setters = fields.map { f =>
val newTypePath = walkedTypePath.recordField(
f.enc.clsTag.runtimeClass.getName,
f.name)
val setter = expressionWithNullSafety(
createDeserializer(
f.enc,
addToPath(path, f.name, f.enc.dataType, newTypePath),
newTypePath),
nullable = f.nullable,
newTypePath)
f.writeMethod.get -> setter
}

val cls = tag.runtimeClass
val newInstance = NewInstance(cls, Nil, ObjectType(cls), propagateNull = false)
val result = InitializeJavaBean(newInstance, setters.toMap)
exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result)
}

private def deserializeArray(
path: Expression,
elementEnc: AgnosticEncoder[_],
containsNull: Boolean,
cls: Option[Class[_]],
walkedTypePath: WalkedTypePath): Expression = {
val newTypePath = walkedTypePath.recordArray(elementEnc.clsTag.runtimeClass.getName)
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expects.
deserializerForWithNullSafetyAndUpcast(
element,
elementEnc.dataType,
nullable = containsNull,
newTypePath,
createDeserializer(elementEnc, _, newTypePath))
}
UnresolvedMapObjects(mapFunction, path, cls)
}

private def toArrayMethodName(enc: AgnosticEncoder[_]): String = enc match {
case PrimitiveBooleanEncoder => "toBooleanArray"
case PrimitiveByteEncoder => "toByteArray"
case PrimitiveShortEncoder => "toShortArray"
case PrimitiveIntEncoder => "toIntArray"
case PrimitiveLongEncoder => "toLongArray"
case PrimitiveFloatEncoder => "toFloatArray"
case PrimitiveDoubleEncoder => "toDoubleArray"
case _ => "array"
}
}
Loading

0 comments on commit f6866d1

Please sign in to comment.