Skip to content

Commit

Permalink
Revert "[SPARK-34581][SQL] Don't optimize out grouping expressions fr…
Browse files Browse the repository at this point in the history
…om aggregate expressions without aggregate function"

This reverts commit c8d78a7.
  • Loading branch information
cloud-fan committed Apr 23, 2021
1 parent 20d68dc commit fdccd88
Show file tree
Hide file tree
Showing 15 changed files with 68 additions and 247 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Attribute, GroupingExprRef, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule

/**
Expand Down Expand Up @@ -52,22 +52,3 @@ object UpdateAttributeNullability extends Rule[LogicalPlan] {
}
}
}

/**
* Updates nullability of [[GroupingExprRef]]s in a resolved LogicalPlan by using the nullability of
* referenced grouping expression.
*/
object UpdateGroupingExprRefNullability extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a: Aggregate =>
val nullabilities = a.groupingExpressions.map(_.nullable).toArray

val newAggregateExpressions =
a.aggregateExpressions.map(_.transform {
case g: GroupingExprRef if g.nullable != nullabilities(g.ordinal) =>
g.copy(nullable = nullabilities(g.ordinal))
}.asInstanceOf[NamedExpression])

a.copy(aggregateExpressions = newAggregateExpressions)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ trait AliasHelper {
protected def getAliasMap(plan: Aggregate): AttributeMap[Alias] = {
// Find all the aliased expressions in the aggregate list that don't include any actual
// AggregateExpression or PythonUDF, and create a map from the alias to the expression
val aliasMap = plan.aggregateExpressionsWithoutGroupingRefs.collect {
val aliasMap = plan.aggregateExpressions.collect {
case a: Alias if a.child.find(e => e.isInstanceOf[AggregateExpression] ||
PythonUDF.isGroupedAggPandasUDF(e)).isEmpty =>
(a.toAttribute, a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,6 @@ object AggregateExpression {
filter,
NamedExpression.newExprId)
}

def containsAggregate(expr: Expression): Boolean = {
expr.find(isAggregate).isDefined
}

def isAggregate(expr: Expression): Boolean = {
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,22 +277,3 @@ object GroupingAnalytics {
}
}
}

/**
* A reference to an grouping expression in [[Aggregate]] node.
*
* @param ordinal The ordinal of the grouping expression in [[Aggregate]] that this expression
* refers to.
* @param dataType The [[DataType]] of the referenced grouping expression.
* @param nullable True if null is a valid value for the referenced grouping expression.
*/
case class GroupingExprRef(
ordinal: Int,
dataType: DataType,
nullable: Boolean)
extends LeafExpression with Unevaluable {

override def stringArgs: Iterator[Any] = {
Iterator(ordinal)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,23 @@
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule

/**
* Simplify redundant [[CreateNamedStruct]], [[CreateArray]] and [[CreateMap]] expressions.
*/
object SimplifyExtractValueOps extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// One place where this optimization is invalid is an aggregation where the select
// list expression is a function of a grouping expression:
//
// SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b)
//
// cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this
// optimization for Aggregates (although this misses some cases where the optimization
// can be made).
case a: Aggregate => a
case p => p.transformExpressionsUp {
// Remove redundant field extraction.
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
OptimizeUpdateFields,
SimplifyExtractValueOps,
OptimizeCsvJsonExprs,
CombineConcats,
UpdateGroupingExprRefNullability) ++
CombineConcats) ++
extendedOperatorOptimizationRules

val operatorOptimizationBatch: Seq[Batch] = {
Expand Down Expand Up @@ -149,7 +148,6 @@ abstract class Optimizer(catalogManager: CatalogManager)
EliminateView,
ReplaceExpressions,
RewriteNonCorrelatedExists,
EnforceGroupingReferencesInAggregates,
ComputeCurrentTime,
GetCurrentDatabaseAndCatalog(catalogManager)) ::
//////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -269,9 +267,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RewriteCorrelatedScalarSubquery.ruleName ::
RewritePredicateSubquery.ruleName ::
NormalizeFloatingNumbers.ruleName ::
ReplaceUpdateFieldsExpression.ruleName ::
EnforceGroupingReferencesInAggregates.ruleName ::
UpdateGroupingExprRefNullability.ruleName :: Nil
ReplaceUpdateFieldsExpression.ruleName :: Nil

/**
* Optimize all the subqueries inside expression.
Expand Down Expand Up @@ -512,7 +508,7 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
case upper @ Aggregate(_, _, lower: Aggregate) if lowerIsRedundant(upper, lower) =>
val aliasMap = getAliasMap(lower)

val newAggregate = Aggregate.withGroupingRefs(
val newAggregate = upper.copy(
child = lower.child,
groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)),
aggregateExpressions = upper.aggregateExpressions.map(
Expand All @@ -528,19 +524,23 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
}

private def lowerIsRedundant(upper: Aggregate, lower: Aggregate): Boolean = {
val upperHasNoAggregateExpressions =
!upper.aggregateExpressions.exists(AggregateExpression.containsAggregate)
val upperHasNoAggregateExpressions = !upper.aggregateExpressions.exists(isAggregate)

lazy val upperRefsOnlyDeterministicNonAgg = upper.references.subsetOf(AttributeSet(
lower
.aggregateExpressions
.filter(_.deterministic)
.filterNot(AggregateExpression.containsAggregate)
.filter(!isAggregate(_))
.map(_.toAttribute)
))

upperHasNoAggregateExpressions && upperRefsOnlyDeterministicNonAgg
}

private def isAggregate(expr: Expression): Boolean = {
expr.find(e => e.isInstanceOf[AggregateExpression] ||
PythonUDF.isGroupedAggPandasUDF(e)).isDefined
}
}

/**
Expand Down Expand Up @@ -1978,18 +1978,7 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
case a @ Aggregate(grouping, _, _) if grouping.nonEmpty =>
val newGrouping = grouping.filter(!_.foldable)
if (newGrouping.nonEmpty) {
val droppedGroupsBefore =
grouping.scanLeft(0)((n, e) => n + (if (e.foldable) 1 else 0)).toArray

val newAggregateExpressions =
a.aggregateExpressions.map(_.transform {
case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 =>
g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal))
}.asInstanceOf[NamedExpression])

a.copy(
groupingExpressions = newGrouping,
aggregateExpressions = newAggregateExpressions)
a.copy(groupingExpressions = newGrouping)
} else {
// All grouping expressions are literals. We should not drop them all, because this can
// change the return semantics when the input of the Aggregate is empty (SPARK-17114). We
Expand All @@ -2010,25 +1999,7 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
if (newGrouping.size == grouping.size) {
a
} else {
var i = 0
val droppedGroupsBefore = grouping.scanLeft(0)((n, e) =>
n + (if (i >= newGrouping.size || e.eq(newGrouping(i))) {
i += 1
0
} else {
1
})
).toArray

val newAggregateExpressions =
a.aggregateExpressions.map(_.transform {
case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 =>
g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal))
}.asInstanceOf[NamedExpression])

a.copy(
groupingExpressions = newGrouping,
aggregateExpressions = newAggregateExpressions)
a.copy(groupingExpressions = newGrouping)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,10 +633,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
* subqueries.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
case a @ Aggregate(grouping, _, child) =>
case a @ Aggregate(grouping, expressions, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val rewriteExprs = a.aggregateExpressionsWithoutGroupingRefs
.map(extractCorrelatedScalarSubqueries(_, subqueries))
val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
// We currently only allow correlated subqueries in an aggregate if they are part of the
// grouping expressions. As a result we need to replace all the scalar subqueries in the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ object PhysicalAggregation {
(Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan)

def unapply(a: Any): Option[ReturnType] = a match {
case a @ logical.Aggregate(groupingExpressions, resultExpressions, child) =>
case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
// A single aggregate expression might appear multiple times in resultExpressions.
// In order to avoid evaluating an individual aggregate function multiple times, we'll
// build a set of semantically distinct aggregate expressions and re-write expressions so
Expand All @@ -297,9 +297,11 @@ object PhysicalAggregation {
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
// addExpr() always returns false for non-deterministic expressions and do not add them.
case a
if AggregateExpression.isAggregate(a) && !equivalentAggregateExpressions.addExpr(a) =>
a
case agg: AggregateExpression
if !equivalentAggregateExpressions.addExpr(agg) => agg
case udf: PythonUDF
if PythonUDF.isGroupedAggPandasUDF(udf) &&
!equivalentAggregateExpressions.addExpr(udf) => udf
}
}

Expand All @@ -320,7 +322,7 @@ object PhysicalAggregation {
// which takes the grouping columns and final aggregate result buffer as input.
// Thus, we must re-write the result expressions so that their attributes match up with
// the attributes of the final result projection's input row:
val rewrittenResultExpressions = a.aggregateExpressionsWithoutGroupingRefs.map { expr =>
val rewrittenResultExpressions = resultExpressions.map { expr =>
expr.transformDown {
case ae: AggregateExpression =>
// The final aggregation buffer's attributes will be `finalAggregationAttributes`,
Expand Down
Loading

0 comments on commit fdccd88

Please sign in to comment.