Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-35078][SQL] Add tree traversal pruning in expression rules #32280

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast.{forceNullable, resolvable
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern.{CAST, TreePattern}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
Expand Down Expand Up @@ -1800,6 +1801,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

final override val nodePatterns: Seq[TreePattern] = Seq(CAST)

override protected val ansiEnabled: Boolean = SQLConf.get.ansiEnabled

override def canCast(from: DataType, to: DataType): Boolean = if (ansiEnabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreePattern.{COUNT, TreePattern}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -48,6 +49,8 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {

override def nullable: Boolean = false

final override val nodePatterns: Seq[TreePattern] = Seq(COUNT)

// Return data type.
override def dataType: DataType = LongType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, TreePattern,
UNARY_POSITIVE}
import org.apache.spark.sql.catalyst.util.{IntervalUtils, TypeUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -128,6 +130,8 @@ case class UnaryPositive(child: Expression)

override def dataType: DataType = child.dataType

final override val nodePatterns: Seq[TreePattern] = Seq(UNARY_POSITIVE)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
defineCodeGen(ctx, ev, c => c)

Expand Down Expand Up @@ -199,6 +203,8 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {

override def dataType: DataType = left.dataType

final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_ARITHMETIC)

override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess

/** Name of the function for this expression on a [[Decimal]] type. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{CONCAT, TreePattern}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
Expand Down Expand Up @@ -2172,6 +2173,8 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio

private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType)

final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT)

override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty) {
TypeCheckResult.TypeCheckSuccess
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TernaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{CASE_WHEN, IF, TreePattern}
import org.apache.spark.sql.types._

// scalastyle:off line.size.limit
Expand All @@ -48,6 +49,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
override def third: Expression = falseValue
override def nullable: Boolean = trueValue.nullable || falseValue.nullable

final override val nodePatterns : Seq[TreePattern] = Seq(IF)

override def checkInputDataTypes(): TypeCheckResult = {
if (predicate.dataType != BooleanType) {
TypeCheckResult.TypeCheckFailure(
Expand Down Expand Up @@ -139,6 +142,8 @@ case class CaseWhen(

override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue

final override val nodePatterns : Seq[TreePattern] = Seq(CASE_WHEN)

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
super.legacyWithNewChildren(newChildren)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_CHECK, TreePattern}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -345,6 +346,8 @@ case class NaNvl(left: Expression, right: Expression)
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false

final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)

override def eval(input: InternalRow): Any = {
child.eval(input) == null
}
Expand Down Expand Up @@ -375,6 +378,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false

final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)

override def eval(input: InternalRow): Any = {
child.eval(input) != null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TernaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{NULL_CHECK, TreePattern}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -1705,6 +1706,8 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)
override def foldable: Boolean = false
override def nullable: Boolean = false

final override val nodePatterns: Seq[TreePattern] = Seq(NULL_CHECK)

override def flatArguments: Iterator[Any] = Iterator(child)

private val errMsg = "Null value appeared in non-nullable field:" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.trees.TreePattern.{IN, IN_SUBQUERY, INSET, TreePattern}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -309,6 +309,8 @@ case class Not(child: Expression)

override def inputTypes: Seq[DataType] = Seq(BooleanType)

final override val nodePatterns: Seq[TreePattern] = Seq(NOT)

// +---------+-----------+
// | CHILD | NOT CHILD |
// +---------+-----------+
Expand Down Expand Up @@ -435,7 +437,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)

override val nodePatterns: Seq[TreePattern] = Seq(IN)
final override val nodePatterns: Seq[TreePattern] = Seq(IN)

override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"

Expand Down Expand Up @@ -548,7 +550,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with

override def nullable: Boolean = child.nullable || hasNull

override val nodePatterns: Seq[TreePattern] = Seq(INSET)
final override val nodePatterns: Seq[TreePattern] = Seq(INSET)

protected override def nullSafeEval(value: Any): Any = {
if (set.contains(value)) {
Expand Down Expand Up @@ -666,6 +668,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with

override def sqlOperator: String = "AND"

final override val nodePatterns: Seq[TreePattern] = Seq(AND_OR)

// +---------+---------+---------+---------+
// | AND | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
Expand Down Expand Up @@ -752,6 +756,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P

override def sqlOperator: String = "OR"

final override val nodePatterns: Seq[TreePattern] = Seq(AND_OR)

// +---------+---------+---------+---------+
// | OR | TRUE | FALSE | UNKNOWN |
// +---------+---------+---------+---------+
Expand Down Expand Up @@ -823,6 +829,8 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
// finitely enumerable. The allowable types are checked below by checkInputDataTypes.
override def inputType: AbstractDataType = AnyDataType

final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_COMPARISON)

override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
TypeUtils.checkForOrderingExpr(left.dataType, this.getClass.getSimpleName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, TreePattern}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -129,6 +130,8 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)

override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches()

final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY)

override def toString: String = escapeChar match {
case '\\' => s"$left LIKE $right"
case c => s"$left LIKE $right ESCAPE '$c'"
Expand Down Expand Up @@ -198,6 +201,8 @@ sealed abstract class MultiLikeBase

override def nullable: Boolean = true

final override val nodePatterns: Seq[TreePattern] = Seq(LIKE_FAMLIY)

protected lazy val hasNull: Boolean = patterns.contains(null)

protected lazy val cache = patterns.filterNot(_ == null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -406,6 +407,8 @@ case class Upper(child: Expression)
override def convert(v: UTF8String): UTF8String = v.toUpperCase
// scalastyle:on caselocale

final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"($c).toUpperCase()")
}
Expand All @@ -432,6 +435,8 @@ case class Lower(child: Expression)
override def convert(v: UTF8String): UTF8String = v.toLowerCase
// scalastyle:on caselocale

final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"($c).toLowerCase()")
}
Expand Down
Loading