diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 53c7f17ee6b2e..361c3476f5941 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -232,6 +232,11 @@ object ScalaReflection extends ScalaReflection { case t if isSubtype(t, localTypeOf[java.time.Instant]) => createDeserializerForInstant(path) + case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) => + createDeserializerForTypesSupportValueOf( + Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false), + getClassFromType(t)) + case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => createDeserializerForSqlTimestamp(path) @@ -526,6 +531,9 @@ object ScalaReflection extends ScalaReflection { case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => createSerializerForJavaBigInteger(inputObject) + case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) => + createSerializerForJavaEnum(inputObject) + case t if isSubtype(t, localTypeOf[scala.math.BigInt]) => createSerializerForScalaBigInt(inputObject) @@ -749,6 +757,7 @@ object ScalaReflection extends ScalaReflection { case t if isSubtype(t, localTypeOf[java.lang.Short]) => Schema(ShortType, nullable = true) case t if isSubtype(t, localTypeOf[java.lang.Byte]) => Schema(ByteType, nullable = true) case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => Schema(BooleanType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) => Schema(StringType, nullable = true) case t if isSubtype(t, definitions.IntTpe) => Schema(IntegerType, nullable = false) case t if isSubtype(t, definitions.LongTpe) => Schema(LongType, nullable = false) case t if isSubtype(t, definitions.DoubleTpe) => Schema(DoubleType, nullable = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index 85acaa11230b4..0554f0f76708b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -74,6 +74,9 @@ object SerializerBuildHelper { returnNullable = false) } + def createSerializerForJavaEnum(inputObject: Expression): Expression = + createSerializerForString(Invoke(inputObject, "name", ObjectType(classOf[String]))) + def createSerializerForSqlTimestamp(inputObject: Expression): Expression = { StaticInvoke( DateTimeUtils.getClass, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 71cbc3ab14d97..1135c8848bc23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -88,6 +88,10 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { /** @since 3.0.0 */ implicit def newInstantEncoder: Encoder[java.time.Instant] = Encoders.INSTANT + /** @since 3.2.0 */ + implicit def newJavaEnumEncoder[A <: java.lang.Enum[_] : TypeTag]: Encoder[A] = + ExpressionEncoder() + // Boxed primitives /** @since 2.0.0 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 67e3ad6a80642..3a169e487827a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1693,6 +1693,33 @@ class DatasetSuite extends QueryTest checkDataset(ds1.select("_2._2"), ds2.select("_2._2").collect(): _*) } + test("SPARK-23862: Spark ExpressionEncoder should support Java Enum type from Scala") { + val saveModeSeq = + Seq(SaveMode.Append, SaveMode.Overwrite, SaveMode.ErrorIfExists, SaveMode.Ignore, null) + assert(saveModeSeq.toDS().collect().toSeq === saveModeSeq) + assert(saveModeSeq.toDS().schema === new StructType().add("value", StringType, nullable = true)) + + val saveModeCaseSeq = saveModeSeq.map(SaveModeCase.apply) + assert(saveModeCaseSeq.toDS().collect().toSet === saveModeCaseSeq.toSet) + assert(saveModeCaseSeq.toDS().schema === + new StructType().add("mode", StringType, nullable = true)) + + val saveModeArrayCaseSeq = + Seq(SaveModeArrayCase(Array()), SaveModeArrayCase(saveModeSeq.toArray)) + val collected = saveModeArrayCaseSeq.toDS().collect() + assert(collected.length === 2) + val sortedByLength = collected.sortBy(_.modes.length) + assert(sortedByLength(0).modes === Array()) + assert(sortedByLength(1).modes === saveModeSeq.toArray) + assert(saveModeArrayCaseSeq.toDS().schema === + new StructType().add("modes", ArrayType(StringType, containsNull = true), nullable = true)) + + // Enum is stored as string, so it is possible to convert to/from string + val stringSeq = saveModeSeq.map(Option.apply).map(_.map(_.toString).orNull) + assert(stringSeq.toDS().as[SaveMode].collect().toSet === saveModeSeq.toSet) + assert(saveModeSeq.toDS().as[String].collect().toSet === stringSeq.toSet) + } + test("SPARK-24571: filtering of string values by char literal") { val df = Seq("Amsterdam", "San Francisco", "X").toDF("city") checkAnswer(df.where($"city" === 'X'), Seq(Row("X"))) @@ -2053,3 +2080,7 @@ case class CircularReferenceClassD(map: Map[String, CircularReferenceClassE]) case class CircularReferenceClassE(id: String, list: List[CircularReferenceClassD]) case class SpecialCharClass(`field.1`: String, `field 2`: String) + +/** Used to test Java Enums from Scala code */ +case class SaveModeCase(mode: SaveMode) +case class SaveModeArrayCase(modes: Array[SaveMode])