From 536c005f787e5f56bedbae8946603afbc8d6285e Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 28 Apr 2014 13:53:51 +0800 Subject: [PATCH] Add Exceptional case for constant folding --- .../spark/sql/catalyst/expressions/complexTypes.scala | 4 +++- .../spark/sql/catalyst/optimizer/Optimizer.scala | 11 ++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index c947155cb701c..0add35c971f43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -28,6 +28,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { val children = child :: ordinal :: Nil /** `Null` is returned for invalid ordinals. */ override def nullable = true + override def foldable = child.foldable && ordinal.foldable override def references = children.flatMap(_.references).toSet def dataType = child.dataType match { case ArrayType(dt) => dt @@ -69,7 +70,8 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio type EvaluatedType = Any def dataType = field.dataType - def nullable = field.nullable + override def nullable = field.nullable + override def foldable = child.foldable protected def structType = child.dataType match { case s: StructType => s diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 520101802e25b..aea984cf69de7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -91,9 +91,11 @@ object ColumnPruning extends Rule[LogicalPlan] { */ object ConstantFolding extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case q: LogicalPlan => q transformExpressionsDown { + case q: LogicalPlan => q transformExpressionsUp { // Skip redundant folding of literals. case l: Literal => l + // if it's foldable + case e if e.foldable => Literal(e.eval(null), e.dataType) case e @ Count(Literal(null, _)) => Literal(null, e.dataType) case e @ Sum(Literal(null, _)) => Literal(null, e.dataType) case e @ Average(Literal(null, _)) => Literal(null, e.dataType) @@ -124,15 +126,18 @@ object ConstantFolding extends Rule[LogicalPlan] { case Literal(candidate, _) if(candidate == v) => true case _ => false })) => Literal(true, BooleanType) - // TODO put exceptional cases(Unary & Binary Expression) before here. + + case e @ SortOrder(_, _) => e + // put exceptional cases(Unary & Binary Expression) before here. case e: UnaryExpression => e.child match { case Literal(null, _) => Literal(null, e.dataType) + case _ => e } case e: BinaryExpression => e.children match { case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case _ => e } - case e if e.foldable => Literal(e.eval(null), e.dataType) } } }