Skip to content

Commit

Permalink
Go through the context classloader when reflecting on user types in S…
Browse files Browse the repository at this point in the history
…calaReflection

Replaced calls to `typeOf` with `typeTag[T].in(mirror)`. The shorter method assumes
all types can be found in the classloader that loaded scala-reflect (the primordial
classloader). This assumption is not valid in all contexts (sbt console, Eclipse launchers).

Fixed SPARK-5281
  • Loading branch information
dragos committed May 7, 2015
1 parent 4f87e95 commit d103e70
Showing 1 changed file with 37 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.types._
*/
object ScalaReflection extends ScalaReflection {
val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe
val mirror: universe.Mirror = universe.runtimeMirror(Thread.currentThread().getContextClassLoader)
}

/**
Expand All @@ -36,6 +37,9 @@ trait ScalaReflection {
/** The universe we work in (runtime or macro) */
val universe: scala.reflect.api.Universe

/** The mirror used to access types in the universe */
val mirror: universe.Mirror

import universe._

// The Predef.Map is scala.collection.immutable.Map.
Expand All @@ -52,7 +56,19 @@ trait ScalaReflection {

/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor[T: TypeTag]: Schema =
ScalaReflectionLock.synchronized { schemaFor(typeOf[T]) }
ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) }

/**
* Return the Scala Type for `T` in the current classloader mirror.
*
* Use this method instead of the convenience method `universe.typeOf`, which
* assumes that all types can be found in the classloader that loaded scala-reflect classes.
* That's not necessarily the case when running using Eclipse launchers or even
* Sbt console or test (without `fork := true`).
*
* @see SPARK-5281
*/
private def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe

/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
Expand All @@ -67,25 +83,25 @@ trait ScalaReflection {
val udt = Utils.classForName(className)
.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
Schema(udt, nullable = true)
case t if t <:< typeOf[Option[_]] =>
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
case t if t <:< localTypeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Seq[_]] =>
case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_, _]] =>
case t if t <:< localTypeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[Product] =>
case t if t <:< localTypeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val constructorSymbol = t.member(nme.CONSTRUCTOR)
Expand All @@ -107,19 +123,20 @@ trait ScalaReflection {
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[java.sql.Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.math.BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true)
case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< localTypeOf[java.math.BigDecimal] =>
Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< localTypeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
case t if t <:< localTypeOf[java.lang.Float] => Schema(FloatType, nullable = true)
case t if t <:< localTypeOf[java.lang.Short] => Schema(ShortType, nullable = true)
case t if t <:< localTypeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
case t if t <:< localTypeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
Expand Down

0 comments on commit d103e70

Please sign in to comment.