Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-40382][SQL] Group distinct aggregate expressions by semantically equivalent children in RewriteDistinctAggregates #37825

Closed
wants to merge 13 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {

// Extract distinct aggregate expressions.
val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e =>
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet
val unfoldableChildren = ExpressionSet(e.aggregateFunction.children.filter(!_.foldable))
if (unfoldableChildren.nonEmpty) {
// Only expand the unfoldable children
unfoldableChildren
Expand All @@ -231,7 +231,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
// count(distinct 1) will be explained to count(1) after the rewrite function.
// Generally, the distinct aggregateFunction should not run
// foldable TypeCheck for the first child.
e.aggregateFunction.children.take(1).toSet
ExpressionSet(e.aggregateFunction.children.take(1))
}
}

Expand All @@ -254,7 +254,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {

// Setup unique distinct aggregate children.
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this is necessary, but it's better to use ExpressionSet(distinctAggGroups.keySet.flatten).toSeq, instead of calling .distinct on Seq[Expression]

val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
val distinctAggChildAttrMap = distinctAggChildren.map { e =>
e.canonicalized -> AttributeReference(e.sql, e.dataType, nullable = true)()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we can update expressionAttributePair.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expressionAttributePair is used in two other places, though, for regular aggregate children and filter expressions where the key does not need to be canonicalized.

}
val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
// Setup all the filters in distinct aggregate.
val (distinctAggFilters, distinctAggFilterAttrs, maxConds) = distinctAggs.collect {
Expand Down Expand Up @@ -292,7 +294,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
af
} else {
patchAggregateFunctionChildren(af) { x =>
distinctAggChildAttrLookup.get(x)
distinctAggChildAttrLookup.get(x.canonicalized)
}
}
val newCondition = if (condition.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,37 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
.analyze
checkRewrite(RewriteDistinctAggregates(input))
}

test("SPARK-40382: eliminate multiple distinct groups due to superficial differences") {
val input = testRelation
.groupBy($"a")(
countDistinct($"b" + $"c").as("agg1"),
countDistinct($"c" + $"b").as("agg2"),
max($"c").as("agg3"))
.analyze

val rewrite = RewriteDistinctAggregates(input)
rewrite match {
case Aggregate(_, _, LocalRelation(_, _, _)) =>
case _ => fail(s"Plan is not as expected:\n$rewrite")
}
}

test("SPARK-40382: reduce multiple distinct groups due to superficial differences") {
val input = testRelation
.groupBy($"a")(
countDistinct($"b" + $"c" + $"d").as("agg1"),
countDistinct($"d" + $"c" + $"b").as("agg2"),
countDistinct($"b" + $"c").as("agg3"),
countDistinct($"c" + $"b").as("agg4"),
max($"c").as("agg5"))
.analyze

val rewrite = RewriteDistinctAggregates(input)
rewrite match {
case Aggregate(_, _, Aggregate(_, _, e: Expand)) =>
assert(e.projections.size == 3)
case _ => fail(s"Plan is not rewritten:\n$rewrite")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

val (functionsWithDistinct, functionsWithoutDistinct) =
aggregateExpressions.partition(_.isDistinct)
if (functionsWithDistinct.map(
_.aggregateFunction.children.filterNot(_.foldable).toSet).distinct.length > 1) {
val distinctAggChildSets = functionsWithDistinct.map { ae =>
ExpressionSet(ae.aggregateFunction.children.filterNot(_.foldable))
}.distinct
if (distinctAggChildSets.length > 1) {
// This is a sanity check. We should not reach here when we have multiple distinct
// column sets. Our `RewriteDistinctAggregates` should take care this case.
throw new IllegalStateException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,17 @@ object AggUtils {
}

// 3. Create an Aggregate operator for partial aggregation (for distinct)
val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions, distinctAttributes)
val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions.map(_.canonicalized),
distinctAttributes)
val rewrittenDistinctFunctions = functionsWithDistinct.map {
// Children of an AggregateFunction with DISTINCT keyword has already
// been evaluated. At here, we need to replace original children
// to AttributeReferences.
case agg @ AggregateExpression(aggregateFunction, mode, true, _, _) =>
aggregateFunction.transformDown(distinctColumnAttributeLookup)
.asInstanceOf[AggregateFunction]
aggregateFunction.transformDown {
case e: Expression if distinctColumnAttributeLookup.contains(e.canonicalized) =>
distinctColumnAttributeLookup(e.canonicalized)
}.asInstanceOf[AggregateFunction]
case agg =>
throw new IllegalArgumentException(
"Non-distinct aggregate is found in functionsWithDistinct " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1485,6 +1485,40 @@ class DataFrameAggregateSuite extends QueryTest
val df = Seq(1).toDF("id").groupBy(Stream($"id" + 1, $"id" + 2): _*).sum("id")
checkAnswer(df, Row(2, 3, 1))
}

test("SPARK-40382: Distinct aggregation expression grouping by semantic equivalence") {
Seq(
(1, 1, 3),
(1, 2, 3),
(1, 2, 3),
(2, 1, 1),
(2, 2, 5)
).toDF("k", "c1", "c2").createOrReplaceTempView("df")

// all distinct aggregation children are semantically equivalent
val res1 = sql(
"""select k, sum(distinct c1 + 1), avg(distinct 1 + c1), count(distinct 1 + C1)
|from df
|group by k
|""".stripMargin)
checkAnswer(res1, Row(1, 5, 2.5, 2) :: Row(2, 5, 2.5, 2) :: Nil)

// some distinct aggregation children are semantically equivalent
val res2 = sql(
"""select k, sum(distinct c1 + 2), avg(distinct 2 + c1), count(distinct c2)
|from df
|group by k
|""".stripMargin)
checkAnswer(res2, Row(1, 7, 3.5, 1) :: Row(2, 7, 3.5, 2) :: Nil)

// no distinct aggregation children are semantically equivalent
val res3 = sql(
"""select k, sum(distinct c1 + 2), avg(distinct 3 + c1), count(distinct c2)
|from df
|group by k
|""".stripMargin)
checkAnswer(res3, Row(1, 7, 4.5, 1) :: Row(2, 7, 4.5, 2) :: Nil)
}
}

case class B(c: Option[Double])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
// 2 distinct columns with different order
val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i")
assertNoExpand(query3.queryExecution.executedPlan)

// SPARK-40382: 1 distinct expression with cosmetic differences
val query4 = sql("SELECT sum(DISTINCT j), max(DISTINCT J) FROM v GROUP BY i")
assertNoExpand(query4.queryExecution.executedPlan)
}
}

Expand Down