Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
imback82 committed Jan 23, 2020
1 parent 323d4a7 commit b877de7
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}

/**
Expand All @@ -27,14 +27,21 @@ trait AliasAwareOutputPartitioning extends UnaryExecNode {
protected def outputExpressions: Seq[NamedExpression]

final override def outputPartitioning: Partitioning = {
child.outputPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
val newExpressions = expressions.map {
case a: AttributeReference =>
replaceAlias(a).getOrElse(a)
case other => other
}
HashPartitioning(newExpressions, numPartitions)
if (hasAlias) {
child.outputPartitioning match {
case h: HashPartitioning => h.copy(expressions = replaceAliases(h.expressions))
case other => other
}
} else {
child.outputPartitioning
}
}

private def hasAlias: Boolean = outputExpressions.collectFirst { case _: Alias => }.isDefined

private def replaceAliases(exprs: Seq[Expression]): Seq[Expression] = {
exprs.map {
case a: AttributeReference => replaceAlias(a).getOrElse(a)
case other => other
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range, Repartition, Sort, Union}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
Expand Down Expand Up @@ -990,11 +991,40 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {

val agg1 = t1.groupBy("k1").agg(count(lit("1")).as("cnt1"))
val agg2 = t2.groupBy("k2").agg(count(lit("1")).as("cnt2")).withColumnRenamed("k2", "k3")

val planned = agg1.join(agg2, $"k1" === $"k3").queryExecution.executedPlan

assert(planned.collect { case h: HashAggregateExec => h }.nonEmpty)

val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
assert(exchanges.size == 2)
}
}

test("aliases in the object hash/sort aggregate expressions should not introduce extra shuffle") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
Seq(true, false).foreach { useObjectHashAgg =>
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> useObjectHashAgg.toString) {
val t1 = spark.range(10).selectExpr("floor(id/4) as k1")
val t2 = spark.range(10).selectExpr("floor(id/4) as k2")

val agg1 = t1.groupBy("k1").agg(collect_list("k1"))
val agg2 = t2.groupBy("k2").agg(collect_list("k2")).withColumnRenamed("k2", "k3")

val planned = agg1.join(agg2, $"k1" === $"k3").queryExecution.executedPlan

if (useObjectHashAgg) {
assert(planned.collect { case o: ObjectHashAggregateExec => o }.nonEmpty)
} else {
assert(planned.collect { case s: SortAggregateExec => s }.nonEmpty)
}

val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
assert(exchanges.size == 2)
}
}
}
}
}

// Used for unit-testing EnsureRequirements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -608,16 +608,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
withTable("t") {
withView("v") {
val df = (0 until 20).map(i => (i, i)).toDF("i", "j").as("df")
df.write.format("parquet").bucketBy(8, "i").saveAsTable("t")

spark.range(20).selectExpr("id as i").write.bucketBy(8, "i").saveAsTable("t")
sql("CREATE VIEW v AS SELECT * FROM t").collect()

val plan1 = sql("SELECT * FROM t a JOIN t b ON a.i = b.i").queryExecution.executedPlan
assert(plan1.collect { case exchange: ShuffleExchangeExec => exchange }.isEmpty)

val plan2 = sql("SELECT * FROM t a JOIN v b ON a.i = b.i").queryExecution.executedPlan
assert(plan2.collect { case exchange: ShuffleExchangeExec => exchange }.isEmpty)
val plan = sql("SELECT * FROM t a JOIN v b ON a.i = b.i").queryExecution.executedPlan
assert(plan.collect { case exchange: ShuffleExchangeExec => exchange }.isEmpty)
}
}
}
Expand Down

0 comments on commit b877de7

Please sign in to comment.