Skip to content

Commit

Permalink
Merge pull request #11 from heathermiller/typedSql
Browse files Browse the repository at this point in the history
Adding array support (waiting on merge in scala-records)
  • Loading branch information
marmbrus committed Sep 14, 2014
2 parents 6e1eaf3 + 4eba0aa commit f170b0f
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,17 @@ trait ScalaReflection {
}), nullable = true)
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[Int]] => Schema(ArrayType(IntegerType, false), nullable = true)
case t if t <:< typeOf[Array[Long]] => Schema(ArrayType(LongType, false), nullable = true)
case t if t <:< typeOf[Array[Double]] => Schema(ArrayType(DoubleType, false), nullable = true)
case t if t <:< typeOf[Array[Short]] => Schema(ArrayType(ShortType, false), nullable = true)
case t if t <:< typeOf[Array[Boolean]] => Schema(ArrayType(BooleanType, false), nullable = true)
case t if t <:< typeOf[Array[Float]] => Schema(ArrayType(FloatType, false), nullable = true)
case t if t <:< typeOf[Array[String]] => Schema(ArrayType(StringType, false), nullable = true)
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Expand Down
72 changes: 59 additions & 13 deletions sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
/**
* A collection of Scala macros for working with SQL in a type-safe way.
*/
private[sql] object SQLMacros {
object SQLMacros {
import scala.reflect.macros._

var currentContext: SQLContext = _

def sqlImpl(c: Context)(args: c.Expr[Any]*) =
new Macros[c.type](c).sql(args)

Expand Down Expand Up @@ -68,10 +70,29 @@ private[sql] object SQLMacros {

case class RecSchema(name: String, index: Int, cType: DataType, tpe: Type)

def getSchema(sqlQuery: String, interpolatedArguments: Seq[InterpolatedItem]) = {
if (currentContext == null) {
val parser = new SqlParser()
val logicalPlan = parser(sqlQuery)
val catalog = new SimpleCatalog(true)
val functionRegistry = new SimpleFunctionRegistry
val analyzer = new Analyzer(catalog, functionRegistry, true)

interpolatedArguments.foreach(_.localRegister(catalog, functionRegistry))
val analyzedPlan = analyzer(logicalPlan)

analyzedPlan.output.map(attr => (attr.name, attr.dataType))
} else {
interpolatedArguments.foreach(
_.localRegister(currentContext.catalog, currentContext.functionRegistry))
currentContext.sql(sqlQuery).schema.fields.map(attr => (attr.name, attr.dataType))
}
}

def sql(args: Seq[c.Expr[Any]]) = {

val q"""
org.apache.spark.sql.test.TestSQLContext.SqlInterpolator(
$path.SQLInterpolation(
scala.StringContext.apply(..$rawParts))""" = c.prefix.tree

//rawParts.map(_.toString).foreach(println)
Expand All @@ -96,16 +117,7 @@ private[sql] object SQLMacros {
interpolatedArguments(i).placeholderName + parts(i + 1)
}.mkString("")

val parser = new SqlParser()
val logicalPlan = parser(query)
val catalog = new SimpleCatalog(true)
val functionRegistry = new SimpleFunctionRegistry
val analyzer = new Analyzer(catalog, functionRegistry, true)

interpolatedArguments.foreach(_.localRegister(catalog, functionRegistry))
val analyzedPlan = analyzer(logicalPlan)

val fields = analyzedPlan.output.map(attr => (attr.name, attr.dataType))
val fields = getSchema(query, interpolatedArguments)
val record = genRecord(q"row", fields)

val tree = q"""
Expand Down Expand Up @@ -157,16 +169,50 @@ private[sql] object SQLMacros {
* Constructs a nested record if necessary
*/
def genGetField(row: Tree, index: Int, t: DataType): Tree = t match {
case BinaryType =>
q"$row($index).asInstanceOf[Array[Byte]]"
case DecimalType =>
q"$row($index).asInstanceOf[scala.math.BigDecimal]"
case t: PrimitiveType =>
// this case doesn't work for DecimalType or BinaryType,
// note that they both extend PrimitiveType
val methodName = newTermName("get" + primitiveForType(t))
q"$row.$methodName($index)"
case ArrayType(elementType, _) =>
val tpe = typeOfDataType(elementType)
q"$row($index).asInstanceOf[Array[$tpe]]"
case StructType(structFields) =>
val fields = structFields.map(f => (f.name, f.dataType))
genRecord(q"$row($index).asInstanceOf[$rowTpe]", fields)
case _ =>
c.abort(NoPosition, s"Query returns currently unhandled field type: $t")
}
}

private def typeOfDataType(dt: DataType): Type = dt match {
case ArrayType(elementType, _) =>
val elemTpe = typeOfDataType(elementType)
appliedType(definitions.ArrayClass.toType, List(elemTpe))
case TimestampType =>
typeOf[java.sql.Timestamp]
case DecimalType =>
typeOf[BigDecimal]
case BinaryType =>
typeOf[Array[Byte]]
case _ if dt.isPrimitive =>
typeOfPrimitive(dt.asInstanceOf[PrimitiveType])
}

private def typeOfPrimitive(dt: PrimitiveType): Type = dt match {
case IntegerType => typeOf[Int]
case LongType => typeOf[Long]
case ShortType => typeOf[Short]
case ByteType => typeOf[Byte]
case DoubleType => typeOf[Double]
case FloatType => typeOf[Float]
case BooleanType => typeOf[Boolean]
case StringType => typeOf[String]
}
} // end of class Macros

// TODO: Duplicated from codegen PR...
protected def primitiveForType(dt: PrimitiveType) = dt match {
Expand Down
90 changes: 90 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,28 @@ import org.scalatest.FunSuite

import org.apache.spark.sql.test.TestSQLContext

import scala.math.BigDecimal
import scala.language.reflectiveCalls

import java.sql.Timestamp

case class Person(name: String, age: Int)

case class Car(owner: Person, model: String)

case class Garage(cars: Array[Car])

case class DataInt(arr: Array[Int])
case class DataDouble(arr: Array[Double])
case class DataFloat(arr: Array[Float])
case class DataString(arr: Array[String])
case class DataByte(arr: Array[Byte])
case class DataLong(arr: Array[Long])
case class DataShort(arr: Array[Short])
case class DataArrayShort(arr: Array[Array[Short]])
case class DataBigDecimal(arr: Array[BigDecimal])
case class DataTimestamp(arr: Array[Timestamp])

class TypedSqlSuite extends FunSuite {
import TestSQLContext._

Expand All @@ -35,11 +53,19 @@ class TypedSqlSuite extends FunSuite {
val cars = sparkContext.parallelize(
Car(Person("Michael", 30), "GrandAm") :: Nil)

val garage = sparkContext.parallelize(
Array(Car(Person("Michael", 30), "GrandAm"), Car(Person("Mary", 52), "Buick")))

test("typed query") {
val results = sql"SELECT name FROM $people WHERE age = 30"
assert(results.first().name == "Michael")
}

test("typed query with array") {
val results = sql"SELECT * FROM $garage"
assert(results.first().owner == "Michael")
}

test("int results") {
val results = sql"SELECT * FROM $people WHERE age = 30"
assert(results.first().name == "Michael")
Expand Down Expand Up @@ -73,4 +99,68 @@ class TypedSqlSuite extends FunSuite {
// def addOne(i: Int) = i + 1
// assert(sql"SELECT $addOne(1) as two".first.two === 2)
}


// tests for different configurations of arrays, primitive and nested
val sqlContext = new org.apache.spark.sql.SQLContext(sparkContext)

test("array int results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataInt(Array(1, 2, 3)))
val ai = sql"SELECT arr FROM $data"
assert(ai.take(1).head.arr === Array(1, 2, 3))
}

test("array double results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataDouble(Array(1.0, 2.0, 3.0)))
val ad = sql"SELECT arr FROM $data"
assert(ad.take(1).head.arr === Array(1.0, 2.0, 3.0))
}

test("array float results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataFloat(Array(1F, 2F, 3F)))
val af = sql"SELECT arr FROM $data"
assert(af.take(1).head.arr === Array(1F, 2F, 3F))
}

test("array string results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataString(Array("hey","yes","no")))
val as = sql"SELECT arr FROM $data"
assert(as.take(1).head.arr === Array("hey","yes","no"))
}

test("array byte results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataByte(Array(1.toByte, 2.toByte, 3.toByte)))
val ab = sql"SELECT arr FROM $data"
assert(ab.take(1).head.arr === Array(1.toByte, 2.toByte, 3.toByte))
}

test("array long results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataLong(Array(1L, 2L, 3L)))
val al = sql"SELECT arr FROM $data"
assert(al.take(1).head.arr === Array(1L, 2L, 3L))
}

test("array short results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataShort(Array(1.toShort, 2.toShort, 3.toShort)))
val ash = sql"SELECT arr FROM $data"
assert(ash.take(1).head.arr === Array(1.toShort, 2.toShort, 3.toShort))
}

test("array of array of short results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataArrayShort(Array(Array(1.toShort, 2.toShort, 3.toShort))))
val aash = sql"SELECT arr FROM $data"
assert(aash.take(1).head.arr === Array(Array(1.toShort, 2.toShort, 3.toShort)))
}

test("array bigdecimal results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataBigDecimal(Array(new java.math.BigDecimal(1), new java.math.BigDecimal(2), new java.math.BigDecimal(3))))
val abd = sql"SELECT arr FROM $data"
assert(abd.take(1).head.arr === Array(new java.math.BigDecimal(1), new java.math.BigDecimal(2), new java.math.BigDecimal(3)))
}

test("array timestamp results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataTimestamp(Array(new Timestamp(1L), new Timestamp(2L), new Timestamp(3L))))
val ats = sql"SELECT arr FROM $data"
assert(ats.take(1).head.arr === Array(new Timestamp(1L), new Timestamp(2L), new Timestamp(3L)))
}
}

0 comments on commit f170b0f

Please sign in to comment.