From c2f9b058f6ad243e35518bc1758fe7ef7ac2c25f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Mar 2016 14:02:12 +0000 Subject: [PATCH] init import. --- .../sql/catalyst/planning/patterns.scala | 22 +++++++++++++++++++ .../spark/sql/catalyst/plans/QueryPlan.scala | 8 +++++++ 2 files changed, 30 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 62d54df98ecc5..96eede403e9f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -202,3 +202,25 @@ object Unions { } } } + +/** + * A pattern that finds the original expression from a sequence of casts. + */ +object Casts { + def unapply(expr: Expression): Option[Attribute] = expr match { + case c: Cast => collectCasts(expr) + case _ => None + } + + private def collectCasts(e: Expression): Option[Attribute] = { + if (e.isInstanceOf[Cast]) { + collectCasts(e.children(0)) + } else { + if (e.isInstanceOf[Attribute]) { + Some(e.asInstanceOf[Attribute]) + } else { + None + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 920e989d058dc..71b9f762d49e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.Casts import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} @@ -36,6 +37,13 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT .union(constructIsNotNullConstraints(constraints)) .filter(constraint => constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) + .map(_.transform { + case n @ IsNotNull(c) => + c match { + case Casts(a) if outputSet.contains(a) => IsNotNull(a) + case _ => n + } + }) } /**