Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-27915][SQL][WIP] Update logical Filter's output nullability based on IsNotNull conditions #24765

Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,49 @@ trait PredicateHelper {
case e: Unevaluable => false
case e => e.children.forall(canEvaluateWithinJoin)
}

/**
* Given an IsNotNull expression, returns the IDs of expressions whose not-nullness
* is implied by the IsNotNull expressions.
*/
protected def getImpliedNotNullExprIds(isNotNullExpr: IsNotNull): Set[ExprId] = {
// This logic is a little tricky, so we'll use an example to build some intuition.
// Consider the expression IsNotNull(f(g(x), y)). By definition, its child is not null:
// f(g(x), y) is not null
// In addition, if `f` is NullIntolerant then it would be null if either child was null:
// g(x) is null => f(g(x), y) is null
// y is null => f(g(x), y) is null
// Via A => B <=> !B || A, we have:
// g(x) is not null || f(g(x), y) is null
// y is not null || f(g(x), y) is null
// Since we know that f(g(x), y) is not null, we must therefore conclude that
// g(x) is not null
// y is not null
// By recursively applying this logic, if g is NullIntolerant then x is not null.
// However, if g is NOT NullIntolerant (e.g. if g(null) is non-null) then we cannot
// conclude anything about x's nullability.
def getExprIdIfNamed(expr: Expression): Set[ExprId] = expr match {
case ne: NamedExpression => Set(ne.toAttribute.exprId)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be AttributeReference? I couldn't remember offhand how to get ExprIds from arbitrary expressions, hence this hack.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use AttributeSet?

case _ => Set.empty
}
def isNullIntolerant(expr: Expression): Boolean = expr match {
case _: NullIntolerant => true
case Alias(_: NullIntolerant, _) => true
case _ => false
}
// Recurse through the IsNotNull expression's descendants, stopping
// once we encounter a null-tolerant expression.
def getNotNullDescendants(expr: Expression): Set[ExprId] = {
expr.children.map { child =>
if (isNullIntolerant(child)) {
getExprIdIfNamed(child) ++ getNotNullDescendants(child)
} else {
getExprIdIfNamed(child)
}
}.foldLeft(Set.empty[ExprId])(_ ++ _)
}
getExprIdIfNamed(isNotNullExpr) ++ getNotNullDescendants(isNotNullExpr)
}
}

@ExpressionDescription(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,14 @@ case class Subquery(child: LogicalPlan) extends OrderPreservingUnaryNode {

case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
override def output: Seq[Attribute] = {
// The child operator may have inferred more precise nullability information
// for the project expression, so leverage that information if it's availble:
val childOutputNullability = child.output.map(a => a.exprId -> a.nullable).toMap
projectList
.map(_.toAttribute)
.map{ a => childOutputNullability.get(a.exprId).map(a.withNullability).getOrElse(a) }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to fix this part? It seems UpdateAttributeNullability could handle this case if Filter.output works well?

}
override def maxRows: Option[Long] = child.maxRows

override lazy val resolved: Boolean = {
Expand Down Expand Up @@ -129,7 +136,22 @@ case class Generate(

case class Filter(condition: Expression, child: LogicalPlan)
extends OrderPreservingUnaryNode with PredicateHelper {
override def output: Seq[Attribute] = child.output

override def output: Seq[Attribute] = {
val impliedNotNullExprIds: Set[ExprId] = {
splitConjunctivePredicates(condition)
.collect { case isNotNull: IsNotNull => isNotNull }
.map(getImpliedNotNullExprIds)
.foldLeft(Set.empty[ExprId])(_ ++ _)
}
child.output.map { a =>
if (a.nullable && impliedNotNullExprIds.contains(a.exprId)) {
a.withNullability(false)
} else {
a
}
}
}

override def maxRows: Option[Long] = child.maxRows

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.types._

class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with PredicateHelper {

def testAllTypes(testFunc: (Any, DataType) => Unit): Unit = {
testFunc(false, BooleanType)
Expand Down Expand Up @@ -175,4 +175,22 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val inputs = (1 to 4000).map(x => Literal(s"x_$x"))
checkEvaluation(AtLeastNNonNulls(1, inputs), true)
}

test("getImpliedNotNullExprIds") {
val a = AttributeReference("a", IntegerType)(exprId = ExprId(1))
val b = AttributeReference("b", IntegerType)(exprId = ExprId(2))

// Simple case of IsNotNull of a leaf value:
assert(getImpliedNotNullExprIds(IsNotNull(a)) == Set(a.exprId))

// Even though we can't make claims about its children, a non-NullIntolerant is
// expression is still considered non-null due to its parent IsNotNull expression:
val coalesceExpr = Alias(Coalesce(Seq(a, b)), "c")(exprId = ExprId(3))
assert(getImpliedNotNullExprIds(IsNotNull(coalesceExpr)) == Set(coalesceExpr.exprId))

// NullIntolerant expressions propagate the non-null constraint to all of their children:
val addExpr = Alias(Add(a, b), "add")(exprId = ExprId(4))
assert(addExpr.child.isInstanceOf[NullIntolerant])
assert(getImpliedNotNullExprIds(IsNotNull(addExpr)) == Set(a, b, addExpr).map(_.exprId))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,28 +87,25 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
case class FilterExec(condition: Expression, child: SparkPlan)
extends UnaryExecNode with CodegenSupport with PredicateHelper {

// Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate
// all the variables at the beginning to take advantage of short circuiting.
override def usedInputs: AttributeSet = AttributeSet.empty

// Split out all the IsNotNulls from condition.
private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the old code here to be slightly confusing because it seemed to be using notNullPreds for two different purposes:

  1. If we see IsNotNull conjuncts in the filter then evaluate them first / earlier because (a) these expressions are cheap to evaluate and may allow for short-circuiting and skipping more expensive expressions, and (b) evaluating these earlier allows other expressions to omit null checks (for example, if we have IsNotNull(x) and x * 100 < 10 then we already implicitly need to null-check x as part of the second expression so we might as well do the explicit null check expression first).
  2. Given that tuples have successfully passed through the filter, we can rely on the presence of IsNotNull checks to default subsequent expressions' null checks to false. For example, let's say we had a .filter().select() which gets compiled into a single whole stage codegen: after tuples have passed through the filter we know that certain fields cannot possibly be null, so we can elide null checks at codegen time by just setting nullable = false in subsequent code.

There might be some subtleties related in (1) related to non-deterministic expressions, but I think that's accounted for further down at the place where we're actually generating the checks.

In the old code, the (notNullPreds, otherPreds) on this line was being used for both purposes: for (1) I think we could simply collect all IsNotNull expressions, but the existing implementation of (2) relied on the additional nullIntolerant / a.references checks in order to be correct.

In this PR, I've separated these two usages: the "update nullability for downstream operators" now uses the more precise condition implemented in getImpliedNotNullExprIds, while the "optimize short-circuiting" simply checks for IsNotNull and ignores child attributes.

case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet)
case IsNotNull(_) => true
case _ => false
}

// If one expression and its children are null intolerant, it is null intolerant.
private def isNullIntolerant(expr: Expression): Boolean = expr match {
case e: NullIntolerant => e.children.forall(isNullIntolerant)
case _ => false
private val impliedNotNullExprIds: Set[ExprId] = {
notNullPreds
.map { case n: IsNotNull => getImpliedNotNullExprIds(n) }
.foldLeft(Set.empty[ExprId])(_ ++ _)
}

// The columns that will filtered out by `IsNotNull` could be considered as not nullable.
private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId)

// Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate
// all the variables at the beginning to take advantage of short circuiting.
override def usedInputs: AttributeSet = AttributeSet.empty

override def output: Seq[Attribute] = {
child.output.map { a =>
if (a.nullable && notNullAttributes.contains(a.exprId)) {
if (a.nullable && impliedNotNullExprIds.contains(a.exprId)) {
a.withNullability(false)
} else {
a
Expand Down Expand Up @@ -193,7 +190,7 @@ case class FilterExec(condition: Expression, child: SparkPlan)
// Reset the isNull to false for the not-null columns, then the followed operators could
// generate better code (remove dead branches).
val resultVars = input.zipWithIndex.map { case (ev, i) =>
if (notNullAttributes.contains(child.output(i).exprId)) {
if (impliedNotNullExprIds.contains(child.output(i).exprId)) {
ev.isNull = FalseLiteral
}
ev
Expand Down