Skip to content

Commit

Permalink
In BHJ, shuffle read should be local always (apache#53)
Browse files Browse the repository at this point in the history
* In BHJ, shuffle read should be local

* add comments
  • Loading branch information
Yucai Yu authored and carsonwang committed May 17, 2018
1 parent f2ad905 commit de21bb3
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,6 @@ case class OptimizeJoin(conf: SQLConf) extends Rule[SparkPlan] {
private def optimizeForLocalShuffleReadLessPartitions(
broadcastSidePlan: SparkPlan,
childrenPlans: Seq[SparkPlan]) = {
// All shuffle read should be local instead of remote
childrenPlans.foreach {
case input: ShuffleQueryStageInput =>
input.isLocalShuffle = true
case _ =>
}
// If there's shuffle write on broadcast side, then find the partitions with 0 size and ignore
// reading them in local shuffle read.
broadcastSidePlan match {
Expand Down Expand Up @@ -138,6 +132,12 @@ case class OptimizeJoin(conf: SQLConf) extends Rule[SparkPlan] {
condition,
removeSort(left),
removeSort(right))
// All shuffle read should be local instead of remote
broadcastJoin.children.foreach {
case input: ShuffleQueryStageInput =>
input.isLocalShuffle = true
case _ =>
}

val newChild = queryStage.child.transformDown {
case s: SortMergeJoinExec if s.fastEquals(smj) => broadcastJoin
Expand Down Expand Up @@ -177,11 +177,7 @@ case class OptimizeJoin(conf: SQLConf) extends Rule[SparkPlan] {
} else {
logWarning("Join optimization is not applied due to additional shuffles will be " +
"introduced. Enable spark.sql.adaptive.allowAdditionalShuffle to allow it.")
joinType match {
case _: InnerLike =>
revertShuffleReadChanges(broadcastJoin.children)
case _ =>
}
revertShuffleReadChanges(broadcastJoin.children)
smj
}
}.getOrElse(smj)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,42 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll {
}
}

def checkJoin(join: DataFrame, spark: SparkSession): Unit = {
// Before Execution, there is one SortMergeJoin
val smjBeforeExecution = join.queryExecution.executedPlan.collect {
case smj: SortMergeJoinExec => smj
}
assert(smjBeforeExecution.length === 1)

// Check the answer.
val expectedAnswer =
spark
.range(0, 1000)
.selectExpr("id % 500 as key", "id as value")
.union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value"))
checkAnswer(
join,
expectedAnswer.collect())

// During execution, the SortMergeJoin is changed to BroadcastHashJoinExec
val smjAfterExecution = join.queryExecution.executedPlan.collect {
case smj: SortMergeJoinExec => smj
}
assert(smjAfterExecution.length === 0)

val numBhjAfterExecution = join.queryExecution.executedPlan.collect {
case smj: BroadcastHashJoinExec => smj
}.length
assert(numBhjAfterExecution === 1)

// Both shuffle should be local shuffle
val queryStageInputs = join.queryExecution.executedPlan.collect {
case q: ShuffleQueryStageInput => q
}
assert(queryStageInputs.length === 2)
assert(queryStageInputs.forall(_.isLocalShuffle) === true)
}

test("1 sort merge join to broadcast join") {
withSparkSession(defaultSparkSession) { spark: SparkSession =>
val df1 =
Expand All @@ -83,39 +119,12 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll {
.range(0, 1000, 1, numInputPartitions)
.selectExpr("id % 500 as key2", "id as value2")

val join = df1.join(df2, col("key1") === col("key2")).select(col("key1"), col("value2"))

// Before Execution, there is one SortMergeJoin
val smjBeforeExecution = join.queryExecution.executedPlan.collect {
case smj: SortMergeJoinExec => smj
}
assert(smjBeforeExecution.length === 1)

// Check the answer.
val expectedAnswer =
spark
.range(0, 1000)
.selectExpr("id % 500 as key", "id as value")
.union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value"))
checkAnswer(
join,
expectedAnswer.collect())

// During execution, the SortMergeJoin is changed to BroadcastHashJoinExec
val smjAfterExecution = join.queryExecution.executedPlan.collect {
case smj: SortMergeJoinExec => smj
}
assert(smjAfterExecution.length === 0)
val innerJoin = df1.join(df2, col("key1") === col("key2")).select(col("key1"), col("value2"))
checkJoin(innerJoin, spark)

val numBhjAfterExecution = join.queryExecution.executedPlan.collect {
case smj: BroadcastHashJoinExec => smj
}.length
assert(numBhjAfterExecution === 1)

val queryStageInputs = join.queryExecution.executedPlan.collect {
case q: QueryStageInput => q
}
assert(queryStageInputs.length === 2)
val leftJoin =
df1.join(df2, col("key1") === col("key2"), "left").select(col("key1"), col("value1"))
checkJoin(leftJoin, spark)
}
}

Expand Down

0 comments on commit de21bb3

Please sign in to comment.