Skip to content

Commit

Permalink
[SPARK-6542][SQL] add CreateStruct
Browse files Browse the repository at this point in the history
Similar to `CreateArray`, we can add `CreateStruct` to create nested columns. marmbrus

Author: Xiangrui Meng <[email protected]>

Closes apache#5195 from mengxr/SPARK-6542 and squashes the following commits:

3795c57 [Xiangrui Meng] update error message
ae7ac3e [Xiangrui Meng] move unit test to a separate suite
85dd559 [Xiangrui Meng] use NamedExpr
c78e31a [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-6542
85f3106 [Xiangrui Meng] add CreateStruct
  • Loading branch information
mengxr authored and liancheng committed Mar 31, 2015
1 parent 314afd0 commit a05835b
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ class Analyzer(catalog: Catalog,
case o => o :: Nil
}
Alias(c.copy(children = expandedArgs), name)() :: Nil
case Alias(c @ CreateStruct(args), name) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
Alias(c.copy(children = expandedArgs), name)() :: Nil
case o => o :: Nil
},
child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, co
case class CreateArray(children: Seq[Expression]) extends Expression {
override type EvaluatedType = Any

override def foldable: Boolean = !children.exists(!_.foldable)
override def foldable: Boolean = children.forall(_.foldable)

lazy val childTypes = children.map(_.dataType).distinct

Expand All @@ -142,3 +142,30 @@ case class CreateArray(children: Seq[Expression]) extends Expression {

override def toString: String = s"Array(${children.mkString(",")})"
}

/**
* Returns a Row containing the evaluation of all children expressions.
* TODO: [[CreateStruct]] does not support codegen.
*/
case class CreateStruct(children: Seq[NamedExpression]) extends Expression {
override type EvaluatedType = Row

override def foldable: Boolean = children.forall(_.foldable)

override lazy val resolved: Boolean = childrenResolved

override lazy val dataType: StructType = {
assert(resolved,
s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.")
val fields = children.map { child =>
StructField(child.name, child.dataType, child.nullable, child.metadata)
}
StructType(fields)
}

override def nullable: Boolean = false

override def eval(input: Row): EvaluatedType = {
Row(children.map(_.eval(input)): _*)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,34 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
import org.apache.spark.sql.types._


class ExpressionEvaluationSuite extends FunSuite {
class ExpressionEvaluationBaseSuite extends FunSuite {

def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = {
expression.eval(inputRow)
}

def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = {
val actual = try evaluate(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
if(actual != expected) {
val input = if(inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
}
}

def checkDoubleEvaluation(
expression: Expression,
expected: Spread[Double],
inputRow: Row = EmptyRow): Unit = {
val actual = try evaluate(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
actual.asInstanceOf[Double] shouldBe expected
}
}

class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {

test("literals") {
checkEvaluation(Literal(1), 1)
Expand Down Expand Up @@ -134,27 +161,6 @@ class ExpressionEvaluationSuite extends FunSuite {
}
}

def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = {
expression.eval(inputRow)
}

def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = {
val actual = try evaluate(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
if(actual != expected) {
val input = if(inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
}
}

def checkDoubleEvaluation(expression: Expression, expected: Spread[Double], inputRow: Row = EmptyRow): Unit = {
val actual = try evaluate(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
actual.asInstanceOf[Double] shouldBe expected
}

test("IN") {
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true)
Expand Down Expand Up @@ -1081,3 +1087,14 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(~c1, -2, row)
}
}

// TODO: Make the tests work with codegen.
class ExpressionEvaluationWithoutCodeGenSuite extends ExpressionEvaluationBaseSuite {

test("CreateStruct") {
val row = Row(1, 2, 3)
val c1 = 'a.int.at(0).as("a")
val c3 = 'c.int.at(2).as("c")
checkEvaluation(CreateStruct(Seq(c1, c3)), Row(1, 3), row)
}
}

0 comments on commit a05835b

Please sign in to comment.