diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 02b4cbd604903..dedb020337428 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -74,8 +74,17 @@ trait ScalaReflection { }), nullable = true) // Need to decide if we actually need a special type here. case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) + case t if t <:< typeOf[Array[Int]] => Schema(ArrayType(IntegerType, false), nullable = true) + case t if t <:< typeOf[Array[Long]] => Schema(ArrayType(LongType, false), nullable = true) + case t if t <:< typeOf[Array[Double]] => Schema(ArrayType(DoubleType, false), nullable = true) + case t if t <:< typeOf[Array[Short]] => Schema(ArrayType(ShortType, false), nullable = true) + case t if t <:< typeOf[Array[Boolean]] => Schema(ArrayType(BooleanType, false), nullable = true) + case t if t <:< typeOf[Array[Float]] => Schema(ArrayType(FloatType, false), nullable = true) + case t if t <:< typeOf[Array[String]] => Schema(ArrayType(StringType, false), nullable = true) case t if t <:< typeOf[Array[_]] => - sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< typeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala b/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala index e4d9a8e7f2385..5944556b5e170 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala @@ -18,9 +18,11 @@ import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} /** * A collection of Scala macros for working with SQL in a type-safe way. */ -private[sql] object SQLMacros { +object SQLMacros { import scala.reflect.macros._ + var currentContext: SQLContext = _ + def sqlImpl(c: Context)(args: c.Expr[Any]*) = new Macros[c.type](c).sql(args) @@ -68,10 +70,29 @@ private[sql] object SQLMacros { case class RecSchema(name: String, index: Int, cType: DataType, tpe: Type) + def getSchema(sqlQuery: String, interpolatedArguments: Seq[InterpolatedItem]) = { + if (currentContext == null) { + val parser = new SqlParser() + val logicalPlan = parser(sqlQuery) + val catalog = new SimpleCatalog(true) + val functionRegistry = new SimpleFunctionRegistry + val analyzer = new Analyzer(catalog, functionRegistry, true) + + interpolatedArguments.foreach(_.localRegister(catalog, functionRegistry)) + val analyzedPlan = analyzer(logicalPlan) + + analyzedPlan.output.map(attr => (attr.name, attr.dataType)) + } else { + interpolatedArguments.foreach( + _.localRegister(currentContext.catalog, currentContext.functionRegistry)) + currentContext.sql(sqlQuery).schema.fields.map(attr => (attr.name, attr.dataType)) + } + } + def sql(args: Seq[c.Expr[Any]]) = { val q""" - org.apache.spark.sql.test.TestSQLContext.SqlInterpolator( + $path.SQLInterpolation( scala.StringContext.apply(..$rawParts))""" = c.prefix.tree //rawParts.map(_.toString).foreach(println) @@ -96,16 +117,7 @@ private[sql] object SQLMacros { interpolatedArguments(i).placeholderName + parts(i + 1) }.mkString("") - val parser = new SqlParser() - val logicalPlan = parser(query) - val catalog = new SimpleCatalog(true) - val functionRegistry = new SimpleFunctionRegistry - val analyzer = new Analyzer(catalog, functionRegistry, true) - - interpolatedArguments.foreach(_.localRegister(catalog, functionRegistry)) - val analyzedPlan = analyzer(logicalPlan) - - val fields = analyzedPlan.output.map(attr => (attr.name, attr.dataType)) + val fields = getSchema(query, interpolatedArguments) val record = genRecord(q"row", fields) val tree = q""" @@ -157,16 +169,50 @@ private[sql] object SQLMacros { * Constructs a nested record if necessary */ def genGetField(row: Tree, index: Int, t: DataType): Tree = t match { + case BinaryType => + q"$row($index).asInstanceOf[Array[Byte]]" + case DecimalType => + q"$row($index).asInstanceOf[scala.math.BigDecimal]" case t: PrimitiveType => + // this case doesn't work for DecimalType or BinaryType, + // note that they both extend PrimitiveType val methodName = newTermName("get" + primitiveForType(t)) q"$row.$methodName($index)" + case ArrayType(elementType, _) => + val tpe = typeOfDataType(elementType) + q"$row($index).asInstanceOf[Array[$tpe]]" case StructType(structFields) => val fields = structFields.map(f => (f.name, f.dataType)) genRecord(q"$row($index).asInstanceOf[$rowTpe]", fields) case _ => c.abort(NoPosition, s"Query returns currently unhandled field type: $t") } - } + + private def typeOfDataType(dt: DataType): Type = dt match { + case ArrayType(elementType, _) => + val elemTpe = typeOfDataType(elementType) + appliedType(definitions.ArrayClass.toType, List(elemTpe)) + case TimestampType => + typeOf[java.sql.Timestamp] + case DecimalType => + typeOf[BigDecimal] + case BinaryType => + typeOf[Array[Byte]] + case _ if dt.isPrimitive => + typeOfPrimitive(dt.asInstanceOf[PrimitiveType]) + } + + private def typeOfPrimitive(dt: PrimitiveType): Type = dt match { + case IntegerType => typeOf[Int] + case LongType => typeOf[Long] + case ShortType => typeOf[Short] + case ByteType => typeOf[Byte] + case DoubleType => typeOf[Double] + case FloatType => typeOf[Float] + case BooleanType => typeOf[Boolean] + case StringType => typeOf[String] + } + } // end of class Macros // TODO: Duplicated from codegen PR... protected def primitiveForType(dt: PrimitiveType) = dt match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala index d087f8c101c8f..0d975085c328e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala @@ -21,10 +21,28 @@ import org.scalatest.FunSuite import org.apache.spark.sql.test.TestSQLContext +import scala.math.BigDecimal +import scala.language.reflectiveCalls + +import java.sql.Timestamp + case class Person(name: String, age: Int) case class Car(owner: Person, model: String) +case class Garage(cars: Array[Car]) + +case class DataInt(arr: Array[Int]) +case class DataDouble(arr: Array[Double]) +case class DataFloat(arr: Array[Float]) +case class DataString(arr: Array[String]) +case class DataByte(arr: Array[Byte]) +case class DataLong(arr: Array[Long]) +case class DataShort(arr: Array[Short]) +case class DataArrayShort(arr: Array[Array[Short]]) +case class DataBigDecimal(arr: Array[BigDecimal]) +case class DataTimestamp(arr: Array[Timestamp]) + class TypedSqlSuite extends FunSuite { import TestSQLContext._ @@ -35,11 +53,19 @@ class TypedSqlSuite extends FunSuite { val cars = sparkContext.parallelize( Car(Person("Michael", 30), "GrandAm") :: Nil) + val garage = sparkContext.parallelize( + Array(Car(Person("Michael", 30), "GrandAm"), Car(Person("Mary", 52), "Buick"))) + test("typed query") { val results = sql"SELECT name FROM $people WHERE age = 30" assert(results.first().name == "Michael") } + test("typed query with array") { + val results = sql"SELECT * FROM $garage" + assert(results.first().owner == "Michael") + } + test("int results") { val results = sql"SELECT * FROM $people WHERE age = 30" assert(results.first().name == "Michael") @@ -73,4 +99,68 @@ class TypedSqlSuite extends FunSuite { // def addOne(i: Int) = i + 1 // assert(sql"SELECT $addOne(1) as two".first.two === 2) } + + + // tests for different configurations of arrays, primitive and nested + val sqlContext = new org.apache.spark.sql.SQLContext(sparkContext) + + test("array int results") { + val data = sparkContext.parallelize(1 to 10).map(x => DataInt(Array(1, 2, 3))) + val ai = sql"SELECT arr FROM $data" + assert(ai.take(1).head.arr === Array(1, 2, 3)) + } + + test("array double results") { + val data = sparkContext.parallelize(1 to 10).map(x => DataDouble(Array(1.0, 2.0, 3.0))) + val ad = sql"SELECT arr FROM $data" + assert(ad.take(1).head.arr === Array(1.0, 2.0, 3.0)) + } + + test("array float results") { + val data = sparkContext.parallelize(1 to 10).map(x => DataFloat(Array(1F, 2F, 3F))) + val af = sql"SELECT arr FROM $data" + assert(af.take(1).head.arr === Array(1F, 2F, 3F)) + } + + test("array string results") { + val data = sparkContext.parallelize(1 to 10).map(x => DataString(Array("hey","yes","no"))) + val as = sql"SELECT arr FROM $data" + assert(as.take(1).head.arr === Array("hey","yes","no")) + } + + test("array byte results") { + val data = sparkContext.parallelize(1 to 10).map(x => DataByte(Array(1.toByte, 2.toByte, 3.toByte))) + val ab = sql"SELECT arr FROM $data" + assert(ab.take(1).head.arr === Array(1.toByte, 2.toByte, 3.toByte)) + } + + test("array long results") { + val data = sparkContext.parallelize(1 to 10).map(x => DataLong(Array(1L, 2L, 3L))) + val al = sql"SELECT arr FROM $data" + assert(al.take(1).head.arr === Array(1L, 2L, 3L)) + } + + test("array short results") { + val data = sparkContext.parallelize(1 to 10).map(x => DataShort(Array(1.toShort, 2.toShort, 3.toShort))) + val ash = sql"SELECT arr FROM $data" + assert(ash.take(1).head.arr === Array(1.toShort, 2.toShort, 3.toShort)) + } + + test("array of array of short results") { + val data = sparkContext.parallelize(1 to 10).map(x => DataArrayShort(Array(Array(1.toShort, 2.toShort, 3.toShort)))) + val aash = sql"SELECT arr FROM $data" + assert(aash.take(1).head.arr === Array(Array(1.toShort, 2.toShort, 3.toShort))) + } + + test("array bigdecimal results") { + val data = sparkContext.parallelize(1 to 10).map(x => DataBigDecimal(Array(new java.math.BigDecimal(1), new java.math.BigDecimal(2), new java.math.BigDecimal(3)))) + val abd = sql"SELECT arr FROM $data" + assert(abd.take(1).head.arr === Array(new java.math.BigDecimal(1), new java.math.BigDecimal(2), new java.math.BigDecimal(3))) + } + + test("array timestamp results") { + val data = sparkContext.parallelize(1 to 10).map(x => DataTimestamp(Array(new Timestamp(1L), new Timestamp(2L), new Timestamp(3L)))) + val ats = sql"SELECT arr FROM $data" + assert(ats.take(1).head.arr === Array(new Timestamp(1L), new Timestamp(2L), new Timestamp(3L))) + } }