Skip to content

Commit

Permalink
[SPARK-39915][SQL] Dataset.repartition(N) may not create N partitions…
Browse files Browse the repository at this point in the history
… Non-AQE part

Skip optimize the root user-specified repartition in `PropagateEmptyRelation`.

Spark should preserve the final repatition which can affect the final output partition which is user-specified.

For example:

```scala
spark.sql("select * from values(1) where 1 < rand()").repartition(1)

// before:
== Optimized Logical Plan ==
LocalTableScan <empty>, [col1#0]

// after:
== Optimized Logical Plan ==
Repartition 1, true
+- LocalRelation <empty>, [col1#0]
```

yes, the empty plan may change

add test

Closes apache#37706 from ulysses-you/empty.

Authored-by: ulysses-you <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
ulysses-you committed Aug 30, 2022
1 parent e46d2e2 commit af2ca71
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,9 @@ package object dsl {
def repartition(num: Integer): LogicalPlan =
Repartition(num, shuffle = true, logicalPlan)

def repartition(): LogicalPlan =
RepartitionByExpression(Seq.empty, logicalPlan, None)

def distribute(exprs: Expression*)(n: Int): LogicalPlan =
RepartitionByExpression(exprs, logicalPlan, numPartitions = n)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_FALSE_LITERAL}

/**
Expand All @@ -44,6 +45,9 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_
* - Generate(Explode) with all empty children. Others like Hive UDTF may return results.
*/
abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSupport {
// This tag is used to mark a repartition as a root repartition which is user-specified
private[sql] val ROOT_REPARTITION = TreeNodeTag[Unit]("ROOT_REPARTITION")

protected def isEmpty(plan: LogicalPlan): Boolean = plan match {
case p: LocalRelation => p.data.isEmpty
case _ => false
Expand Down Expand Up @@ -136,8 +140,13 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup
case _: Sort => empty(p)
case _: GlobalLimit if !p.isStreaming => empty(p)
case _: LocalLimit if !p.isStreaming => empty(p)
case _: Repartition => empty(p)
case _: RepartitionByExpression => empty(p)
case _: RepartitionOperation =>
if (p.getTagValue(ROOT_REPARTITION).isEmpty) {
empty(p)
} else {
p.unsetTagValue(ROOT_REPARTITION)
p
}
case _: RebalancePartitions => empty(p)
// An aggregate with non-empty group expression will return one output row per group when the
// input to the aggregate is not empty. If the input to the aggregate is empty then all groups
Expand All @@ -160,13 +169,40 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup
case _ => p
}
}

protected def userSpecifiedRepartition(p: LogicalPlan): Boolean = p match {
case _: Repartition => true
case r: RepartitionByExpression
if r.optNumPartitions.isDefined || r.partitionExpressions.nonEmpty => true
case _ => false
}

protected def applyInternal(plan: LogicalPlan): LogicalPlan

/**
* Add a [[ROOT_REPARTITION]] tag for the root user-specified repartition so this rule can
* skip optimize it.
*/
private def addTagForRootRepartition(plan: LogicalPlan): LogicalPlan = plan match {
case p: Project => p.mapChildren(addTagForRootRepartition)
case f: Filter => f.mapChildren(addTagForRootRepartition)
case r if userSpecifiedRepartition(r) =>
r.setTagValue(ROOT_REPARTITION, ())
r
case _ => plan
}

override def apply(plan: LogicalPlan): LogicalPlan = {
val planWithTag = addTagForRootRepartition(plan)
applyInternal(planWithTag)
}
}

/**
* This rule runs in the normal optimizer
*/
object PropagateEmptyRelation extends PropagateEmptyRelationBase {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
override protected def applyInternal(p: LogicalPlan): LogicalPlan = p.transformUpWithPruning(
_.containsAnyPattern(LOCAL_RELATION, TRUE_OR_FALSE_LITERAL), ruleId) {
commonApplyFunc
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,4 +309,42 @@ class PropagateEmptyRelationSuite extends PlanTest {
val optimized2 = Optimize.execute(plan2)
comparePlans(optimized2, expected)
}

test("Propagate empty relation with repartition") {
val emptyRelation = LocalRelation($"a".int, $"b".int)
comparePlans(Optimize.execute(
emptyRelation.repartition(1).sortBy($"a".asc).analyze
), emptyRelation.analyze)

comparePlans(Optimize.execute(
emptyRelation.distribute($"a")(1).sortBy($"a".asc).analyze
), emptyRelation.analyze)

comparePlans(Optimize.execute(
emptyRelation.repartition().analyze
), emptyRelation.analyze)

comparePlans(Optimize.execute(
emptyRelation.repartition(1).sortBy($"a".asc).repartition().analyze
), emptyRelation.analyze)
}

test("SPARK-39915: Dataset.repartition(N) may not create N partitions") {
val emptyRelation = LocalRelation($"a".int, $"b".int)
val p1 = emptyRelation.repartition(1).analyze
comparePlans(Optimize.execute(p1), p1)

val p2 = emptyRelation.repartition(1).select($"a").analyze
comparePlans(Optimize.execute(p2), p2)

val p3 = emptyRelation.repartition(1).where($"a" > rand(1)).analyze
comparePlans(Optimize.execute(p3), p3)

val p4 = emptyRelation.repartition(1).where($"a" > rand(1)).select($"a").analyze
comparePlans(Optimize.execute(p4), p4)

val p5 = emptyRelation.sortBy("$a".asc).repartition().limit(1).repartition(1).analyze
val expected5 = emptyRelation.repartition(1).analyze
comparePlans(Optimize.execute(p5), expected5)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase {
empty(j)
}

def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
override protected def applyInternal(p: LogicalPlan): LogicalPlan = p.transformUpWithPruning(
// LOCAL_RELATION and TRUE_OR_FALSE_LITERAL pattern are matched at
// `PropagateEmptyRelationBase.commonApplyFunc`
// LOGICAL_QUERY_STAGE pattern is matched at `PropagateEmptyRelationBase.commonApplyFunc`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3281,6 +3281,13 @@ class DataFrameSuite extends QueryTest
Row(java.sql.Date.valueOf("2020-02-01"), java.sql.Date.valueOf("2020-02-01")) ::
Row(java.sql.Date.valueOf("2020-01-01"), java.sql.Date.valueOf("2020-01-02")) :: Nil)
}

test("SPARK-39915: Dataset.repartition(N) may not create N partitions") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
val df = spark.sql("select * from values(1) where 1 < rand()").repartition(2)
assert(df.queryExecution.executedPlan.execute().getNumPartitions == 2)
}
}
}

case class GroupByKey(a: Int, b: Int)
Expand Down

0 comments on commit af2ca71

Please sign in to comment.