Skip to content

Commit

Permalink
[SPARK-23862][SQL] Support Java enums from Scala Dataset API
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Add support for Java Enums (`java.lang.Enum`) from the Scala typed Dataset APIs. This involves adding an implicit for `Encoder` creation in `SQLImplicits`, and updating `ScalaReflection` to handle Java Enums on the serialization and deserialization pathways.

Enums are mapped to a `StringType` which is just the name of the Enum value.

### Why are the changes needed?
In [SPARK-21255](https://issues.apache.org/jira/browse/SPARK-21255), support for (de)serialization of Java Enums was added, but only when called from Java code. It is common for Scala code to rely on Java libraries that are out of control of the Scala developer. Today, if there is a dependency on some Java code which defines an Enum, it would be necessary to define a corresponding Scala class. This change brings closer feature parity between Scala and Java APIs.

### Does this PR introduce _any_ user-facing change?
Yes, previously something like:
```
val ds = Seq(MyJavaEnum.VALUE1, MyJavaEnum.VALUE2).toDS
// or
val ds = Seq(CaseClass(MyJavaEnum.VALUE1), CaseClass(MyJavaEnum.VALUE2)).toDS
```
would fail. Now, it will succeed.

### How was this patch tested?
Additional unit tests are added in `DatasetSuite`. Tests include validating top-level enums, enums inside of case classes, enums inside of arrays, and validating that the Enum is stored as the expected string.

Closes #30877 from xkrogen/xkrogen-SPARK-23862-scalareflection-java-enums.

Lead-authored-by: Erik Krogen <[email protected]>
Co-authored-by: Fangshi Li <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
2 people authored and dongjoon-hyun committed Dec 22, 2020
1 parent 1d45025 commit 303b8c8
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
createDeserializerForInstant(path)

case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) =>
createDeserializerForTypesSupportValueOf(
Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false),
getClassFromType(t))

case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
createDeserializerForSqlTimestamp(path)

Expand Down Expand Up @@ -526,6 +531,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
createSerializerForJavaBigInteger(inputObject)

case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) =>
createSerializerForJavaEnum(inputObject)

case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
createSerializerForScalaBigInt(inputObject)

Expand Down Expand Up @@ -749,6 +757,7 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.lang.Short]) => Schema(ShortType, nullable = true)
case t if isSubtype(t, localTypeOf[java.lang.Byte]) => Schema(ByteType, nullable = true)
case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => Schema(BooleanType, nullable = true)
case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) => Schema(StringType, nullable = true)
case t if isSubtype(t, definitions.IntTpe) => Schema(IntegerType, nullable = false)
case t if isSubtype(t, definitions.LongTpe) => Schema(LongType, nullable = false)
case t if isSubtype(t, definitions.DoubleTpe) => Schema(DoubleType, nullable = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ object SerializerBuildHelper {
returnNullable = false)
}

def createSerializerForJavaEnum(inputObject: Expression): Expression =
createSerializerForString(Invoke(inputObject, "name", ObjectType(classOf[String])))

def createSerializerForSqlTimestamp(inputObject: Expression): Expression = {
StaticInvoke(
DateTimeUtils.getClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
/** @since 3.0.0 */
implicit def newInstantEncoder: Encoder[java.time.Instant] = Encoders.INSTANT

/** @since 3.2.0 */
implicit def newJavaEnumEncoder[A <: java.lang.Enum[_] : TypeTag]: Encoder[A] =
ExpressionEncoder()

// Boxed primitives

/** @since 2.0.0 */
Expand Down
31 changes: 31 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1693,6 +1693,33 @@ class DatasetSuite extends QueryTest
checkDataset(ds1.select("_2._2"), ds2.select("_2._2").collect(): _*)
}

test("SPARK-23862: Spark ExpressionEncoder should support Java Enum type from Scala") {
val saveModeSeq =
Seq(SaveMode.Append, SaveMode.Overwrite, SaveMode.ErrorIfExists, SaveMode.Ignore, null)
assert(saveModeSeq.toDS().collect().toSeq === saveModeSeq)
assert(saveModeSeq.toDS().schema === new StructType().add("value", StringType, nullable = true))

val saveModeCaseSeq = saveModeSeq.map(SaveModeCase.apply)
assert(saveModeCaseSeq.toDS().collect().toSet === saveModeCaseSeq.toSet)
assert(saveModeCaseSeq.toDS().schema ===
new StructType().add("mode", StringType, nullable = true))

val saveModeArrayCaseSeq =
Seq(SaveModeArrayCase(Array()), SaveModeArrayCase(saveModeSeq.toArray))
val collected = saveModeArrayCaseSeq.toDS().collect()
assert(collected.length === 2)
val sortedByLength = collected.sortBy(_.modes.length)
assert(sortedByLength(0).modes === Array())
assert(sortedByLength(1).modes === saveModeSeq.toArray)
assert(saveModeArrayCaseSeq.toDS().schema ===
new StructType().add("modes", ArrayType(StringType, containsNull = true), nullable = true))

// Enum is stored as string, so it is possible to convert to/from string
val stringSeq = saveModeSeq.map(Option.apply).map(_.map(_.toString).orNull)
assert(stringSeq.toDS().as[SaveMode].collect().toSet === saveModeSeq.toSet)
assert(saveModeSeq.toDS().as[String].collect().toSet === stringSeq.toSet)
}

test("SPARK-24571: filtering of string values by char literal") {
val df = Seq("Amsterdam", "San Francisco", "X").toDF("city")
checkAnswer(df.where($"city" === 'X'), Seq(Row("X")))
Expand Down Expand Up @@ -2053,3 +2080,7 @@ case class CircularReferenceClassD(map: Map[String, CircularReferenceClassE])
case class CircularReferenceClassE(id: String, list: List[CircularReferenceClassD])

case class SpecialCharClass(`field.1`: String, `field 2`: String)

/** Used to test Java Enums from Scala code */
case class SaveModeCase(mode: SaveMode)
case class SaveModeArrayCase(modes: Array[SaveMode])

0 comments on commit 303b8c8

Please sign in to comment.