From f6866d1c1ab986545f2b3c1fb254d6ca0d56c056 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 10 Jul 2023 14:45:04 -0400 Subject: [PATCH] [SPARK-44343][CONNECT] Prepare ScalaReflection to the move to SQL/API ### What changes were proposed in this pull request? This PR moves all catalyst specific internals out of ScalaReflection into other catalyst classes: - Serializer Expression Generation is moved to `SerializerBuildHelper`. - Deaerializer Expression Generation is moved to `DeserializerBuildHelper`. - Common utils are moved to `EncoderUtils`. ### Why are the changes needed? We want to use ScalaReflection based encoder inference both for SQL/Core and Connect. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. Closes #41920 from hvanhovell/SPARK-44343. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../catalyst/DeserializerBuildHelper.scala | 254 +++++++- .../spark/sql/catalyst/ScalaReflection.scala | 573 +----------------- .../sql/catalyst/SerializerBuildHelper.scala | 199 +++++- .../sql/catalyst/encoders/EncoderUtils.scala | 139 +++++ .../catalyst/encoders/ExpressionEncoder.scala | 6 +- .../expressions/V2ExpressionUtils.scala | 5 +- .../expressions/objects/objects.scala | 20 +- .../sql/catalyst/ScalaReflectionSuite.scala | 4 +- .../expressions/ObjectExpressionsSuite.scala | 6 +- .../datasources/parquet/ParquetIOSuite.scala | 4 +- .../parquet/ParquetSchemaSuite.scala | 18 +- .../datasources/parquet/ParquetTest.scala | 3 + .../spark/sql/internal/CatalogSuite.scala | 16 +- 13 files changed, 630 insertions(+), 617 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 41fd5bb239d00..bdf996424adad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -17,10 +17,14 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue -import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, UpCast} -import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, StaticInvoke} -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} +import org.apache.spark.sql.catalyst.{expressions => exprs} +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder} +import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, MapKeys, MapValues, UpCast} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, IntervalUtils} import org.apache.spark.sql.types._ object DeserializerBuildHelper { @@ -193,4 +197,246 @@ object DeserializerBuildHelper { UpCast(expr, DecimalType, walkedTypePath.getPaths) case _ => UpCast(expr, expected, walkedTypePath.getPaths) } + + /** + * Returns an expression for deserializing the Spark SQL representation of an object into its + * external form. The mapping between the internal and external representations is + * described by encoder `enc`. The Spark SQL representation is located at ordinal 0 of + * a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed using + * `UnresolvedExtractValue`. + * + * The returned expression is used by `ExpressionEncoder`. The encoder will resolve and bind this + * deserializer expression when using it. + * + * @param enc encoder that describes the mapping between the Spark SQL representation and the + * external representation. + */ + def createDeserializer[T](enc: AgnosticEncoder[T]): Expression = { + val walkedTypePath = WalkedTypePath().recordRoot(enc.clsTag.runtimeClass.getName) + // Assumes we are deserializing the first column of a row. + val input = GetColumnByOrdinal(0, enc.dataType) + enc match { + case AgnosticEncoders.RowEncoder(fields) => + val children = fields.zipWithIndex.map { case (f, i) => + createDeserializer(f.enc, GetStructField(input, i), walkedTypePath) + } + CreateExternalRow(children, enc.schema) + case _ => + val deserializer = createDeserializer( + enc, + upCastToExpectedType(input, enc.dataType, walkedTypePath), + walkedTypePath) + expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath) + } + } + + /** + * Returns an expression for deserializing the value of an input expression into its external + * representation. The mapping between the internal and external representations is + * described by encoder `enc`. + * + * @param enc encoder that describes the mapping between the Spark SQL representation and the + * external representation. + * @param path The expression which can be used to extract serialized value. + * @param walkedTypePath The paths from top to bottom to access current field when deserializing. + */ + private def createDeserializer( + enc: AgnosticEncoder[_], + path: Expression, + walkedTypePath: WalkedTypePath): Expression = enc match { + case _ if isNativeEncoder(enc) => + path + case _: BoxedLeafEncoder[_, _] => + createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass) + case JavaEnumEncoder(tag) => + val toString = createDeserializerForString(path, returnNullable = false) + createDeserializerForTypesSupportValueOf(toString, tag.runtimeClass) + case ScalaEnumEncoder(parent, tag) => + StaticInvoke( + parent, + ObjectType(tag.runtimeClass), + "withName", + createDeserializerForString(path, returnNullable = false) :: Nil, + returnNullable = false) + case StringEncoder => + createDeserializerForString(path, returnNullable = false) + case _: ScalaDecimalEncoder => + createDeserializerForScalaBigDecimal(path, returnNullable = false) + case _: JavaDecimalEncoder => + createDeserializerForJavaBigDecimal(path, returnNullable = false) + case ScalaBigIntEncoder => + createDeserializerForScalaBigInt(path) + case JavaBigIntEncoder => + createDeserializerForJavaBigInteger(path, returnNullable = false) + case DayTimeIntervalEncoder => + createDeserializerForDuration(path) + case YearMonthIntervalEncoder => + createDeserializerForPeriod(path) + case _: DateEncoder => + createDeserializerForSqlDate(path) + case _: LocalDateEncoder => + createDeserializerForLocalDate(path) + case _: TimestampEncoder => + createDeserializerForSqlTimestamp(path) + case _: InstantEncoder => + createDeserializerForInstant(path) + case LocalDateTimeEncoder => + createDeserializerForLocalDateTime(path) + case UDTEncoder(udt, udtClass) => + val obj = NewInstance(udtClass, Nil, ObjectType(udtClass)) + Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) + case OptionEncoder(valueEnc) => + val newTypePath = walkedTypePath.recordOption(valueEnc.clsTag.runtimeClass.getName) + val deserializer = createDeserializer(valueEnc, path, newTypePath) + WrapOption(deserializer, externalDataTypeFor(valueEnc)) + + case ArrayEncoder(elementEnc: AgnosticEncoder[_], containsNull) => + Invoke( + deserializeArray( + path, + elementEnc, + containsNull, + None, + walkedTypePath), + toArrayMethodName(elementEnc), + ObjectType(enc.clsTag.runtimeClass), + returnNullable = false) + + case IterableEncoder(clsTag, elementEnc, containsNull, _) => + deserializeArray( + path, + elementEnc, + containsNull, + Option(clsTag.runtimeClass), + walkedTypePath) + + case MapEncoder(tag, keyEncoder, valueEncoder, _) + if classOf[java.util.Map[_, _]].isAssignableFrom(tag.runtimeClass) => + // TODO (hvanhovell) this is can be improved. + val newTypePath = walkedTypePath.recordMap( + keyEncoder.clsTag.runtimeClass.getName, + valueEncoder.clsTag.runtimeClass.getName) + + val keyData = + Invoke( + UnresolvedMapObjects( + p => createDeserializer(keyEncoder, p, newTypePath), + MapKeys(path)), + "array", + ObjectType(classOf[Array[Any]])) + + val valueData = + Invoke( + UnresolvedMapObjects( + p => createDeserializer(valueEncoder, p, newTypePath), + MapValues(path)), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + ArrayBasedMapData.getClass, + ObjectType(classOf[java.util.Map[_, _]]), + "toJavaMap", + keyData :: valueData :: Nil, + returnNullable = false) + + case MapEncoder(tag, keyEncoder, valueEncoder, _) => + val newTypePath = walkedTypePath.recordMap( + keyEncoder.clsTag.runtimeClass.getName, + valueEncoder.clsTag.runtimeClass.getName) + UnresolvedCatalystToExternalMap( + path, + createDeserializer(keyEncoder, _, newTypePath), + createDeserializer(valueEncoder, _, newTypePath), + tag.runtimeClass) + + case ProductEncoder(tag, fields) => + val cls = tag.runtimeClass + val dt = ObjectType(cls) + val isTuple = cls.getName.startsWith("scala.Tuple") + val arguments = fields.zipWithIndex.map { + case (field, i) => + val newTypePath = walkedTypePath.recordField( + field.enc.clsTag.runtimeClass.getName, + field.name) + // For tuples, we grab the inner fields by ordinal instead of name. + val getter = if (isTuple) { + addToPathOrdinal(path, i, field.enc.dataType, newTypePath) + } else { + addToPath(path, field.name, field.enc.dataType, newTypePath) + } + expressionWithNullSafety( + createDeserializer(field.enc, getter, newTypePath), + field.enc.nullable, + newTypePath) + } + exprs.If( + IsNull(path), + exprs.Literal.create(null, dt), + NewInstance(cls, arguments, dt, propagateNull = false)) + + case AgnosticEncoders.RowEncoder(fields) => + val convertedFields = fields.zipWithIndex.map { case (f, i) => + val newTypePath = walkedTypePath.recordField( + f.enc.clsTag.runtimeClass.getName, + f.name) + exprs.If( + Invoke(path, "isNullAt", BooleanType, exprs.Literal(i) :: Nil), + exprs.Literal.create(null, externalDataTypeFor(f.enc)), + createDeserializer(f.enc, GetStructField(path, i), newTypePath)) + } + exprs.If(IsNull(path), + exprs.Literal.create(null, externalDataTypeFor(enc)), + CreateExternalRow(convertedFields, enc.schema)) + + case JavaBeanEncoder(tag, fields) => + val setters = fields.map { f => + val newTypePath = walkedTypePath.recordField( + f.enc.clsTag.runtimeClass.getName, + f.name) + val setter = expressionWithNullSafety( + createDeserializer( + f.enc, + addToPath(path, f.name, f.enc.dataType, newTypePath), + newTypePath), + nullable = f.nullable, + newTypePath) + f.writeMethod.get -> setter + } + + val cls = tag.runtimeClass + val newInstance = NewInstance(cls, Nil, ObjectType(cls), propagateNull = false) + val result = InitializeJavaBean(newInstance, setters.toMap) + exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result) + } + + private def deserializeArray( + path: Expression, + elementEnc: AgnosticEncoder[_], + containsNull: Boolean, + cls: Option[Class[_]], + walkedTypePath: WalkedTypePath): Expression = { + val newTypePath = walkedTypePath.recordArray(elementEnc.clsTag.runtimeClass.getName) + val mapFunction: Expression => Expression = element => { + // upcast the array element to the data type the encoder expects. + deserializerForWithNullSafetyAndUpcast( + element, + elementEnc.dataType, + nullable = containsNull, + newTypePath, + createDeserializer(elementEnc, _, newTypePath)) + } + UnresolvedMapObjects(mapFunction, path, cls) + } + + private def toArrayMethodName(enc: AgnosticEncoder[_]): String = enc match { + case PrimitiveBooleanEncoder => "toBooleanArray" + case PrimitiveByteEncoder => "toByteArray" + case PrimitiveShortEncoder => "toShortArray" + case PrimitiveIntEncoder => "toIntArray" + case PrimitiveLongEncoder => "toLongArray" + case PrimitiveFloatEncoder => "toFloatArray" + case PrimitiveDoubleEncoder => "toDoubleArray" + case _ => "array" + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index b0588fc704440..9f2548c378928 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst import javax.lang.model.SourceVersion import scala.annotation.tailrec -import scala.language.existentials import scala.reflect.ClassTag import scala.reflect.internal.Symbols import scala.util.{Failure, Success} @@ -30,24 +29,14 @@ import org.apache.commons.lang3.reflect.ConstructorUtils import org.apache.spark.SPARK_DOC_ROOT import org.apache.spark.internal.Logging import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{expressions => exprs} -import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ -import org.apache.spark.sql.catalyst.SerializerBuildHelper._ -import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.types.{PhysicalBinaryType, PhysicalIntegerType, PhysicalLongType} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} - +import org.apache.spark.unsafe.types.CalendarInterval private[catalyst] object ScalaSubtypeLock - /** * A default version of ScalaReflection that uses the runtime universe. */ @@ -79,38 +68,6 @@ object ScalaReflection extends ScalaReflection { } } - /** - * Return the data type we expect to see when deserializing a value with encoder `enc`. - */ - private[catalyst] def externalDataTypeFor(enc: AgnosticEncoder[_]): DataType = { - externalDataTypeFor(enc, lenientSerialization = false) - } - - private[catalyst] def lenientExternalDataTypeFor(enc: AgnosticEncoder[_]): DataType = - externalDataTypeFor(enc, enc.lenientSerialization) - - private def externalDataTypeFor( - enc: AgnosticEncoder[_], - lenientSerialization: Boolean): DataType = { - // DataType can be native. - if (isNativeEncoder(enc)) { - enc.dataType - } else if (lenientSerialization) { - ObjectType(classOf[java.lang.Object]) - } else { - ObjectType(enc.clsTag.runtimeClass) - } - } - - /** - * Returns true if the value of this data type is same between internal and external. - */ - def isNativeType(dt: DataType): Boolean = dt match { - case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType | CalendarIntervalType => true - case _ => false - } - private def baseType(tpe: `Type`): `Type` = { tpe.dealias match { case annotatedType: AnnotatedType => annotatedType.underlying @@ -118,452 +75,6 @@ object ScalaReflection extends ScalaReflection { } } - /** - * Returns true if the encoders' internal and external data type is the same. - */ - private def isNativeEncoder(enc: AgnosticEncoder[_]): Boolean = enc match { - case PrimitiveBooleanEncoder => true - case PrimitiveByteEncoder => true - case PrimitiveShortEncoder => true - case PrimitiveIntEncoder => true - case PrimitiveLongEncoder => true - case PrimitiveFloatEncoder => true - case PrimitiveDoubleEncoder => true - case NullEncoder => true - case CalendarIntervalEncoder => true - case BinaryEncoder => true - case _: SparkDecimalEncoder => true - case _ => false - } - - private def toArrayMethodName(enc: AgnosticEncoder[_]): String = enc match { - case PrimitiveBooleanEncoder => "toBooleanArray" - case PrimitiveByteEncoder => "toByteArray" - case PrimitiveShortEncoder => "toShortArray" - case PrimitiveIntEncoder => "toIntArray" - case PrimitiveLongEncoder => "toLongArray" - case PrimitiveFloatEncoder => "toFloatArray" - case PrimitiveDoubleEncoder => "toDoubleArray" - case _ => "array" - } - - /** - * Returns an expression for deserializing the Spark SQL representation of an object into its - * external form. The mapping between the internal and external representations is - * described by encoder `enc`. The Spark SQL representation is located at ordinal 0 of - * a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed using - * `UnresolvedExtractValue`. - * - * The returned expression is used by `ExpressionEncoder`. The encoder will resolve and bind this - * deserializer expression when using it. - * - * @param enc encoder that describes the mapping between the Spark SQL representation and the - * external representation. - */ - def deserializerFor[T](enc: AgnosticEncoder[T]): Expression = { - val walkedTypePath = WalkedTypePath().recordRoot(enc.clsTag.runtimeClass.getName) - // Assumes we are deserializing the first column of a row. - val input = GetColumnByOrdinal(0, enc.dataType) - enc match { - case RowEncoder(fields) => - val children = fields.zipWithIndex.map { case (f, i) => - deserializerFor(f.enc, GetStructField(input, i), walkedTypePath) - } - CreateExternalRow(children, enc.schema) - case _ => - val deserializer = deserializerFor( - enc, - upCastToExpectedType(input, enc.dataType, walkedTypePath), - walkedTypePath) - expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath) - } - } - - /** - * Returns an expression for deserializing the value of an input expression into its external - * representation. The mapping between the internal and external representations is - * described by encoder `enc`. - * - * @param enc encoder that describes the mapping between the Spark SQL representation and the - * external representation. - * @param path The expression which can be used to extract serialized value. - * @param walkedTypePath The paths from top to bottom to access current field when deserializing. - */ - private def deserializerFor( - enc: AgnosticEncoder[_], - path: Expression, - walkedTypePath: WalkedTypePath): Expression = enc match { - case _ if isNativeEncoder(enc) => - path - case _: BoxedLeafEncoder[_, _] => - createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass) - case JavaEnumEncoder(tag) => - val toString = createDeserializerForString(path, returnNullable = false) - createDeserializerForTypesSupportValueOf(toString, tag.runtimeClass) - case ScalaEnumEncoder(parent, tag) => - StaticInvoke( - parent, - ObjectType(tag.runtimeClass), - "withName", - createDeserializerForString(path, returnNullable = false) :: Nil, - returnNullable = false) - case StringEncoder => - createDeserializerForString(path, returnNullable = false) - case _: ScalaDecimalEncoder => - createDeserializerForScalaBigDecimal(path, returnNullable = false) - case _: JavaDecimalEncoder => - createDeserializerForJavaBigDecimal(path, returnNullable = false) - case ScalaBigIntEncoder => - createDeserializerForScalaBigInt(path) - case JavaBigIntEncoder => - createDeserializerForJavaBigInteger(path, returnNullable = false) - case DayTimeIntervalEncoder => - createDeserializerForDuration(path) - case YearMonthIntervalEncoder => - createDeserializerForPeriod(path) - case _: DateEncoder => - createDeserializerForSqlDate(path) - case _: LocalDateEncoder => - createDeserializerForLocalDate(path) - case _: TimestampEncoder => - createDeserializerForSqlTimestamp(path) - case _: InstantEncoder => - createDeserializerForInstant(path) - case LocalDateTimeEncoder => - createDeserializerForLocalDateTime(path) - case UDTEncoder(udt, udtClass) => - val obj = NewInstance(udtClass, Nil, ObjectType(udtClass)) - Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) - case OptionEncoder(valueEnc) => - val newTypePath = walkedTypePath.recordOption(valueEnc.clsTag.runtimeClass.getName) - val deserializer = deserializerFor(valueEnc, path, newTypePath) - WrapOption(deserializer, externalDataTypeFor(valueEnc)) - - case ArrayEncoder(elementEnc: AgnosticEncoder[_], containsNull) => - Invoke( - deserializeArray( - path, - elementEnc, - containsNull, - None, - walkedTypePath), - toArrayMethodName(elementEnc), - ObjectType(enc.clsTag.runtimeClass), - returnNullable = false) - - case IterableEncoder(clsTag, elementEnc, containsNull, _) => - deserializeArray( - path, - elementEnc, - containsNull, - Option(clsTag.runtimeClass), - walkedTypePath) - - case MapEncoder(tag, keyEncoder, valueEncoder, _) - if classOf[java.util.Map[_, _]].isAssignableFrom(tag.runtimeClass) => - // TODO (hvanhovell) this is can be improved. - val newTypePath = walkedTypePath.recordMap( - keyEncoder.clsTag.runtimeClass.getName, - valueEncoder.clsTag.runtimeClass.getName) - - val keyData = - Invoke( - UnresolvedMapObjects( - p => deserializerFor(keyEncoder, p, newTypePath), - MapKeys(path)), - "array", - ObjectType(classOf[Array[Any]])) - - val valueData = - Invoke( - UnresolvedMapObjects( - p => deserializerFor(valueEncoder, p, newTypePath), - MapValues(path)), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke( - ArrayBasedMapData.getClass, - ObjectType(classOf[java.util.Map[_, _]]), - "toJavaMap", - keyData :: valueData :: Nil, - returnNullable = false) - - case MapEncoder(tag, keyEncoder, valueEncoder, _) => - val newTypePath = walkedTypePath.recordMap( - keyEncoder.clsTag.runtimeClass.getName, - valueEncoder.clsTag.runtimeClass.getName) - UnresolvedCatalystToExternalMap( - path, - deserializerFor(keyEncoder, _, newTypePath), - deserializerFor(valueEncoder, _, newTypePath), - tag.runtimeClass) - - case ProductEncoder(tag, fields) => - val cls = tag.runtimeClass - val dt = ObjectType(cls) - val isTuple = cls.getName.startsWith("scala.Tuple") - val arguments = fields.zipWithIndex.map { - case (field, i) => - val newTypePath = walkedTypePath.recordField( - field.enc.clsTag.runtimeClass.getName, - field.name) - // For tuples, we grab the inner fields by ordinal instead of name. - val getter = if (isTuple) { - addToPathOrdinal(path, i, field.enc.dataType, newTypePath) - } else { - addToPath(path, field.name, field.enc.dataType, newTypePath) - } - expressionWithNullSafety( - deserializerFor(field.enc, getter, newTypePath), - field.enc.nullable, - newTypePath) - } - expressions.If( - IsNull(path), - expressions.Literal.create(null, dt), - NewInstance(cls, arguments, dt, propagateNull = false)) - - case RowEncoder(fields) => - val convertedFields = fields.zipWithIndex.map { case (f, i) => - val newTypePath = walkedTypePath.recordField( - f.enc.clsTag.runtimeClass.getName, - f.name) - exprs.If( - Invoke(path, "isNullAt", BooleanType, exprs.Literal(i) :: Nil), - exprs.Literal.create(null, externalDataTypeFor(f.enc)), - deserializerFor(f.enc, GetStructField(path, i), newTypePath)) - } - exprs.If(IsNull(path), - exprs.Literal.create(null, externalDataTypeFor(enc)), - CreateExternalRow(convertedFields, enc.schema)) - - case JavaBeanEncoder(tag, fields) => - val setters = fields.map { f => - val newTypePath = walkedTypePath.recordField( - f.enc.clsTag.runtimeClass.getName, - f.name) - val setter = expressionWithNullSafety( - deserializerFor( - f.enc, - addToPath(path, f.name, f.enc.dataType, newTypePath), - newTypePath), - nullable = f.nullable, - newTypePath) - f.writeMethod.get -> setter - } - - val cls = tag.runtimeClass - val newInstance = NewInstance(cls, Nil, ObjectType(cls), propagateNull = false) - val result = InitializeJavaBean(newInstance, setters.toMap) - exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result) - } - - private def deserializeArray( - path: Expression, - elementEnc: AgnosticEncoder[_], - containsNull: Boolean, - cls: Option[Class[_]], - walkedTypePath: WalkedTypePath): Expression = { - val newTypePath = walkedTypePath.recordArray(elementEnc.clsTag.runtimeClass.getName) - val mapFunction: Expression => Expression = element => { - // upcast the array element to the data type the encoder expects. - deserializerForWithNullSafetyAndUpcast( - element, - elementEnc.dataType, - nullable = containsNull, - newTypePath, - deserializerFor(elementEnc, _, newTypePath)) - } - UnresolvedMapObjects(mapFunction, path, cls) - } - - /** - * Returns an expression for serializing an object into its Spark SQL form. The mapping - * between the external and internal representations is described by encoder `enc`. The - * input object is located at ordinal 0 of a row, i.e., `BoundReference(0, _)`. - */ - def serializerFor(enc: AgnosticEncoder[_]): Expression = { - val input = BoundReference(0, lenientExternalDataTypeFor(enc), nullable = enc.nullable) - serializerFor(enc, input) - } - - /** - * Returns an expression for serializing the value of an input expression into its Spark SQL - * representation. The mapping between the external and internal representations is described - * by encoder `enc`. - */ - private def serializerFor(enc: AgnosticEncoder[_], input: Expression): Expression = enc match { - case _ if isNativeEncoder(enc) => input - case BoxedBooleanEncoder => createSerializerForBoolean(input) - case BoxedByteEncoder => createSerializerForByte(input) - case BoxedShortEncoder => createSerializerForShort(input) - case BoxedIntEncoder => createSerializerForInteger(input) - case BoxedLongEncoder => createSerializerForLong(input) - case BoxedFloatEncoder => createSerializerForFloat(input) - case BoxedDoubleEncoder => createSerializerForDouble(input) - case JavaEnumEncoder(_) => createSerializerForJavaEnum(input) - case ScalaEnumEncoder(_, _) => createSerializerForScalaEnum(input) - case StringEncoder => createSerializerForString(input) - case ScalaDecimalEncoder(dt) => createSerializerForBigDecimal(input, dt) - case JavaDecimalEncoder(dt, false) => createSerializerForBigDecimal(input, dt) - case JavaDecimalEncoder(dt, true) => createSerializerForAnyDecimal(input, dt) - case ScalaBigIntEncoder => createSerializerForBigInteger(input) - case JavaBigIntEncoder => createSerializerForBigInteger(input) - case DayTimeIntervalEncoder => createSerializerForJavaDuration(input) - case YearMonthIntervalEncoder => createSerializerForJavaPeriod(input) - case DateEncoder(true) | LocalDateEncoder(true) => createSerializerForAnyDate(input) - case DateEncoder(false) => createSerializerForSqlDate(input) - case LocalDateEncoder(false) => createSerializerForJavaLocalDate(input) - case TimestampEncoder(true) | InstantEncoder(true) => createSerializerForAnyTimestamp(input) - case TimestampEncoder(false) => createSerializerForSqlTimestamp(input) - case InstantEncoder(false) => createSerializerForJavaInstant(input) - case LocalDateTimeEncoder => createSerializerForLocalDateTime(input) - case UDTEncoder(udt, udtClass) => createSerializerForUserDefinedType(input, udt, udtClass) - case OptionEncoder(valueEnc) => - serializerFor(valueEnc, UnwrapOption(externalDataTypeFor(valueEnc), input)) - - case ArrayEncoder(elementEncoder, containsNull) => - if (elementEncoder.isPrimitive) { - createSerializerForPrimitiveArray(input, elementEncoder.dataType) - } else { - serializerForArray(elementEncoder, containsNull, input, lenientSerialization = false) - } - - case IterableEncoder(ctag, elementEncoder, containsNull, lenientSerialization) => - val getter = if (classOf[scala.collection.Set[_]].isAssignableFrom(ctag.runtimeClass)) { - // There's no corresponding Catalyst type for `Set`, we serialize a `Set` to Catalyst array. - // Note that the property of `Set` is only kept when manipulating the data as domain object. - Invoke(input, "toSeq", ObjectType(classOf[scala.collection.Seq[_]])) - } else { - input - } - serializerForArray(elementEncoder, containsNull, getter, lenientSerialization) - - case MapEncoder(_, keyEncoder, valueEncoder, valueContainsNull) => - createSerializerForMap( - input, - MapElementInformation( - ObjectType(classOf[AnyRef]), - nullable = keyEncoder.nullable, - validateAndSerializeElement(keyEncoder, keyEncoder.nullable)), - MapElementInformation( - ObjectType(classOf[AnyRef]), - nullable = valueContainsNull, - validateAndSerializeElement(valueEncoder, valueContainsNull)) - ) - - case ProductEncoder(_, fields) => - val serializedFields = fields.map { field => - // SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul - // is necessary here. Because for a nullable nested inputObject with struct data - // type, e.g. StructType(IntegerType, StringType), it will return nullable=true - // for IntegerType without KnownNotNull. And that's what we do not expect to. - val getter = Invoke( - KnownNotNull(input), - field.name, - externalDataTypeFor(field.enc), - returnNullable = field.nullable) - field.name -> serializerFor(field.enc, getter) - } - createSerializerForObject(input, serializedFields) - - case RowEncoder(fields) => - val serializedFields = fields.zipWithIndex.map { case (field, index) => - val fieldValue = serializerFor( - field.enc, - ValidateExternalType( - GetExternalRowField(input, index, field.name), - field.enc.dataType, - lenientExternalDataTypeFor(field.enc))) - - val convertedField = if (field.nullable) { - exprs.If( - Invoke(input, "isNullAt", BooleanType, exprs.Literal(index) :: Nil), - // Because we strip UDTs, `field.dataType` can be different from `fieldValue.dataType`. - // We should use `fieldValue.dataType` here. - exprs.Literal.create(null, fieldValue.dataType), - fieldValue - ) - } else { - AssertNotNull(fieldValue) - } - field.name -> convertedField - } - createSerializerForObject(input, serializedFields) - - case JavaBeanEncoder(_, fields) => - val serializedFields = fields.map { f => - val fieldValue = Invoke( - KnownNotNull(input), - f.readMethod.get, - externalDataTypeFor(f.enc), - propagateNull = f.nullable, - returnNullable = f.nullable) - f.name -> serializerFor(f.enc, fieldValue) - } - createSerializerForObject(input, serializedFields) - } - - private def serializerForArray( - elementEnc: AgnosticEncoder[_], - elementNullable: Boolean, - input: Expression, - lenientSerialization: Boolean): Expression = { - // Default serializer for Seq and generic Arrays. This does not work for primitive arrays. - val genericSerializer = createSerializerForMapObjects( - input, - ObjectType(classOf[AnyRef]), - validateAndSerializeElement(elementEnc, elementNullable)) - - // Check if it is possible the user can pass a primitive array. This is the only case when it - // is safe to directly convert to an array (for generic arrays and Seqs the type and the - // nullability can be violated). If the user has passed a primitive array we create a special - // code path to deal with these. - val primitiveEncoderOption = elementEnc match { - case _ if !lenientSerialization => None - case enc: PrimitiveLeafEncoder[_] => Option(enc) - case enc: BoxedLeafEncoder[_, _] => Option(enc.primitive) - case _ => None - } - primitiveEncoderOption match { - case Some(primitiveEncoder) => - val primitiveArrayClass = primitiveEncoder.clsTag.wrap.runtimeClass - val check = Invoke( - targetObject = exprs.Literal.fromObject(primitiveArrayClass), - functionName = "isInstance", - BooleanType, - arguments = input :: Nil, - propagateNull = false, - returnNullable = false) - exprs.If( - check, - // TODO replace this with `createSerializerForPrimitiveArray` as - // soon as Cast support ObjectType casts. - StaticInvoke( - classOf[ArrayData], - ArrayType(elementEnc.dataType, containsNull = false), - "toArrayData", - input :: Nil, - propagateNull = false, - returnNullable = false), - genericSerializer) - case None => - genericSerializer - } - } - - private def validateAndSerializeElement( - enc: AgnosticEncoder[_], - nullable: Boolean): Expression => Expression = { input => - expressionWithNullSafety( - serializerFor( - enc, - ValidateExternalType(input, enc.dataType, lenientExternalDataTypeFor(enc))), - nullable, - WalkedTypePath()) - } - /** * Returns the parameter names for the primary constructor of this class. * @@ -606,15 +117,6 @@ object ScalaReflection extends ScalaReflection { } } - /** - * Returns the parameter values for the primary constructor of this class. - */ - def getConstructorParameterValues(obj: DefinedByConstructorParams): Seq[AnyRef] = { - getConstructorParameterNames(obj.getClass).map { name => - obj.getClass.getMethod(name).invoke(obj) - } - } - private def erasure(tpe: Type): Type = { // For user-defined AnyVal classes, we should not erasure it. Otherwise, it will // resolve to underlying type which wrapped by this class, e.g erasure @@ -651,13 +153,6 @@ object ScalaReflection extends ScalaReflection { case class Schema(dataType: DataType, nullable: Boolean) - /** Returns a Sequence of attributes for the given case class type. */ - def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { - case Schema(s: StructType, _) => - s.toAttributes - case others => throw QueryExecutionErrors.attributesForTypeUnsupportedError(others) - } - /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) @@ -713,72 +208,6 @@ object ScalaReflection extends ScalaReflection { } } - val typeJavaMapping: Map[DataType, Class[_]] = Map[DataType, Class[_]]( - BooleanType -> classOf[Boolean], - ByteType -> classOf[Byte], - ShortType -> classOf[Short], - IntegerType -> classOf[Int], - LongType -> classOf[Long], - FloatType -> classOf[Float], - DoubleType -> classOf[Double], - StringType -> classOf[UTF8String], - DateType -> classOf[PhysicalIntegerType.InternalType], - TimestampType -> classOf[PhysicalLongType.InternalType], - TimestampNTZType -> classOf[PhysicalLongType.InternalType], - BinaryType -> classOf[PhysicalBinaryType.InternalType], - CalendarIntervalType -> classOf[CalendarInterval] - ) - - val typeBoxedJavaMapping: Map[DataType, Class[_]] = Map[DataType, Class[_]]( - BooleanType -> classOf[java.lang.Boolean], - ByteType -> classOf[java.lang.Byte], - ShortType -> classOf[java.lang.Short], - IntegerType -> classOf[java.lang.Integer], - LongType -> classOf[java.lang.Long], - FloatType -> classOf[java.lang.Float], - DoubleType -> classOf[java.lang.Double], - DateType -> classOf[java.lang.Integer], - TimestampType -> classOf[java.lang.Long], - TimestampNTZType -> classOf[java.lang.Long] - ) - - def dataTypeJavaClass(dt: DataType): Class[_] = { - dt match { - case _: DecimalType => classOf[Decimal] - case _: DayTimeIntervalType => classOf[PhysicalLongType.InternalType] - case _: YearMonthIntervalType => classOf[PhysicalIntegerType.InternalType] - case _: StructType => classOf[InternalRow] - case _: ArrayType => classOf[ArrayData] - case _: MapType => classOf[MapData] - case ObjectType(cls) => cls - case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object]) - } - } - - @scala.annotation.tailrec - def javaBoxedType(dt: DataType): Class[_] = dt match { - case _: DecimalType => classOf[Decimal] - case _: DayTimeIntervalType => classOf[java.lang.Long] - case _: YearMonthIntervalType => classOf[java.lang.Integer] - case BinaryType => classOf[Array[Byte]] - case StringType => classOf[UTF8String] - case CalendarIntervalType => classOf[CalendarInterval] - case _: StructType => classOf[InternalRow] - case _: ArrayType => classOf[ArrayData] - case _: MapType => classOf[MapData] - case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType) - case ObjectType(cls) => cls - case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object]) - } - - def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = { - if (arguments != Nil) { - arguments.map(e => dataTypeJavaClass(e.dataType)) - } else { - Seq.empty - } - } - def encodeFieldNameToIdentifier(fieldName: String): String = { TermName(fieldName).encodedName.toString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index 33b0edb0c440a..7a4061a4b5605 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -17,9 +17,16 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.catalyst.expressions.{CheckOverflow, CreateNamedStruct, Expression, IsNull, UnsafeArrayData} +import scala.language.existentials + +import org.apache.spark.sql.catalyst.{expressions => exprs} +import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, UnsafeArrayData} import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData, IntervalUtils} +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, IntervalUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -194,7 +201,7 @@ object SerializerBuildHelper { dataType: DataType): Expression = { StaticInvoke( classOf[UnsafeArrayData], - ArrayType(dataType, false), + ArrayType(dataType, containsNull = false), "fromPrimitiveArray", inputObject :: Nil, returnNullable = false) @@ -264,4 +271,190 @@ object SerializerBuildHelper { val obj = NewInstance(udtClass, Nil, dataType = ObjectType(udtClass)) Invoke(obj, "serialize", udt, inputObject :: Nil) } + + /** + * Returns an expression for serializing an object into its Spark SQL form. The mapping + * between the external and internal representations is described by encoder `enc`. The + * input object is located at ordinal 0 of a row, i.e., `BoundReference(0, _)`. + */ + def createSerializer(enc: AgnosticEncoder[_]): Expression = { + val input = BoundReference(0, lenientExternalDataTypeFor(enc), nullable = enc.nullable) + createSerializer(enc, input) + } + + /** + * Returns an expression for serializing the value of an input expression into its Spark SQL + * representation. The mapping between the external and internal representations is described + * by encoder `enc`. + */ + private def createSerializer(enc: AgnosticEncoder[_], input: Expression): Expression = enc match { + case _ if isNativeEncoder(enc) => input + case BoxedBooleanEncoder => createSerializerForBoolean(input) + case BoxedByteEncoder => createSerializerForByte(input) + case BoxedShortEncoder => createSerializerForShort(input) + case BoxedIntEncoder => createSerializerForInteger(input) + case BoxedLongEncoder => createSerializerForLong(input) + case BoxedFloatEncoder => createSerializerForFloat(input) + case BoxedDoubleEncoder => createSerializerForDouble(input) + case JavaEnumEncoder(_) => createSerializerForJavaEnum(input) + case ScalaEnumEncoder(_, _) => createSerializerForScalaEnum(input) + case StringEncoder => createSerializerForString(input) + case ScalaDecimalEncoder(dt) => createSerializerForBigDecimal(input, dt) + case JavaDecimalEncoder(dt, false) => createSerializerForBigDecimal(input, dt) + case JavaDecimalEncoder(dt, true) => createSerializerForAnyDecimal(input, dt) + case ScalaBigIntEncoder => createSerializerForBigInteger(input) + case JavaBigIntEncoder => createSerializerForBigInteger(input) + case DayTimeIntervalEncoder => createSerializerForJavaDuration(input) + case YearMonthIntervalEncoder => createSerializerForJavaPeriod(input) + case DateEncoder(true) | LocalDateEncoder(true) => createSerializerForAnyDate(input) + case DateEncoder(false) => createSerializerForSqlDate(input) + case LocalDateEncoder(false) => createSerializerForJavaLocalDate(input) + case TimestampEncoder(true) | InstantEncoder(true) => createSerializerForAnyTimestamp(input) + case TimestampEncoder(false) => createSerializerForSqlTimestamp(input) + case InstantEncoder(false) => createSerializerForJavaInstant(input) + case LocalDateTimeEncoder => createSerializerForLocalDateTime(input) + case UDTEncoder(udt, udtClass) => createSerializerForUserDefinedType(input, udt, udtClass) + case OptionEncoder(valueEnc) => + createSerializer(valueEnc, UnwrapOption(externalDataTypeFor(valueEnc), input)) + + case ArrayEncoder(elementEncoder, containsNull) => + if (elementEncoder.isPrimitive) { + createSerializerForPrimitiveArray(input, elementEncoder.dataType) + } else { + serializerForArray(elementEncoder, containsNull, input, lenientSerialization = false) + } + + case IterableEncoder(ctag, elementEncoder, containsNull, lenientSerialization) => + val getter = if (classOf[scala.collection.Set[_]].isAssignableFrom(ctag.runtimeClass)) { + // There's no corresponding Catalyst type for `Set`, we serialize a `Set` to Catalyst array. + // Note that the property of `Set` is only kept when manipulating the data as domain object. + Invoke(input, "toSeq", ObjectType(classOf[scala.collection.Seq[_]])) + } else { + input + } + serializerForArray(elementEncoder, containsNull, getter, lenientSerialization) + + case MapEncoder(_, keyEncoder, valueEncoder, valueContainsNull) => + createSerializerForMap( + input, + MapElementInformation( + ObjectType(classOf[AnyRef]), + nullable = keyEncoder.nullable, + validateAndSerializeElement(keyEncoder, keyEncoder.nullable)), + MapElementInformation( + ObjectType(classOf[AnyRef]), + nullable = valueContainsNull, + validateAndSerializeElement(valueEncoder, valueContainsNull)) + ) + + case ProductEncoder(_, fields) => + val serializedFields = fields.map { field => + // SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul + // is necessary here. Because for a nullable nested inputObject with struct data + // type, e.g. StructType(IntegerType, StringType), it will return nullable=true + // for IntegerType without KnownNotNull. And that's what we do not expect to. + val getter = Invoke( + KnownNotNull(input), + field.name, + externalDataTypeFor(field.enc), + returnNullable = field.nullable) + field.name -> createSerializer(field.enc, getter) + } + createSerializerForObject(input, serializedFields) + + case AgnosticEncoders.RowEncoder(fields) => + val serializedFields = fields.zipWithIndex.map { case (field, index) => + val fieldValue = createSerializer( + field.enc, + ValidateExternalType( + GetExternalRowField(input, index, field.name), + field.enc.dataType, + lenientExternalDataTypeFor(field.enc))) + + val convertedField = if (field.nullable) { + exprs.If( + Invoke(input, "isNullAt", BooleanType, exprs.Literal(index) :: Nil), + // Because we strip UDTs, `field.dataType` can be different from `fieldValue.dataType`. + // We should use `fieldValue.dataType` here. + exprs.Literal.create(null, fieldValue.dataType), + fieldValue + ) + } else { + AssertNotNull(fieldValue) + } + field.name -> convertedField + } + createSerializerForObject(input, serializedFields) + + case JavaBeanEncoder(_, fields) => + val serializedFields = fields.map { f => + val fieldValue = Invoke( + KnownNotNull(input), + f.readMethod.get, + externalDataTypeFor(f.enc), + propagateNull = f.nullable, + returnNullable = f.nullable) + f.name -> createSerializer(f.enc, fieldValue) + } + createSerializerForObject(input, serializedFields) + } + + private def serializerForArray( + elementEnc: AgnosticEncoder[_], + elementNullable: Boolean, + input: Expression, + lenientSerialization: Boolean): Expression = { + // Default serializer for Seq and generic Arrays. This does not work for primitive arrays. + val genericSerializer = createSerializerForMapObjects( + input, + ObjectType(classOf[AnyRef]), + validateAndSerializeElement(elementEnc, elementNullable)) + + // Check if it is possible the user can pass a primitive array. This is the only case when it + // is safe to directly convert to an array (for generic arrays and Seqs the type and the + // nullability can be violated). If the user has passed a primitive array we create a special + // code path to deal with these. + val primitiveEncoderOption = elementEnc match { + case _ if !lenientSerialization => None + case enc: PrimitiveLeafEncoder[_] => Option(enc) + case enc: BoxedLeafEncoder[_, _] => Option(enc.primitive) + case _ => None + } + primitiveEncoderOption match { + case Some(primitiveEncoder) => + val primitiveArrayClass = primitiveEncoder.clsTag.wrap.runtimeClass + val check = Invoke( + targetObject = exprs.Literal.fromObject(primitiveArrayClass), + functionName = "isInstance", + BooleanType, + arguments = input :: Nil, + propagateNull = false, + returnNullable = false) + exprs.If( + check, + // TODO replace this with `createSerializerForPrimitiveArray` as + // soon as Cast support ObjectType casts. + StaticInvoke( + classOf[ArrayData], + ArrayType(elementEnc.dataType, containsNull = false), + "toArrayData", + input :: Nil, + propagateNull = false, + returnNullable = false), + genericSerializer) + case None => + genericSerializer + } + } + + private def validateAndSerializeElement( + enc: AgnosticEncoder[_], + nullable: Boolean): Expression => Expression = { input => + expressionWithNullSafety( + createSerializer( + enc, + ValidateExternalType(input, enc.dataType, lenientExternalDataTypeFor(enc))), + nullable, + WalkedTypePath()) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala new file mode 100644 index 0000000000000..4540ecffe0d21 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.encoders + +import scala.collection.Map + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, CalendarIntervalEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, SparkDecimalEncoder} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.types.{PhysicalBinaryType, PhysicalIntegerType, PhysicalLongType} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType, UserDefinedType, YearMonthIntervalType} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +/** + * Helper class for Generating [[ExpressionEncoder]]s. + */ +object EncoderUtils { + /** + * Return the data type we expect to see when deserializing a value with encoder `enc`. + */ + private[catalyst] def externalDataTypeFor(enc: AgnosticEncoder[_]): DataType = { + externalDataTypeFor(enc, lenientSerialization = false) + } + + private[catalyst] def lenientExternalDataTypeFor(enc: AgnosticEncoder[_]): DataType = + externalDataTypeFor(enc, enc.lenientSerialization) + + private def externalDataTypeFor( + enc: AgnosticEncoder[_], + lenientSerialization: Boolean): DataType = { + // DataType can be native. + if (isNativeEncoder(enc)) { + enc.dataType + } else if (lenientSerialization) { + ObjectType(classOf[java.lang.Object]) + } else { + ObjectType(enc.clsTag.runtimeClass) + } + } + + /** + * Returns true if the encoders' internal and external data type is the same. + */ + private[catalyst] def isNativeEncoder(enc: AgnosticEncoder[_]): Boolean = enc match { + case PrimitiveBooleanEncoder => true + case PrimitiveByteEncoder => true + case PrimitiveShortEncoder => true + case PrimitiveIntEncoder => true + case PrimitiveLongEncoder => true + case PrimitiveFloatEncoder => true + case PrimitiveDoubleEncoder => true + case NullEncoder => true + case CalendarIntervalEncoder => true + case BinaryEncoder => true + case _: SparkDecimalEncoder => true + case _ => false + } + + def dataTypeJavaClass(dt: DataType): Class[_] = { + dt match { + case _: DecimalType => classOf[Decimal] + case _: DayTimeIntervalType => classOf[PhysicalLongType.InternalType] + case _: YearMonthIntervalType => classOf[PhysicalIntegerType.InternalType] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayData] + case _: MapType => classOf[MapData] + case ObjectType(cls) => cls + case _ => typeJavaMapping.getOrElse(dt, classOf[java.lang.Object]) + } + } + + @scala.annotation.tailrec + def javaBoxedType(dt: DataType): Class[_] = dt match { + case _: DecimalType => classOf[Decimal] + case _: DayTimeIntervalType => classOf[java.lang.Long] + case _: YearMonthIntervalType => classOf[java.lang.Integer] + case BinaryType => classOf[Array[Byte]] + case StringType => classOf[UTF8String] + case CalendarIntervalType => classOf[CalendarInterval] + case _: StructType => classOf[InternalRow] + case _: ArrayType => classOf[ArrayData] + case _: MapType => classOf[MapData] + case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType) + case ObjectType(cls) => cls + case _ => typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object]) + } + + def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = { + if (arguments != Nil) { + arguments.map(e => dataTypeJavaClass(e.dataType)) + } else { + Seq.empty + } + } + + val typeJavaMapping: Map[DataType, Class[_]] = Map[DataType, Class[_]]( + BooleanType -> classOf[Boolean], + ByteType -> classOf[Byte], + ShortType -> classOf[Short], + IntegerType -> classOf[Int], + LongType -> classOf[Long], + FloatType -> classOf[Float], + DoubleType -> classOf[Double], + StringType -> classOf[UTF8String], + DateType -> classOf[PhysicalIntegerType.InternalType], + TimestampType -> classOf[PhysicalLongType.InternalType], + TimestampNTZType -> classOf[PhysicalLongType.InternalType], + BinaryType -> classOf[PhysicalBinaryType.InternalType], + CalendarIntervalType -> classOf[CalendarInterval] + ) + + val typeBoxedJavaMapping: Map[DataType, Class[_]] = Map[DataType, Class[_]]( + BooleanType -> classOf[java.lang.Boolean], + ByteType -> classOf[java.lang.Byte], + ShortType -> classOf[java.lang.Short], + IntegerType -> classOf[java.lang.Integer], + LongType -> classOf[java.lang.Long], + FloatType -> classOf[java.lang.Float], + DoubleType -> classOf[java.lang.Double], + DateType -> classOf[java.lang.Integer], + TimestampType -> classOf[java.lang.Long], + TimestampNTZType -> classOf[java.lang.Long] + ) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 8f7583c48fcac..cfcc1959a3d76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -21,7 +21,7 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.catalyst.{DeserializerBuildHelper, InternalRow, JavaTypeInference, ScalaReflection, SerializerBuildHelper} import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer} import org.apache.spark.sql.catalyst.expressions._ @@ -52,8 +52,8 @@ object ExpressionEncoder { def apply[T](enc: AgnosticEncoder[T]): ExpressionEncoder[T] = { new ExpressionEncoder[T]( - ScalaReflection.serializerFor(enc), - ScalaReflection.deserializerFor(enc), + SerializerBuildHelper.createSerializer(enc), + DeserializerBuildHelper.createDeserializer(enc), enc.clsTag) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index 06ecf79c58cdf..1d65d49443596 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -21,8 +21,9 @@ import java.lang.reflect.{Method, Modifier} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection, SQLConfHelper} +import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException +import org.apache.spark.sql.catalyst.encoders.EncoderUtils import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier} @@ -152,7 +153,7 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { scalarFunc: ScalarFunction[_], arguments: Seq[Expression]): Expression = { val declaredInputTypes = scalarFunc.inputTypes().toSeq - val argClasses = declaredInputTypes.map(ScalaReflection.dataTypeJavaClass) + val argClasses = declaredInputTypes.map(EncoderUtils.dataTypeJavaClass) findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { case Some(m) if Modifier.isStatic(m.getModifiers) => StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 929beb660ad60..d4c5428af4d25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -31,6 +31,7 @@ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} +import org.apache.spark.sql.catalyst.encoders.EncoderUtils import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -58,10 +59,10 @@ trait InvokeLike extends Expression with NonSQLExpression with ImplicitCastInput protected lazy val needNullCheck: Boolean = needNullCheckForIndex.contains(true) protected lazy val needNullCheckForIndex: Array[Boolean] = arguments.map(a => a.nullable && (propagateNull || - ScalaReflection.dataTypeJavaClass(a.dataType).isPrimitive)).toArray + EncoderUtils.dataTypeJavaClass(a.dataType).isPrimitive)).toArray protected lazy val evaluatedArgs: Array[Object] = new Array[Object](arguments.length) private lazy val boxingFn: Any => Any = - ScalaReflection.typeBoxedJavaMapping + EncoderUtils.typeBoxedJavaMapping .get(dataType) .map(cls => v => cls.cast(v)) .getOrElse(identity) @@ -277,7 +278,7 @@ case class StaticInvoke( override def children: Seq[Expression] = arguments override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic) - lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) + lazy val argClasses = EncoderUtils.expressionJavaClasses(arguments) @transient lazy val method = findMethod(cls, functionName, argClasses) override def eval(input: InternalRow): Any = { @@ -370,7 +371,7 @@ case class Invoke( returnNullable : Boolean = true, isDeterministic: Boolean = true) extends InvokeLike { - lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) + lazy val argClasses = EncoderUtils.expressionJavaClasses(arguments) final override val nodePatterns: Seq[TreePattern] = Seq(INVOKE) @@ -546,7 +547,7 @@ case class NewInstance( } @transient private lazy val constructor: (Seq[AnyRef]) => Any = { - val paramTypes = ScalaReflection.expressionJavaClasses(arguments) + val paramTypes = EncoderUtils.expressionJavaClasses(arguments) val getConstructor = (paramClazz: Seq[Class[_]]) => { ScalaReflection.findConstructor(cls, paramClazz).getOrElse { throw QueryExecutionErrors.constructorNotFoundError(cls.toString) @@ -892,8 +893,8 @@ case class MapObjects private( private def elementClassTag(): ClassTag[Any] = { val clazz = lambdaFunction.dataType match { case ObjectType(cls) => cls - case dt if lambdaFunction.nullable => ScalaReflection.javaBoxedType(dt) - case dt => ScalaReflection.dataTypeJavaClass(dt) + case dt if lambdaFunction.nullable => EncoderUtils.javaBoxedType(dt) + case dt => EncoderUtils.dataTypeJavaClass(dt) } ClassTag(clazz).asInstanceOf[ClassTag[Any]] } @@ -1729,7 +1730,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp case (name, expr) => // Looking for known type mapping. // But also looking for general `Object`-type parameter for generic methods. - val paramTypes = ScalaReflection.expressionJavaClasses(Seq(expr)) ++ Seq(classOf[Object]) + val paramTypes = EncoderUtils.expressionJavaClasses(Seq(expr)) :+ + classOf[Object] val methods = paramTypes.flatMap { fieldClass => try { Some(beanClass.getDeclaredMethod(name, fieldClass)) @@ -1939,7 +1941,7 @@ case class ValidateExternalType(child: Expression, expected: DataType, externalD value.isInstanceOf[java.sql.Timestamp] || value.isInstanceOf[java.time.Instant] } case _ => - val dataTypeClazz = ScalaReflection.javaBoxedType(dataType) + val dataTypeClazz = EncoderUtils.javaBoxedType(dataType) (value: Any) => { dataTypeClazz.isInstance(value) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 7df9f45ed835d..690e55dbe5fd1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -183,13 +183,13 @@ class ScalaReflectionSuite extends SparkFunSuite { // A helper method used to test `ScalaReflection.serializerForType`. private def serializerFor[T: TypeTag]: Expression = { val enc = ScalaReflection.encoderFor[T] - ScalaReflection.serializerFor(enc) + SerializerBuildHelper.createSerializer(enc) } // A helper method used to test `ScalaReflection.deserializerForType`. private def deserializerFor[T: TypeTag]: Expression = { val enc = ScalaReflection.encoderFor[T] - ScalaReflection.deserializerFor(enc) + DeserializerBuildHelper.createDeserializer(enc) } test("isSubtype") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 05ab7a65a3219..63edba80ec83b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -500,7 +500,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val validateType = ValidateExternalType( GetExternalRowField(inputObject, index = 0, fieldName = "c0"), dt, - ScalaReflection.lenientExternalDataTypeFor(enc)) + EncoderUtils.lenientExternalDataTypeFor(enc)) checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input)))) } @@ -560,10 +560,10 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ExternalMapToCatalyst( inputObject, - ScalaReflection.externalDataTypeFor(keyEnc), + EncoderUtils.externalDataTypeFor(keyEnc), kvSerializerFor(keyEnc), keyNullable = keyEnc.nullable, - ScalaReflection.externalDataTypeFor(valueEnc), + EncoderUtils.externalDataTypeFor(valueEnc), kvSerializerFor(valueEnc), valueNullable = valueEnc.nullable ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 8670d95c65e57..4f8a9e3971664 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -41,7 +41,7 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.{SPARK_VERSION_SHORT, SparkException, TestUtils} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol @@ -1091,7 +1091,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession withTempPath { file => val path = new Path(file.toURI.toString) val fs = FileSystem.getLocal(hadoopConf) - val schema = StructType.fromAttributes(ScalaReflection.attributesFor[(Int, String)]) + val schema = schemaFor[(Int, String)] writeMetadata(schema, path, hadoopConf) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 80c27049e5a8d..30f46a3cac2d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -27,7 +27,6 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.parquet.schema.Type._ import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException import org.apache.spark.sql.functions.desc @@ -51,7 +50,7 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSparkSession { nanosAsLong: Boolean = false): Unit = { testSchema( testName, - StructType.fromAttributes(ScalaReflection.attributesFor[T]), + schemaFor[T], messageType, binaryAsString, int96AsTimestamp, @@ -224,8 +223,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { writeLegacyParquetFormat = true, expectedParquetColumn = Some( ParquetColumn( - sparkType = StructType.fromAttributes( - ScalaReflection.attributesFor[Tuple1[Long]]), + sparkType = schemaFor[Tuple1[Long]], descriptor = None, repetitionLevel = 0, definitionLevel = 0, @@ -255,8 +253,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { writeLegacyParquetFormat = true, expectedParquetColumn = Some( ParquetColumn( - sparkType = StructType.fromAttributes( - ScalaReflection.attributesFor[(Boolean, Int, Long, Float, Double, Array[Byte])]), + sparkType = schemaFor[(Boolean, Int, Long, Float, Double, Array[Byte])], descriptor = None, repetitionLevel = 0, definitionLevel = 0, @@ -294,8 +291,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { writeLegacyParquetFormat = true, expectedParquetColumn = Some( ParquetColumn( - sparkType = StructType.fromAttributes( - ScalaReflection.attributesFor[(Byte, Short, Int, Long, java.sql.Date)]), + sparkType = schemaFor[(Byte, Short, Int, Long, java.sql.Date)], descriptor = None, repetitionLevel = 0, definitionLevel = 0, @@ -326,8 +322,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { writeLegacyParquetFormat = true, expectedParquetColumn = Some( ParquetColumn( - sparkType = StructType.fromAttributes( - ScalaReflection.attributesFor[Tuple1[String]]), + sparkType = schemaFor[Tuple1[String]], descriptor = None, repetitionLevel = 0, definitionLevel = 0, @@ -350,8 +345,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { writeLegacyParquetFormat = true, expectedParquetColumn = Some( ParquetColumn( - sparkType = StructType.fromAttributes( - ScalaReflection.attributesFor[Tuple1[String]]), + sparkType = schemaFor[Tuple1[String]], descriptor = None, repetitionLevel = 0, definitionLevel = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 9eca308f16fcc..1558e9733523d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -35,6 +35,7 @@ import org.apache.parquet.schema.MessageType import org.apache.spark.TestUtils import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.execution.datasources.FileBasedDataSourceTest import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -165,6 +166,8 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest { Thread.currentThread().getContextClassLoader.getResource(name).toString } + protected def schemaFor[T: TypeTag]: StructType = ScalaReflection.encoderFor[T].schema + def withAllParquetReaders(code: => Unit): Unit = { // test the row-based reader withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 0f88baeb689d9..c6bf220e45d52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.catalog.{Column, Database, Function, Table} -import org.apache.spark.sql.catalyst.{FunctionIdentifier, ScalaReflection, TableIdentifier} +import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, FunctionIdentifier, ScalaReflection, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Expression @@ -449,10 +449,10 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf val function = new Function("nama", "cataloa", Array("databasa"), "descripta", "classa", false) val column = new Column( "nama", "descripta", "typa", nullable = false, isPartition = true, isBucket = true) - val dbFields = ScalaReflection.getConstructorParameterValues(db) - val tableFields = ScalaReflection.getConstructorParameterValues(table) - val functionFields = ScalaReflection.getConstructorParameterValues(function) - val columnFields = ScalaReflection.getConstructorParameterValues(column) + val dbFields = getConstructorParameterValues(db) + val tableFields = getConstructorParameterValues(table) + val functionFields = getConstructorParameterValues(function) + val columnFields = getConstructorParameterValues(column) assert(dbFields == Seq("nama", "cataloa", "descripta", "locata")) assert(Seq(tableFields(0), tableFields(1), tableFields(3), tableFields(4), tableFields(5)) == Seq("nama", "cataloa", "descripta", "typa", false)) @@ -1044,4 +1044,10 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf func2.isTemporary === false && func2.className.startsWith("org.apache.spark.sql.internal.CatalogSuite")) } + + private def getConstructorParameterValues(obj: DefinedByConstructorParams): Seq[AnyRef] = { + ScalaReflection.getConstructorParameterNames(obj.getClass).map { name => + obj.getClass.getMethod(name).invoke(obj) + } + } }