Skip to content

Commit

Permalink
Coalesce shuffle partition should handle empty input RDD
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you committed Apr 27, 2021
1 parent 2d2f467 commit 9e395ac
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.adaptive

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan}
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REPARTITION, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.internal.SQLConf

Expand Down Expand Up @@ -54,16 +54,31 @@ case class CoalesceShufflePartitions(session: SparkSession) extends CustomShuffl
if (!shuffleStages.forall(s => supportCoalesce(s.shuffle))) {
plan
} else {
def insertCustomShuffleReader(partitionSpecs: Seq[ShufflePartitionSpec]): SparkPlan = {
// This transformation adds new nodes, so we must use `transformUp` here.
val stageIds = shuffleStages.map(_.id).toSet
plan.transformUp {
// even for shuffle exchange whose input RDD has 0 partition, we should still update its
// `partitionStartIndices`, so that all the leaf shuffles in a stage have the same
// number of output partitions.
case stage: ShuffleQueryStageExec if stageIds.contains(stage.id) =>
CustomShuffleReaderExec(stage, partitionSpecs)
}
}

// `ShuffleQueryStageExec#mapStats` returns None when the input RDD has 0 partitions,
// we should skip it when calculating the `partitionStartIndices`.
// If all input RDDs have 0 partition, we create empty partition for every shuffle reader.
val validMetrics = shuffleStages.flatMap(_.mapStats)

// We may have different pre-shuffle partition numbers, don't reduce shuffle partition number
// in that case. For example when we union fully aggregated data (data is arranged to a single
// partition) and a result of a SortMergeJoin (multiple partitions).
val distinctNumPreShufflePartitions =
validMetrics.map(stats => stats.bytesByPartitionId.length).distinct
if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1) {
if (validMetrics.isEmpty) {
insertCustomShuffleReader(ShufflePartitionsUtil.createEmptyPartition() :: Nil)
} else if (distinctNumPreShufflePartitions.length == 1) {
// We fall back to Spark default parallelism if the minimum number of coalesced partitions
// is not set, so to avoid perf regressions compared to no coalescing.
val minPartitionNum = conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM)
Expand All @@ -77,15 +92,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends CustomShuffl
if (partitionSpecs.length == distinctNumPreShufflePartitions.head) {
plan
} else {
// This transformation adds new nodes, so we must use `transformUp` here.
val stageIds = shuffleStages.map(_.id).toSet
plan.transformUp {
// even for shuffle exchange whose input RDD has 0 partition, we should still update its
// `partitionStartIndices`, so that all the leaf shuffles in a stage have the same
// number of output partitions.
case stage: ShuffleQueryStageExec if stageIds.contains(stage.id) =>
CustomShuffleReaderExec(stage, partitionSpecs)
}
insertCustomShuffleReader(partitionSpecs)
}
} else {
plan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ object ShufflePartitionsUtil extends Logging {
partitionSpecs.toSeq
}

def createEmptyPartition(): ShufflePartitionSpec = {
CoalescedPartitionSpec(0, 0)
}

/**
* Given a list of size, return an array of indices to split the list into multiple partitions,
* so that the size sum of each partition is close to the target size. Each index indicates the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1575,4 +1575,19 @@ class AdaptiveQueryExecSuite
checkNoCoalescePartitions(df.sort($"key"), ENSURE_REQUIREMENTS)
}
}

test("SPARK-35239: Coalesce shuffle partition should handle empty input RDD") {
withTable("t") {
withSQLConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
SQLConf.SHUFFLE_PARTITIONS.key -> "2") {
spark.sql("CREATE TABLE t (c1 int) USING PARQUET")
val (_, adaptive) = runAdaptiveAndVerifyResult("SELECT c1, count(*) FROM t GROUP BY c1")
assert(
collect(adaptive) {
case c @ CustomShuffleReaderExec(_, partitionSpecs) if partitionSpecs.length == 1 => c
}.length == 1
)
}
}
}
}

0 comments on commit 9e395ac

Please sign in to comment.