diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeJoin.scala index 3717f4acc757e..65b00653d42d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeJoin.scala @@ -81,12 +81,6 @@ case class OptimizeJoin(conf: SQLConf) extends Rule[SparkPlan] { private def optimizeForLocalShuffleReadLessPartitions( broadcastSidePlan: SparkPlan, childrenPlans: Seq[SparkPlan]) = { - // All shuffle read should be local instead of remote - childrenPlans.foreach { - case input: ShuffleQueryStageInput => - input.isLocalShuffle = true - case _ => - } // If there's shuffle write on broadcast side, then find the partitions with 0 size and ignore // reading them in local shuffle read. broadcastSidePlan match { @@ -138,6 +132,12 @@ case class OptimizeJoin(conf: SQLConf) extends Rule[SparkPlan] { condition, removeSort(left), removeSort(right)) + // All shuffle read should be local instead of remote + broadcastJoin.children.foreach { + case input: ShuffleQueryStageInput => + input.isLocalShuffle = true + case _ => + } val newChild = queryStage.child.transformDown { case s: SortMergeJoinExec if s.fastEquals(smj) => broadcastJoin @@ -177,11 +177,7 @@ case class OptimizeJoin(conf: SQLConf) extends Rule[SparkPlan] { } else { logWarning("Join optimization is not applied due to additional shuffles will be " + "introduced. Enable spark.sql.adaptive.allowAdditionalShuffle to allow it.") - joinType match { - case _: InnerLike => - revertShuffleReadChanges(broadcastJoin.children) - case _ => - } + revertShuffleReadChanges(broadcastJoin.children) smj } }.getOrElse(smj) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryStageSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryStageSuite.scala index a25acda5f971c..5295bd77a76fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryStageSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryStageSuite.scala @@ -72,6 +72,42 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll { } } + def checkJoin(join: DataFrame, spark: SparkSession): Unit = { + // Before Execution, there is one SortMergeJoin + val smjBeforeExecution = join.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjBeforeExecution.length === 1) + + // Check the answer. + val expectedAnswer = + spark + .range(0, 1000) + .selectExpr("id % 500 as key", "id as value") + .union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value")) + checkAnswer( + join, + expectedAnswer.collect()) + + // During execution, the SortMergeJoin is changed to BroadcastHashJoinExec + val smjAfterExecution = join.queryExecution.executedPlan.collect { + case smj: SortMergeJoinExec => smj + } + assert(smjAfterExecution.length === 0) + + val numBhjAfterExecution = join.queryExecution.executedPlan.collect { + case smj: BroadcastHashJoinExec => smj + }.length + assert(numBhjAfterExecution === 1) + + // Both shuffle should be local shuffle + val queryStageInputs = join.queryExecution.executedPlan.collect { + case q: ShuffleQueryStageInput => q + } + assert(queryStageInputs.length === 2) + assert(queryStageInputs.forall(_.isLocalShuffle) === true) + } + test("1 sort merge join to broadcast join") { withSparkSession(defaultSparkSession) { spark: SparkSession => val df1 = @@ -83,39 +119,12 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll { .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key2", "id as value2") - val join = df1.join(df2, col("key1") === col("key2")).select(col("key1"), col("value2")) - - // Before Execution, there is one SortMergeJoin - val smjBeforeExecution = join.queryExecution.executedPlan.collect { - case smj: SortMergeJoinExec => smj - } - assert(smjBeforeExecution.length === 1) - - // Check the answer. - val expectedAnswer = - spark - .range(0, 1000) - .selectExpr("id % 500 as key", "id as value") - .union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value")) - checkAnswer( - join, - expectedAnswer.collect()) - - // During execution, the SortMergeJoin is changed to BroadcastHashJoinExec - val smjAfterExecution = join.queryExecution.executedPlan.collect { - case smj: SortMergeJoinExec => smj - } - assert(smjAfterExecution.length === 0) + val innerJoin = df1.join(df2, col("key1") === col("key2")).select(col("key1"), col("value2")) + checkJoin(innerJoin, spark) - val numBhjAfterExecution = join.queryExecution.executedPlan.collect { - case smj: BroadcastHashJoinExec => smj - }.length - assert(numBhjAfterExecution === 1) - - val queryStageInputs = join.queryExecution.executedPlan.collect { - case q: QueryStageInput => q - } - assert(queryStageInputs.length === 2) + val leftJoin = + df1.join(df2, col("key1") === col("key2"), "left").select(col("key1"), col("value1")) + checkJoin(leftJoin, spark) } }