Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-19981][SQL] Respect aliases in output partitioning of projects and aggregates #17400

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.physical._

trait AliasAwareOutputPartitioning extends UnaryExecNode {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need a general utility class for this. cc @maryannxue She did the similar things for the other projects in the past. Maybe @maryannxue can help deliver such a utility class?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I'll wait for @maryannxue suggestion.


protected def outputExpressions: Seq[NamedExpression]

// If projects and aggregates have aliases in output expressions, we should respect
// these aliases so as to check if the operators satisfy their output distribution requirements.
// If we don't respect aliases, this rule wrongly adds shuffle operations, e.g.,
//
// spark.range(10).selectExpr("id AS key", "0").repartition($"key").write.saveAsTable("df1")
// spark.range(10).selectExpr("id AS key", "0").repartition($"key").write.saveAsTable("df2")
// sql("""
// SELECT * FROM
// (SELECT key AS k from df1) t1
// INNER JOIN
// (SELECT key AS k from df2) t2
// ON t1.k = t2.k
// """).explain
//
// == Physical Plan ==
// *SortMergeJoin [k#56L], [k#57L], Inner
// :- *Sort [k#56L ASC NULLS FIRST], false, 0
// : +- Exchange hashpartitioning(k#56L, 200) // <--- Unnecessary shuffle operation
// : +- *Project [key#39L AS k#56L]
// : +- Exchange hashpartitioning(key#39L, 200)
// : +- *Project [id#36L AS key#39L]
// : +- *Range (0, 10, step=1, splits=Some(4))
// +- *Sort [k#57L ASC NULLS FIRST], false, 0
// +- ReusedExchange [k#57L], Exchange hashpartitioning(k#56L, 200)
final override def outputPartitioning: Partitioning = if (hasAlias(outputExpressions)) {
resolveOutputPartitioningByAliases(outputExpressions, child.outputPartitioning)
} else {
child.outputPartitioning
}

private def hasAlias(exprs: Seq[NamedExpression]): Boolean =
exprs.exists(_.collectFirst { case _: Alias => true }.isDefined)

private def resolveOutputPartitioningByAliases(
exprs: Seq[NamedExpression],
partitioning: Partitioning): Partitioning = {
val aliasSeq = exprs.flatMap(_.collectFirst {
case a @ Alias(child, _) => (child, a.toAttribute)
})
def mayReplaceExprWithAlias(e: Expression): Expression = {
aliasSeq.find { case (c, _) => c.semanticEquals(e) }.map(_._2).getOrElse(e)
}
def mayReplacePartitioningExprsWithAliases(p: Partitioning): Partitioning = p match {
case hash @ HashPartitioning(exprs, _) =>
hash.copy(expressions = exprs.map(mayReplaceExprWithAlias))
case range @ RangePartitioning(ordering, _) =>
range.copy(ordering = ordering.map { order =>
order.copy(
child = mayReplaceExprWithAlias(order.child),
sameOrderExpressions = order.sameOrderExpressions.map(mayReplaceExprWithAlias)
)
})
case _ => p
}

partitioning match {
case pc @ PartitioningCollection(ps) =>
pc.copy(partitionings = ps.map(mayReplacePartitioningExprsWithAliases))
case _ =>
mayReplacePartitioningExprsWithAliases(partitioning)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ case class HashAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode with CodegenSupport {
extends UnaryExecNode with CodegenSupport with AliasAwareOutputPartitioning {

override protected def outputExpressions: Seq[NamedExpression] = resultExpressions

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
Expand All @@ -66,8 +68,6 @@ case class HashAggregateExec(

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)

override def outputPartitioning: Partitioning = child.outputPartitioning

override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ case class ObjectHashAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode {
extends UnaryExecNode with AliasAwareOutputPartitioning {

override protected def outputExpressions: Seq[NamedExpression] = resultExpressions

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
Expand Down Expand Up @@ -95,8 +97,6 @@ case class ObjectHashAggregateExec(
}
}

override def outputPartitioning: Partitioning = child.outputPartitioning

protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
val numOutputRows = longMetric("numOutputRows")
val aggTime = longMetric("aggTime")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.Utils

Expand All @@ -38,7 +38,9 @@ case class SortAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode {
extends UnaryExecNode with AliasAwareOutputPartitioning {

override protected def outputExpressions: Seq[NamedExpression] = resultExpressions

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
Expand Down Expand Up @@ -66,8 +68,6 @@ case class SortAggregateExec(
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
}

override def outputPartitioning: Partitioning = child.outputPartitioning

override def outputOrdering: Seq[SortOrder] = {
groupingExpressions.map(SortOrder(_, Ascending))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}

/** Physical plan for Project. */
case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
extends UnaryExecNode with CodegenSupport {
extends UnaryExecNode with CodegenSupport with AliasAwareOutputPartitioning {

override def output: Seq[Attribute] = projectList.map(_.toAttribute)

override protected def outputExpressions: Seq[NamedExpression] = projectList

override def inputRDDs(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].inputRDDs()
}
Expand Down Expand Up @@ -76,8 +78,6 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
}

override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def outputPartitioning: Partitioning = child.outputPartitioning
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec,
SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down
6 changes: 6 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/group-by.sql
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,9 @@ where b.z != b.z;
-- SPARK-24369 multiple distinct aggregations having the same argument set
SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*)
FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y);

-- SPARK-19981 Correctly resolve partitioning when output has aliases
CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1, 'a'), (1, 'b') AS (a, b) DISTRIBUTE BY a;
EXPLAIN SELECT k, MAX(b) FROM (SELECT a AS k, b FROM t1) t GROUP BY k;
CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1, 2), (0, 3) AS (a, b);
EXPLAIN SELECT k, COUNT(v) FROM (SELECT a AS k, MAX(b) AS v FROM t2 GROUP BY a) t GROUP BY k;
7 changes: 7 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/inner-join.sql
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,10 @@ SELECT a, 'b' AS tag FROM t4;

-- SPARK-19766 Constant alias columns in INNER JOIN should not be folded by FoldablePropagation rule
SELECT tb.* FROM ta INNER JOIN tb ON ta.a = tb.a AND ta.tag = tb.tag;

-- SPARK-19981 Correctly resolve partitioning when output has aliases
SET spark.sql.autoBroadcastJoinThreshold = -1;
CREATE TEMPORARY VIEW t5 AS SELECT * FROM VALUES (1, 1), (3, 0) AS (k, v) DISTRIBUTE BY (k);
CREATE TEMPORARY VIEW t6 AS SELECT * FROM VALUES (1, 1), (5, 1) AS (k, v) DISTRIBUTE BY (k);
EXPLAIN SELECT * FROM (SELECT k AS k1 FROM t5) t5a
INNER JOIN (SELECT k AS k1 FROM t6) t6a ON t5a.k1 = t6a.k1;
46 changes: 45 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/group-by.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 27
-- Number of queries: 31


-- !query 0
Expand Down Expand Up @@ -250,3 +250,47 @@ SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*)
struct<corr(DISTINCT CAST(x AS DOUBLE), CAST(y AS DOUBLE)):double,corr(DISTINCT CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,count(1):bigint>
-- !query 26 output
1.0 1.0 3


-- !query 27
CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1, 'a'), (1, 'b') AS (a, b) DISTRIBUTE BY a
-- !query 27 schema
struct<>
-- !query 27 output



-- !query 28
EXPLAIN SELECT k, MAX(b) FROM (SELECT a AS k, b FROM t1) t GROUP BY k
-- !query 28 schema
struct<plan:string>
-- !query 28 output
== Physical Plan ==
SortAggregate(key=[k#x], functions=[max(b#x)])
+- SortAggregate(key=[k#x], functions=[partial_max(b#x)])
+- *Sort [k#x ASC NULLS FIRST], false, 0
+- *Project [a#x AS k#x, b#x]
+- Exchange hashpartitioning(a#x, 200)
+- LocalTableScan [a#x, b#x]


-- !query 29
CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1, 2), (0, 3) AS (a, b)
-- !query 29 schema
struct<>
-- !query 29 output



-- !query 30
EXPLAIN SELECT k, COUNT(v) FROM (SELECT a AS k, MAX(b) AS v FROM t2 GROUP BY a) t GROUP BY k
-- !query 30 schema
struct<plan:string>
-- !query 30 output
== Physical Plan ==
*HashAggregate(keys=[k#x], functions=[count(v#x)])
+- *HashAggregate(keys=[k#x], functions=[partial_count(v#x)])
+- *HashAggregate(keys=[a#x], functions=[max(b#x)])
+- Exchange hashpartitioning(a#x, 200)
+- *HashAggregate(keys=[a#x], functions=[partial_max(b#x)])
+- LocalTableScan [a#x, b#x]
44 changes: 43 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/inner-join.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 7
-- Number of queries: 11


-- !query 0
Expand Down Expand Up @@ -65,3 +65,45 @@ struct<a:int,tag:string>
1 a
1 b
1 b


-- !query 7
SET spark.sql.autoBroadcastJoinThreshold = -1
-- !query 7 schema
struct<key:string,value:string>
-- !query 7 output
spark.sql.autoBroadcastJoinThreshold -1


-- !query 8
CREATE TEMPORARY VIEW t5 AS SELECT * FROM VALUES (1, 1), (3, 0) AS (k, v) DISTRIBUTE BY (k)
-- !query 8 schema
struct<>
-- !query 8 output



-- !query 9
CREATE TEMPORARY VIEW t6 AS SELECT * FROM VALUES (1, 1), (5, 1) AS (k, v) DISTRIBUTE BY (k)
-- !query 9 schema
struct<>
-- !query 9 output



-- !query 10
EXPLAIN SELECT * FROM (SELECT k AS k1 FROM t5) t5a
INNER JOIN (SELECT k AS k1 FROM t6) t6a ON t5a.k1 = t6a.k1
-- !query 10 schema
struct<plan:string>
-- !query 10 output
== Physical Plan ==
*SortMergeJoin [k1#x], [k1#x], Inner
:- *Sort [k1#x ASC NULLS FIRST], false, 0
: +- *Project [k#x AS k1#x]
: +- Exchange hashpartitioning(k#x, 200)
: +- LocalTableScan [k#x]
+- *Sort [k1#x ASC NULLS FIRST], false, 0
+- *Project [k#x AS k1#x]
+- Exchange hashpartitioning(k#x, 200)
+- LocalTableScan [k#x]