Skip to content

Commit

Permalink
[SPARK-34989] Improve the performance of mapChildren and withNewChild…
Browse files Browse the repository at this point in the history
…ren methods

### What changes were proposed in this pull request?
One of the main performance bottlenecks in query compilation is overly-generic tree transformation methods, namely `mapChildren` and `withNewChildren` (defined in `TreeNode`). These methods have an overly-generic implementation to iterate over the children and rely on reflection to create new instances. We have observed that, especially for queries with large query plans, a significant amount of CPU cycles are wasted in these methods. In this PR we make these methods more efficient, by delegating the iteration and instantiation to concrete node types. The benchmarks show that we can expect significant performance improvement in total query compilation time in queries with large query plans (from 30-80%) and about 20% on average.

#### Problem detail
The `mapChildren` method in `TreeNode` is overly generic and costly. To be more specific, this method:
- iterates over all the fields of a node using Scala’s product iterator. While the iteration is not reflection-based, thanks to the Scala compiler generating code for `Product`, we create many anonymous functions and visit many nested structures (recursive calls).
The anonymous functions (presumably compiled to Java anonymous inner classes) also show up quite high on the list in the object allocation profiles, so we are putting unnecessary pressure on GC here.
- does a lot of comparisons. Basically for each element returned from the product iterator, we check if it is a child (contained in the list of children) and then transform it. We can avoid that by just iterating over children, but in the current implementation, we need to gather all the fields (only transform the children) so that we can instantiate the object using the reflection.
- creates objects using reflection, by delegating to the `makeCopy` method, which is several orders of magnitude slower than using the constructor.

#### Solution
The proposed solution in this PR is rather straightforward: we rewrite the `mapChildren` method using the `children` and `withNewChildren` methods. The default `withNewChildren` method suffers from the same problems as `mapChildren` and we need to make it more efficient by specializing it in concrete classes.  Similar to how each concrete query plan node already defines its children, it should also define how they can be constructed given a new list of children. Actually, the implementation is quite simple in most cases and is a one-liner thanks to the copy method present in Scala case classes. Note that we cannot abstract over the copy method, it’s generated by the compiler for case classes if no other type higher in the hierarchy defines it. For most concrete nodes, the implementation of `withNewChildren` looks like this:
```
override def withNewChildren(newChildren: Seq[LogicalPlan]): LogicalPlan = copy(children = newChildren)
```
The current `withNewChildren` method has two properties that we should preserve:

- It returns the same instance if the provided children are the same as its children, i.e., it preserves referential equality.
- It copies tags and maintains the origin links when a new copy is created.

These properties are hard to enforce in the concrete node type implementation. Therefore, we propose a template method `withNewChildrenInternal` that should be rewritten by the concrete classes and let the `withNewChildren` method take care of referential equality and copying:
```
override def withNewChildren(newChildren: Seq[LogicalPlan]): LogicalPlan = {
 if (childrenFastEquals(children, newChildren)) {
   this
 } else {
   CurrentOrigin.withOrigin(origin) {
     val res = withNewChildrenInternal(newChildren)
     res.copyTagsFrom(this)
     res
   }
 }
}
```

With the refactoring done in a previous PR (#31932) most tree node types fall in one of the categories of `Leaf`, `Unary`, `Binary` or `Ternary`. These traits have a more efficient implementation for `mapChildren` and define a more specialized version of `withNewChildrenInternal` that avoids creating unnecessary lists. For example, the `mapChildren` method in `UnaryLike` is defined as follows:
```
  override final def mapChildren(f: T => T): T = {
    val newChild = f(child)
    if (newChild fastEquals child) {
      this.asInstanceOf[T]
    } else {
      CurrentOrigin.withOrigin(origin) {
        val res = withNewChildInternal(newChild)
        res.copyTagsFrom(this.asInstanceOf[T])
        res
      }
    }
  }
```

#### Results
With this PR, we have observed significant performance improvements in query compilation time, more specifically in the analysis and optimization phases. The table below shows the TPC-DS queries that had more than 25% speedup in compilation times. Biggest speedups are observed in queries with large query plans.
| Query  | Speedup |
| ------------- | ------------- |
|q4    |29%|
|q9    |81%|
|q14a  |31%|
|q14b  |28%|
|q22   |33%|
|q33   |29%|
|q34   |25%|
|q39   |27%|
|q41   |27%|
|q44   |26%|
|q47   |28%|
|q48   |76%|
|q49   |46%|
|q56   |26%|
|q58   |43%|
|q59   |46%|
|q60   |50%|
|q65   |59%|
|q66   |46%|
|q67   |52%|
|q69   |31%|
|q70   |30%|
|q96   |26%|
|q98   |32%|

#### Binary incompatibility
Changing the `withNewChildren` in `TreeNode` breaks the binary compatibility of the code compiled against older versions of Spark because now it is expected that concrete `TreeNode` subclasses all implement the `withNewChildrenInternal` method. This is a problem, for example, when users write custom expressions. This change is the right choice, since it forces all newly added expressions to Catalyst implement it in an efficient manner and will prevent future regressions.
Please note that we have not completely removed the old implementation and renamed it to `legacyWithNewChildren`. This method will be removed in the future and for now helps the transition. There are expressions such as `UpdateFields` that have a complex way of defining children. Writing `withNewChildren` for them requires refactoring the expression. For now, these expressions use the old, slow method. In a future PR we address these expressions.

### Does this PR introduce _any_ user-facing change?

This PR does not introduce user facing changes but my break binary compatibility of the code compiled against older versions. See the binary compatibility section.

### How was this patch tested?

This PR is mainly a refactoring and passes existing tests.

Closes #32030 from dbaliafroozeh/ImprovedMapChildren.

Authored-by: Ali Afroozeh <[email protected]>
Signed-off-by: herman <[email protected]>
  • Loading branch information
dbaliafroozeh authored and hvanhovell committed Apr 9, 2021
1 parent a3d1e00 commit 0945baf
Show file tree
Hide file tree
Showing 175 changed files with 2,213 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,7 @@ private[avro] case class AvroDataToCatalyst(
"""
})
}

override protected def withNewChildInternal(newChild: Expression): AvroDataToCatalyst =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,7 @@ private[avro] case class CatalystDataToAvro(
defineCodeGen(ctx, ev, input =>
s"(byte[]) $expr.nullSafeEval($input)")
}

override protected def withNewChildInternal(newChild: Expression): CatalystDataToAvro =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ private[spark] object SummaryBuilderImpl extends Logging {
override def left: Expression = featuresExpr
override def right: Expression = weightExpr

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): MetricsAggregate =
copy(featuresExpr = newLeft, weightExpr = newRight)

override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = {
val features = vectorUDT.deserialize(featuresExpr.eval(row))
val weight = weightExpr.eval(row).asInstanceOf[Double]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expressio

override def terminate(): TraversableOnce[InternalRow] =
throw QueryExecutionErrors.cannotTerminateGeneratorError(this)

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): UnresolvedGenerator = copy(children = newChildren)
}

case class UnresolvedFunction(
Expand All @@ -284,6 +287,15 @@ case class UnresolvedFunction(
val distinct = if (isDistinct) "distinct " else ""
s"'$name($distinct${children.mkString(", ")})"
}

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): UnresolvedFunction = {
if (filter.isDefined) {
copy(arguments = newChildren.dropRight(1), filter = Some(newChildren.last))
} else {
copy(arguments = newChildren)
}
}
}

object UnresolvedFunction {
Expand Down Expand Up @@ -441,6 +453,8 @@ case class MultiAlias(child: Expression, names: Seq[String])

override def toString: String = s"$child AS $names"

override protected def withNewChildInternal(newChild: Expression): MultiAlias =
copy(child = newChild)
}

/**
Expand Down Expand Up @@ -475,6 +489,11 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)

override def toString: String = s"$child[$extraction]"
override def sql: String = s"${child.sql}[${extraction.sql}]"

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): UnresolvedExtractValue = {
copy(child = newLeft, extraction = newRight)
}
}

/**
Expand All @@ -499,6 +518,9 @@ case class UnresolvedAlias(
override def newInstance(): NamedExpression = throw new UnresolvedException("newInstance")

override lazy val resolved = false

override protected def withNewChildInternal(newChild: Expression): UnresolvedAlias =
copy(child = newChild)
}

/**
Expand All @@ -520,6 +542,9 @@ case class UnresolvedSubqueryColumnAliases(
override def output: Seq[Attribute] = Nil

override lazy val resolved = false

override protected def withNewChildInternal(
newChild: LogicalPlan): UnresolvedSubqueryColumnAliases = copy(child = newChild)
}

/**
Expand All @@ -541,6 +566,9 @@ case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq
override def dataType: DataType = throw new UnresolvedException("dataType")
override def nullable: Boolean = throw new UnresolvedException("nullable")
override lazy val resolved = false

override protected def withNewChildInternal(newChild: Expression): UnresolvedDeserializer =
copy(deserializer = newChild)
}

case class GetColumnByOrdinal(ordinal: Int, dataType: DataType) extends LeafExpression
Expand Down Expand Up @@ -587,6 +615,8 @@ case class UnresolvedHaving(
extends UnaryNode {
override lazy val resolved: Boolean = false
override def output: Seq[Attribute] = child.output
override protected def withNewChildInternal(newChild: LogicalPlan): UnresolvedHaving =
copy(child = newChild)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ case class CallMethodViaReflection(children: Seq[Expression])

/** A temporary buffer used to hold intermediate results returned by children. */
@transient private lazy val buffer = new Array[Object](argExprs.length)

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): CallMethodViaReflection = copy(children = newChildren)
}

object CallMethodViaReflection {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1812,6 +1812,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
} else {
s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
}

override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild)
}

/**
Expand Down Expand Up @@ -1841,6 +1843,8 @@ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[St
Some(SQLConf.STORE_ASSIGNMENT_POLICY.key),
Some(SQLConf.StoreAssignmentPolicy.LEGACY.toString))

override protected def withNewChildInternal(newChild: Expression): AnsiCast =
copy(child = newChild)
}

object AnsiCast {
Expand Down Expand Up @@ -1998,4 +2002,6 @@ case class UpCast(child: Expression, target: AbstractDataType, walkedTypePath: S
case DecimalType => DecimalType.SYSTEM_DEFAULT
case _ => target.asInstanceOf[DataType]
}

override protected def withNewChildInternal(newChild: Expression): UpCast = copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ case class DynamicPruningSubquery(
buildKeys = buildKeys.map(_.canonicalized),
exprId = ExprId(0))
}

override protected def withNewChildInternal(newChild: Expression): DynamicPruningSubquery =
copy(pruningKey = newChild)
}

/**
Expand All @@ -94,4 +97,7 @@ case class DynamicPruningExpression(child: Expression)
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
child.genCode(ctx)
}

override protected def withNewChildInternal(newChild: Expression): DynamicPruningExpression =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,37 @@ abstract class PartitionTransformExpression extends Expression with Unevaluable
*/
case class Years(child: Expression) extends PartitionTransformExpression {
override def dataType: DataType = IntegerType
override protected def withNewChildInternal(newChild: Expression): Years = copy(child = newChild)
}

/**
* Expression for the v2 partition transform months.
*/
case class Months(child: Expression) extends PartitionTransformExpression {
override def dataType: DataType = IntegerType
override protected def withNewChildInternal(newChild: Expression): Months = copy(child = newChild)
}

/**
* Expression for the v2 partition transform days.
*/
case class Days(child: Expression) extends PartitionTransformExpression {
override def dataType: DataType = IntegerType
override protected def withNewChildInternal(newChild: Expression): Days = copy(child = newChild)
}

/**
* Expression for the v2 partition transform hours.
*/
case class Hours(child: Expression) extends PartitionTransformExpression {
override def dataType: DataType = IntegerType
override protected def withNewChildInternal(newChild: Expression): Hours = copy(child = newChild)
}

/**
* Expression for the v2 partition transform bucket.
*/
case class Bucket(numBuckets: Literal, child: Expression) extends PartitionTransformExpression {
override def dataType: DataType = IntegerType
override protected def withNewChildInternal(newChild: Expression): Bucket = copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,7 @@ case class PythonUDF(
// `resultId` can be seen as cosmetic variation in PythonUDF, as it doesn't affect the result.
this.copy(resultId = ExprId(-1)).withNewChildren(canonicalizedChildren)
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): PythonUDF =
copy(children = newChildren)
}
Original file line number Diff line number Diff line change
Expand Up @@ -1195,4 +1195,7 @@ case class ScalaUDF(

resultConverter(result)
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ScalaUDF =
copy(children = newChildren)
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ case class SortOrder(
children.exists(required.child.semanticEquals) &&
direction == required.direction && nullOrdering == required.nullOrdering
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): SortOrder =
copy(child = newChildren.head, sameOrderExpressions = newChildren.tail)
}

object SortOrder {
Expand Down Expand Up @@ -226,4 +229,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
}

override def dataType: DataType = LongType

override protected def withNewChildInternal(newChild: Expression): SortPrefix =
copy(child = newChild.asInstanceOf[SortOrder])
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ case class ExpressionProxy(
}

override def hashCode(): Int = this.id.hashCode()

override protected def withNewChildInternal(newChild: Expression): ExpressionProxy =
copy(child = newChild)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ case class TimeWindow(
}
dataTypeCheck
}

override protected def withNewChildInternal(newChild: Expression): TimeWindow =
copy(timeColumn = newChild)
}

object TimeWindow {
Expand Down Expand Up @@ -155,4 +158,7 @@ case class PreciseTimestampConversion(
""".stripMargin)
}
override def nullSafeEval(input: Any): Any = input

override protected def withNewChildInternal(newChild: Expression): PreciseTimestampConversion =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ case class TryCast(child: Expression, dataType: DataType, timeZoneId: Option[Str
override def typeCheckFailureMessage: String =
AnsiCast.typeCheckFailureMessage(child.dataType, dataType, None, None)

override protected def withNewChildInternal(newChild: Expression): TryCast =
copy(child = newChild)

override def toString: String = {
s"try_cast($child as ${dataType.simpleString})"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,8 @@ case class ApproxCountDistinctForIntervals(
override def getLong(offset: Int): Long = array(offset)
override def setLong(offset: Int, value: Long): Unit = { array(offset) = value }
}

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): ApproxCountDistinctForIntervals =
copy(child = newLeft, endpointsExpression = newRight)
}
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ case class ApproximatePercentile(
override def deserialize(bytes: Array[Byte]): PercentileDigest = {
ApproximatePercentile.serializer.deserialize(bytes)
}

override protected def withNewChildrenInternal(
newFirst: Expression, newSecond: Expression, newThird: Expression): ApproximatePercentile =
copy(child = newFirst, percentageExpression = newSecond, accuracyExpression = newThird)
}

object ApproximatePercentile {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
coalesce(child.cast(sumDataType), Literal.default(sumDataType))),
/* count = */ If(child.isNull, count, count + 1L)
)

override protected def withNewChildInternal(newChild: Expression): Average =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ case class StddevPop(
}

override def prettyName: String = "stddev_pop"

override protected def withNewChildInternal(newChild: Expression): StddevPop =
copy(child = newChild)
}

// Compute the sample standard deviation of a column
Expand Down Expand Up @@ -197,6 +200,9 @@ case class StddevSamp(

override def prettyName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("stddev_samp")

override protected def withNewChildInternal(newChild: Expression): StddevSamp =
copy(child = newChild)
}

// Compute the population variance of a column
Expand All @@ -223,6 +229,9 @@ case class VariancePop(
}

override def prettyName: String = "var_pop"

override protected def withNewChildInternal(newChild: Expression): VariancePop =
copy(child = newChild)
}

// Compute the sample variance of a column
Expand Down Expand Up @@ -250,6 +259,9 @@ case class VarianceSamp(
}

override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("var_samp")

override protected def withNewChildInternal(newChild: Expression): VarianceSamp =
copy(child = newChild)
}

@ExpressionDescription(
Expand Down Expand Up @@ -278,6 +290,9 @@ case class Skewness(
If(n === 0.0, Literal.create(null, DoubleType),
If(m2 === 0.0, divideByZeroEvalResult, sqrt(n) * m3 / sqrt(m2 * m2 * m2)))
}

override protected def withNewChildInternal(newChild: Expression): Skewness =
copy(child = newChild)
}

@ExpressionDescription(
Expand Down Expand Up @@ -306,4 +321,7 @@ case class Kurtosis(
}

override def prettyName: String = "kurtosis"

override protected def withNewChildInternal(newChild: Expression): Kurtosis =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,7 @@ case class Corr(
}

override def prettyName: String = "corr"

override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Corr =
copy(x = newLeft, y = newRight)
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
)
}
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Count =
copy(children = newChildren)
}

object Count {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,7 @@ case class CountIf(predicate: Expression) extends UnevaluableAggregate with Impl
s"function $prettyName requires boolean type, not ${predicate.dataType.catalogString}"
)
}

override protected def withNewChildInternal(newChild: Expression): CountIf =
copy(predicate = newChild)
}
Loading

0 comments on commit 0945baf

Please sign in to comment.