Skip to content

Commit

Permalink
[SPARK-34079][SQL] Merge non-correlated scalar subqueries
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR adds a new optimizer rule `MergeScalarSubqueries` to merge multiple non-correlated `ScalarSubquery`s to compute multiple scalar values once.

E.g. the following query:
```
SELECT
  (SELECT avg(a) FROM t),
  (SELECT sum(b) FROM t)
```
is optimized from:
```
== Optimized Logical Plan ==
Project [scalar-subquery#242 [] AS scalarsubquery()#253, scalar-subquery#243 [] AS scalarsubquery()#254L]
:  :- Aggregate [avg(a#244) AS avg(a)#247]
:  :  +- Project [a#244]
:  :     +- Relation default.t[a#244,b#245] parquet
:  +- Aggregate [sum(a#251) AS sum(a)#250L]
:     +- Project [a#251]
:        +- Relation default.t[a#251,b#252] parquet
+- OneRowRelation
```
to:
```
== Optimized Logical Plan ==
Project [scalar-subquery#242 [].avg(a) AS scalarsubquery()#253, scalar-subquery#243 [].sum(a) AS scalarsubquery()#254L]
:  :- Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260]
:  :  +- Aggregate [avg(a#244) AS avg(a)#247, sum(a#244) AS sum(a)#250L]
:  :     +- Project [a#244]
:  :        +- Relation default.t[a#244,b#245] parquet
:  +- Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260]
:     +- Aggregate [avg(a#244) AS avg(a)#247, sum(a#244) AS sum(a)#250L]
:        +- Project [a#244]
:           +- Relation default.t[a#244,b#245] parquet
+- OneRowRelation
```
and in the physical plan subqueries are reused:
```
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=true
+- == Final Plan ==
   *(1) Project [Subquery subquery#242, [id=#113].avg(a) AS scalarsubquery()#253, ReusedSubquery Subquery subquery#242, [id=#113].sum(a) AS scalarsubquery()#254L]
   :  :- Subquery subquery#242, [id=#113]
   :  :  +- AdaptiveSparkPlan isFinalPlan=true
         +- == Final Plan ==
            *(2) Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260]
            +- *(2) HashAggregate(keys=[], functions=[avg(a#244), sum(a#244)], output=[avg(a)#247, sum(a)#250L])
               +- ShuffleQueryStage 0
                  +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#158]
                     +- *(1) HashAggregate(keys=[], functions=[partial_avg(a#244), partial_sum(a#244)], output=[sum#262, count#263L, sum#264L])
                        +- *(1) ColumnarToRow
                           +- FileScan parquet default.t[a#244] Batched: true, DataFilters: [], Format: Parquet, Location: ..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<a:int>
         +- == Initial Plan ==
            Project [named_struct(avg(a), avg(a)#247, sum(a), sum(a)#250L) AS mergedValue#260]
            +- HashAggregate(keys=[], functions=[avg(a#244), sum(a#244)], output=[avg(a)#247, sum(a)#250L])
               +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#110]
                  +- HashAggregate(keys=[], functions=[partial_avg(a#244), partial_sum(a#244)], output=[sum#262, count#263L, sum#264L])
                     +- FileScan parquet default.t[a#244] Batched: true, DataFilters: [], Format: Parquet, Location: ..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<a:int>
   :  +- ReusedSubquery Subquery subquery#242, [id=#113]
   +- *(1) Scan OneRowRelation[]
+- == Initial Plan ==
...
```

Please note that the above simple example could be easily optimized into a common select expression without reuse node, but this PR can handle more complex queries as well.

### Why are the changes needed?
Performance improvement.
```
[info] TPCDS Snappy:                             Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] q9 - MergeScalarSubqueries off                    50798          52521        1423          0.0      Infinity       1.0X
[info] q9 - MergeScalarSubqueries on                     19484          19675         226          0.0      Infinity       2.6X

[info] TPCDS Snappy:                             Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
[info] ------------------------------------------------------------------------------------------------------------------------
[info] q9b - MergeScalarSubqueries off                   15430          17803         NaN          0.0      Infinity       1.0X
[info] q9b - MergeScalarSubqueries on                     3862           4002         196          0.0      Infinity       4.0X
```
Please find `q9b` in the description of SPARK-34079. It is a variant of [q9.sql](https://github.com/apache/spark/blob/master/sql/core/src/test/resources/tpcds/q9.sql) using CTE.
The performance improvement in case of `q9` comes from merging 15 subqueries into 5 and in case of `q9b` it comes from merging 5 subqueries into 1.

### Does this PR introduce _any_ user-facing change?
No. But this optimization can be disabled with `spark.sql.optimizer.excludedRules` config.

### How was this patch tested?
Existing and new UTs.

Closes #32298 from peter-toth/SPARK-34079-multi-column-scalar-subquery.

Lead-authored-by: Peter Toth <[email protected]>
Co-authored-by: attilapiros <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
2 people authored and cloud-fan committed Apr 20, 2022
1 parent 21e48b7 commit e00b81e
Show file tree
Hide file tree
Showing 19 changed files with 1,706 additions and 1,600 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ case class BloomFilterMightContain(
case e : Expression if e.foldable => TypeCheckResult.TypeCheckSuccess
case subquery : PlanExpression[_] if !subquery.containsPattern(OUTER_REFERENCE) =>
TypeCheckResult.TypeCheckSuccess
case GetStructField(subquery: PlanExpression[_], _, _)
if !subquery.containsPattern(OUTER_REFERENCE) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(s"The Bloom filter binary input to $prettyName " +
"should be either a constant value or a scalar subquery expression")
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ object PushdownPredicatesAndPruneColumnsForCTEDef extends Rule[LogicalPlan] {
private def pushdownPredicatesAndAttributes(
plan: LogicalPlan,
cteMap: CTEMap): LogicalPlan = plan.transformWithSubqueries {
case cteDef @ CTERelationDef(child, id, originalPlanWithPredicates) =>
case cteDef @ CTERelationDef(child, id, originalPlanWithPredicates, _) =>
val (_, _, newPreds, newAttrSet) = cteMap(id)
val originalPlan = originalPlanWithPredicates.map(_._1).getOrElse(child)
val preds = originalPlanWithPredicates.map(_._2).getOrElse(Seq.empty)
Expand Down Expand Up @@ -169,7 +169,7 @@ object PushdownPredicatesAndPruneColumnsForCTEDef extends Rule[LogicalPlan] {
object CleanUpTempCTEInfo extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan =
plan.transformWithPruning(_.containsPattern(CTE)) {
case cteDef @ CTERelationDef(_, _, Some(_)) =>
case cteDef @ CTERelationDef(_, _, Some(_), _) =>
cteDef.copy(originalPlanWithPredicates = None)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,14 @@ object ReplaceCTERefWithRepartition extends Rule[LogicalPlan] {
case WithCTE(child, cteDefs) =>
cteDefs.foreach { cteDef =>
val inlined = replaceWithRepartition(cteDef.child, cteMap)
val withRepartition = if (inlined.isInstanceOf[RepartitionOperation]) {
// If the CTE definition plan itself is a repartition operation, we do not need to add an
// extra repartition shuffle.
inlined
} else {
Repartition(conf.numShufflePartitions, shuffle = true, inlined)
}
val withRepartition =
if (inlined.isInstanceOf[RepartitionOperation] || cteDef.underSubquery) {
// If the CTE definition plan itself is a repartition operation or if it hosts a merged
// scalar subquery, we do not need to add an extra repartition shuffle.
inlined
} else {
Repartition(conf.numShufflePartitions, shuffle = true, inlined)
}
cteMap.put(cteDef.id, withRepartition)
}
replaceWithRepartition(child, cteMap)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, MultiInstanceRe
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable}
import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
Expand Down Expand Up @@ -663,11 +663,14 @@ case class UnresolvedWith(
* predicates that have been pushed down into `child`. This is
* a temporary field used by optimization rules for CTE predicate
* pushdown to help ensure rule idempotency.
* @param underSubquery If true, it means we don't need to add a shuffle for this CTE relation as
* subquery reuse will be applied to reuse CTE relation output.
*/
case class CTERelationDef(
child: LogicalPlan,
id: Long = CTERelationDef.newId,
originalPlanWithPredicates: Option[(LogicalPlan, Seq[Expression])] = None) extends UnaryNode {
originalPlanWithPredicates: Option[(LogicalPlan, Seq[Expression])] = None,
underSubquery: Boolean = false) extends UnaryNode {

final override val nodePatterns: Seq[TreePattern] = Seq(CTE)

Expand All @@ -678,17 +681,19 @@ case class CTERelationDef(
}

object CTERelationDef {
private val curId = new java.util.concurrent.atomic.AtomicLong()
private[sql] val curId = new java.util.concurrent.atomic.AtomicLong()
def newId: Long = curId.getAndIncrement()
}

/**
* Represents the relation of a CTE reference.
* @param cteId The ID of the corresponding CTE definition.
* @param _resolved Whether this reference is resolved.
* @param output The output attributes of this CTE reference, which can be different from
* the output of its corresponding CTE definition after attribute de-duplication.
* @param statsOpt The optional statistics inferred from the corresponding CTE definition.
* @param cteId The ID of the corresponding CTE definition.
* @param _resolved Whether this reference is resolved.
* @param output The output attributes of this CTE reference, which can be different
* from the output of its corresponding CTE definition after attribute
* de-duplication.
* @param statsOpt The optional statistics inferred from the corresponding CTE
* definition.
*/
case class CTERelationRef(
cteId: Long,
Expand Down Expand Up @@ -1014,6 +1019,24 @@ case class Aggregate(
}
}

object Aggregate {
def isAggregateBufferMutable(schema: StructType): Boolean = {
schema.forall(f => UnsafeRow.isMutable(f.dataType))
}

def supportsHashAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = {
val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
isAggregateBufferMutable(aggregationBufferSchema)
}

def supportsObjectHashAggregate(aggregateExpressions: Seq[AggregateExpression]): Boolean = {
aggregateExpressions.map(_.aggregateFunction).exists {
case _: TypedImperativeAggregate[_] => true
case _ => false
}
}
}

case class Window(
windowExpressions: Seq[NamedExpression],
partitionSpec: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ object TreePattern extends Enumeration {
val REGEXP_REPLACE: Value = Value
val RUNTIME_REPLACEABLE: Value = Value
val SCALAR_SUBQUERY: Value = Value
val SCALAR_SUBQUERY_REFERENCE: Value = Value
val SCALA_UDF: Value = Value
val SORT: Value = Value
val SUBQUERY_ALIAS: Value = Value
Expand Down
Loading

0 comments on commit e00b81e

Please sign in to comment.