Skip to content

Commit

Permalink
[SPARK-35078][SQL] Add tree traversal pruning in expression rules
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Added the following TreePattern enums:
- AND_OR
- BINARY_ARITHMETIC
- BINARY_COMPARISON
- CASE_WHEN
- CAST
- CONCAT
- COUNT
- IF
- LIKE_FAMLIY
- NOT
- NULL_CHECK
- UNARY_POSITIVE
- UPPER_OR_LOWER

Used them in the following rules:
- ConstantPropagation
- ReorderAssociativeOperator
- BooleanSimplification
- SimplifyBinaryComparison
- SimplifyCaseConversionExpressions
- SimplifyConditionals
- PushFoldableIntoBranches
- LikeSimplification
- NullPropagation
- SimplifyCasts
- RemoveDispensableExpressions
- CombineConcats

### Why are the changes needed?

Reduce the number of tree traversals and hence improve the query compilation latency.

### How was this patch tested?

Existing tests.

Closes #32280 from sigmod/expression.

Authored-by: Yingyi Bu <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
  • Loading branch information
sigmod authored and gengliangwang committed Apr 23, 2021
1 parent fdccd88 commit 9af338c
Show file tree
Hide file tree
Showing 14 changed files with 120 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvable
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern.{CAST, TreePattern}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
Expand Down Expand Up @@ -1800,6 +1801,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

final override val nodePatterns: Seq[TreePattern] = Seq(CAST)

override protected val ansiEnabled: Boolean = SQLConf.get.ansiEnabled

override def canCast(from: DataType, to: DataType): Boolean = if (ansiEnabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreePattern.{COUNT, TreePattern}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -48,6 +49,8 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {

override def nullable: Boolean = false

final override val nodePatterns: Seq[TreePattern] = Seq(COUNT)

// Return data type.
override def dataType: DataType = LongType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, TreePattern,
UNARY_POSITIVE}
import org.apache.spark.sql.catalyst.util.{IntervalUtils, TypeUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -128,6 +130,8 @@ case class UnaryPositive(child: Expression)

override def dataType: DataType = child.dataType

final override val nodePatterns: Seq[TreePattern] = Seq(UNARY_POSITIVE)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
defineCodeGen(ctx, ev, c => c)

Expand Down Expand Up @@ -199,6 +203,8 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {

override def dataType: DataType = left.dataType

final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_ARITHMETIC)

override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess

/** Name of the function for this expression on a [[Decimal]] type. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{CONCAT, TreePattern}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
Expand Down Expand Up @@ -2172,6 +2173,8 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio

private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType)

final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT)

override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty) {
TypeCheckResult.TypeCheckSuccess
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TernaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{CASE_WHEN, IF, TreePattern}
import org.apache.spark.sql.types._

// scalastyle:off line.size.limit
Expand All @@ -48,6 +49,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
override def third: Expression = falseValue
override def nullable: Boolean = trueValue.nullable || falseValue.nullable

final override val nodePatterns : Seq[TreePattern] = Seq(IF)

override def checkInputDataTypes(): TypeCheckResult = {
if (predicate.dataType != BooleanType) {
TypeCheckResult.TypeCheckFailure(
Expand Down Expand Up @@ -139,6 +142,8 @@ case class CaseWhen(

override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue

final override val nodePatterns : Seq[TreePattern] = Seq(CASE_WHEN)

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
super.legacyWithNewChildren(newChildren)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_CHECK, TreePattern}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -345,6 +346,8 @@ case class NaNvl(left: Expression, right: Expression)
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false

final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)

override def eval(input: InternalRow): Any = {
child.eval(input) == null
}
Expand Down Expand Up @@ -375,6 +378,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false

final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)

override def eval(input: InternalRow): Any = {
child.eval(input) != null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TernaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_CHECK, TreePattern}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -1705,6 +1706,8 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)
override def foldable: Boolean = false
override def nullable: Boolean = false

final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)

override def flatArguments: Iterator[Any] = Iterator(child)

private val errMsg = "Null value appeared in non-nullable field:" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.trees.TreePattern.{IN, IN_SUBQUERY, INSET, TreePattern}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -309,6 +309,8 @@ case class Not(child: Expression)

override def inputTypes: Seq[DataType] = Seq(BooleanType)

final override val nodePatterns: Seq[TreePattern] = Seq(NOT)

// +---------+-----------+
// | CHILD | NOT CHILD |
// +---------+-----------+
Expand Down Expand Up @@ -435,7 +437,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)

override val nodePatterns: Seq[TreePattern] = Seq(IN)
final override val nodePatterns: Seq[TreePattern] = Seq(IN)

override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"

Expand Down Expand Up @@ -548,7 +550,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with

override def nullable: Boolean = child.nullable || hasNull

override val nodePatterns: Seq[TreePattern] = Seq(INSET)
final override val nodePatterns: Seq[TreePattern] = Seq(INSET)

protected override def nullSafeEval(value: Any): Any = {
if (set.contains(value)) {
Expand Down Expand Up @@ -666,6 +668,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with

override def sqlOperator: String = "AND"

final override val nodePatterns: Seq[TreePattern] = Seq(AND_OR)

// +---------+---------+---------+---------+
// | AND | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
Expand Down Expand Up @@ -752,6 +756,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P

override def sqlOperator: String = "OR"

final override val nodePatterns: Seq[TreePattern] = Seq(AND_OR)

// +---------+---------+---------+---------+
// | OR | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
Expand Down Expand Up @@ -823,6 +829,8 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
// finitely enumerable. The allowable types are checked below by checkInputDataTypes.
override def inputType: AbstractDataType = AnyDataType

final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_COMPARISON)

override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
TypeUtils.checkForOrderingExpr(left.dataType, this.getClass.getSimpleName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, TreePattern}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -129,6 +130,8 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)

override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches()

final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY)

override def toString: String = escapeChar match {
case '\\' => s"$left LIKE $right"
case c => s"$left LIKE $right ESCAPE '$c'"
Expand Down Expand Up @@ -198,6 +201,8 @@ sealed abstract class MultiLikeBase

override def nullable: Boolean = true

final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY)

protected lazy val hasNull: Boolean = patterns.contains(null)

protected lazy val cache = patterns.filterNot(_ == null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -406,6 +407,8 @@ case class Upper(child: Expression)
override def convert(v: UTF8String): UTF8String = v.toUpperCase
// scalastyle:on caselocale

final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"($c).toUpperCase()")
}
Expand All @@ -432,6 +435,8 @@ case class Lower(child: Expression)
override def convert(v: UTF8String): UTF8String = v.toLowerCase
// scalastyle:on caselocale

final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"($c).toLowerCase()")
}
Expand Down
Loading

0 comments on commit 9af338c

Please sign in to comment.