Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13995][SQL] Extract correct IsNotNull constraints for Expression #11809

Closed
wants to merge 13 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ object Cast {
}

/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant {

override def toString: String = s"cast($child as ${dataType.simpleString})"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval


case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {
case class UnaryMinus(child: Expression) extends UnaryExpression
with ExpectsInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)

Expand Down Expand Up @@ -58,7 +59,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
override def sql: String = s"(-${child.sql})"
}

case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
case class UnaryPositive(child: Expression)
extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def prettyName: String = "positive"

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
Expand All @@ -79,7 +81,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the absolute value of the numeric value",
extended = "> SELECT _FUNC_('-1');\n1")
case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes {
case class Abs(child: Expression)
extends UnaryExpression with ExpectsInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

Expand Down Expand Up @@ -123,7 +126,7 @@ private[sql] object BinaryArithmetic {
def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right))
}

case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.NumericAndInterval

Expand Down Expand Up @@ -152,7 +155,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
}
}

case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
case class Subtract(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.NumericAndInterval

Expand Down Expand Up @@ -181,7 +185,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
}
}

case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
case class Multiply(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {

override def inputType: AbstractDataType = NumericType

Expand All @@ -193,7 +198,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
}

case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
case class Divide(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {

override def inputType: AbstractDataType = NumericType

Expand Down Expand Up @@ -269,7 +275,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
}
}

case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
case class Remainder(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {

override def inputType: AbstractDataType = NumericType

Expand Down Expand Up @@ -457,7 +464,7 @@ case class MinOf(left: Expression, right: Expression)
override def symbol: String = "min"
}

case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {

override def toString: String = s"pmod($left, $right)"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ trait NamedExpression extends Expression {
}
}

abstract class Attribute extends LeafExpression with NamedExpression {
abstract class Attribute extends LeafExpression with NamedExpression with NullIntolerant {

override def references: AttributeSet = AttributeSet(this)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,11 @@ package object expressions {
StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable)))
}
}

/**
* When an expression inherits this, meaning the expression is null intolerant (i.e. any null
* input will result in null output). We will use this information during constructing IsNotNull
* constraints.
*/
trait NullIntolerant
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ trait PredicateHelper {


case class Not(child: Expression)
extends UnaryExpression with Predicate with ImplicitCastInputTypes {
extends UnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant {

override def toString: String = s"NOT $child"

Expand Down Expand Up @@ -402,7 +402,8 @@ private[sql] object Equality {
}


case class EqualTo(left: Expression, right: Expression) extends BinaryComparison {
case class EqualTo(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = AnyDataType

Expand Down Expand Up @@ -467,7 +468,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
}


case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
case class LessThan(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.Ordered

Expand All @@ -479,7 +481,8 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
}


case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
case class LessThanOrEqual(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.Ordered

Expand All @@ -491,7 +494,8 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
}


case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
case class GreaterThan(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.Ordered

Expand All @@ -503,7 +507,8 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar
}


case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
case class GreaterThanOrEqual(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.Ordered

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
* returns a constraint of the form `isNotNull(a)`
*/
private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
var isNotNullConstraints = Set.empty[Expression]

// First, we propagate constraints if the condition consists of equality and ranges. For all
// other cases, we return an empty set of constraints
constraints.foreach {
case EqualTo(l, r) =>
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
case GreaterThan(l, r) =>
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
case GreaterThanOrEqual(l, r) =>
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
case LessThan(l, r) =>
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
case LessThanOrEqual(l, r) =>
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
case Not(EqualTo(l, r)) =>
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
case _ => // No inference
}
// First, we propagate constraints from the null intolerant expressions.
var isNotNullConstraints: Set[Expression] =
constraints.flatMap(scanNullIntolerantExpr).map(IsNotNull(_))

// Second, we infer additional constraints from non-nullable attributes that are part of the
// operator's output
Expand All @@ -72,6 +56,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
isNotNullConstraints -- constraints
}

/**
* Recursively explores the expressions which are null intolerant and returns all attributes
* in these expressions.
*/
private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match {
case a: Attribute => Seq(a)
case _: NullIntolerant | IsNotNull(_: NullIntolerant) =>
expr.children.flatMap(scanNullIntolerantExpr)
case _ => Seq.empty[Attribute]
}

/**
* Infers an additional set of constraints from a given set of equality constraints.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.{IntegerType, StringType}
import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType}

class ConstraintPropagationSuite extends SparkFunSuite {

Expand Down Expand Up @@ -219,6 +219,89 @@ class ConstraintPropagationSuite extends SparkFunSuite {
IsNotNull(resolveColumn(tr, "b")))))
}

test("infer constraints on cast") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need more test cases here. Any bug in this PR could cause query return a wrong answer. We need to be very cautious.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more tests.

val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
verifyConstraints(
tr.where('a.attr === 'b.attr &&
'c.attr + 100 > 'd.attr &&
IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))).analyze.constraints,
ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"),
Cast(resolveColumn(tr, "c") + 100, LongType) > resolveColumn(tr, "d"),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
IsNotNull(resolveColumn(tr, "c")),
IsNotNull(resolveColumn(tr, "d")),
IsNotNull(resolveColumn(tr, "e")),
IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType)))))
}

test("infer isnotnull constraints from compound expressions") {
val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
verifyConstraints(
tr.where('a.attr + 'b.attr === 'c.attr &&
IsNotNull(
Cast(
Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints,
ExpressionSet(Seq(
Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") ===
Cast(resolveColumn(tr, "c"), LongType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
IsNotNull(resolveColumn(tr, "c")),
IsNotNull(resolveColumn(tr, "e")),
IsNotNull(Cast(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType)))))

verifyConstraints(
tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints,
ExpressionSet(Seq(
Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) ===
Cast(resolveColumn(tr, "c"), LongType),
Cast(resolveColumn(tr, "d"), DoubleType) /
Cast(Cast(10, LongType), DoubleType) ===
Cast(resolveColumn(tr, "e"), DoubleType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
IsNotNull(resolveColumn(tr, "c")),
IsNotNull(resolveColumn(tr, "d")),
IsNotNull(resolveColumn(tr, "e")))))

verifyConstraints(
tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints,
ExpressionSet(Seq(
Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >=
Cast(resolveColumn(tr, "c"), LongType),
Cast(resolveColumn(tr, "d"), DoubleType) /
Cast(Cast(10, LongType), DoubleType) <
Cast(resolveColumn(tr, "e"), DoubleType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
IsNotNull(resolveColumn(tr, "c")),
IsNotNull(resolveColumn(tr, "d")),
IsNotNull(resolveColumn(tr, "e")))))

verifyConstraints(
tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints,
ExpressionSet(Seq(
(Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) -
(Cast(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) >
Cast(resolveColumn(tr, "e") * 1000, LongType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
IsNotNull(resolveColumn(tr, "c")),
IsNotNull(resolveColumn(tr, "d")),
IsNotNull(resolveColumn(tr, "e")))))

// The constraint IsNotNull(IsNotNull(expr)) doesn't guarantee expr is not null.
verifyConstraints(
tr.where('a.attr === 'c.attr &&
IsNotNull(IsNotNull(resolveColumn(tr, "b")))).analyze.constraints,
ExpressionSet(Seq(
resolveColumn(tr, "a") === resolveColumn(tr, "c"),
IsNotNull(IsNotNull(resolveColumn(tr, "b"))),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "c")))))
}

test("infer IsNotNull constraints from non-nullable attributes") {
val tr = LocalRelation('a.int, AttributeReference("b", IntegerType, nullable = false)(),
AttributeReference("c", StringType, nullable = false)())
Expand Down