diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 41398ff956edd..0ce68efe24c5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -222,6 +222,11 @@ abstract class Expression extends TreeNode[Expression] { } } +/** + * Root class for rewritten 2 operands UDF expression. By default, we assume it produces Null if + * either one of its operands is null. Exceptional case requires to update the optimization rule + * at [[optimizer.ConstantFolding ConstantFolding]] + */ abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { self: Product => @@ -238,6 +243,11 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] self: Product => } +/** + * Root class for rewritten single operand UDF expression. By default, we assume it produces Null if + * its operand is null. Exceptional case requires to update the optimization rule + * at [[optimizer.ConstantFolding ConstantFolding]] + */ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 08b2f11d20f5e..d2b7685e73065 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.trees abstract sealed class SortDirection case object Ascending extends SortDirection @@ -27,7 +28,10 @@ case object Descending extends SortDirection * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. */ -case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression { +case class SortOrder(child: Expression, direction: SortDirection) extends Expression + with trees.UnaryNode[Expression] { + + override def references = child.references override def dataType = child.dataType override def nullable = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index aea984cf69de7..535ac1c3d56ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.types._ object Optimizer extends RuleExecutor[LogicalPlan] { val batches = Batch("ConstantFolding", Once, + NullPropagation, ConstantFolding, BooleanSimplification, SimplifyFilters, @@ -87,23 +88,18 @@ object ColumnPruning extends Rule[LogicalPlan] { /** * Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with - * equivalent [[catalyst.expressions.Literal Literal]] values. + * equivalent [[catalyst.expressions.Literal Literal]] values. This rule is more specific with + * Null value propagation from bottom to top of the expression tree. */ -object ConstantFolding extends Rule[LogicalPlan] { +object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { // Skip redundant folding of literals. case l: Literal => l - // if it's foldable - case e if e.foldable => Literal(e.eval(null), e.dataType) case e @ Count(Literal(null, _)) => Literal(null, e.dataType) case e @ Sum(Literal(null, _)) => Literal(null, e.dataType) case e @ Average(Literal(null, _)) => Literal(null, e.dataType) - case e @ IsNull(Literal(null, _)) => Literal(true, BooleanType) - case e @ IsNull(Literal(_, _)) => Literal(false, BooleanType) case e @ IsNull(c @ Rand) => Literal(false, BooleanType) - case e @ IsNotNull(Literal(null, _)) => Literal(false, BooleanType) - case e @ IsNotNull(Literal(_, _)) => Literal(true, BooleanType) case e @ IsNotNull(c @ Rand) => Literal(true, BooleanType) case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) @@ -113,10 +109,10 @@ object ConstantFolding extends Rule[LogicalPlan] { case Literal(null, _) => false case _ => true }) - if(newChildren.length == null) { + if(newChildren.length == 0) { Literal(null, e.dataType) - } else if(newChildren.length == children.length){ - e + } else if(newChildren.length == 1) { + newChildren(0) } else { Coalesce(newChildren) } @@ -126,9 +122,8 @@ object ConstantFolding extends Rule[LogicalPlan] { case Literal(candidate, _) if(candidate == v) => true case _ => false })) => Literal(true, BooleanType) - - case e @ SortOrder(_, _) => e - // put exceptional cases(Unary & Binary Expression) before here. + // Put exceptional cases(Unary & Binary Expression if it doesn't produce null with constant + // null operand) before here. case e: UnaryExpression => e.child match { case Literal(null, _) => Literal(null, e.dataType) case _ => e @@ -141,6 +136,19 @@ object ConstantFolding extends Rule[LogicalPlan] { } } } +/** + * Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with + * equivalent [[catalyst.expressions.Literal Literal]] values. + */ +object ConstantFolding extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsDown { + // Skip redundant folding of literals. + case l: Literal => l + case e if e.foldable => Literal(e.eval(null), e.dataType) + } + } +} /** * Simplifies boolean expressions where the answer can be determined without evaluating both sides.