diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index bec3c36b61302..372acc7c02f9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -752,9 +752,11 @@ object NullPropagation extends Rule[LogicalPlan] { } def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( - _.containsAnyPattern(NULL_CHECK, NULL_LITERAL, COUNT, CAST), ruleId) { + t => t.containsAnyPattern(NULL_CHECK, NULL_LITERAL, COUNT) + || t.containsAllPatterns(WINDOW_EXPRESSION, CAST, LITERAL), ruleId) { case q: LogicalPlan => q.transformExpressionsUpWithPruning( - _.containsAnyPattern(NULL_CHECK, NULL_LITERAL, COUNT, CAST), ruleId) { + t => t.containsAnyPattern(NULL_CHECK, NULL_LITERAL, COUNT) + || t.containsAllPatterns(WINDOW_EXPRESSION, CAST, LITERAL), ruleId) { case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) => Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) case e @ AggregateExpression(Count(exprs), _, _, _, _) if exprs.forall(isNullLiteral) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index bb999ffa72b9d..f92444b6b01f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -28,9 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.catalyst.trees.TreePattern.{ - FILTER, INNER_LIKE_JOIN, JOIN, LEFT_SEMI_OR_ANTI_JOIN, NATURAL_LIKE_JOIN, OUTER_JOIN, TreePattern -} +import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf