diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index d47e34b110dc8..000622187f406 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -491,6 +491,9 @@ package object dsl { def repartition(num: Integer): LogicalPlan = Repartition(num, shuffle = true, logicalPlan) + def repartition(): LogicalPlan = + RepartitionByExpression(Seq.empty, logicalPlan, None) + def distribute(exprs: Expression*)(n: Int): LogicalPlan = RepartitionByExpression(exprs, logicalPlan, numPartitions = n) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index 2c964fa6da3db..f8e2096e44326 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_FALSE_LITERAL} /** @@ -44,6 +45,9 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TRUE_OR_ * - Generate(Explode) with all empty children. Others like Hive UDTF may return results. */ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSupport { + // This tag is used to mark a repartition as a root repartition which is user-specified + private[sql] val ROOT_REPARTITION = TreeNodeTag[Unit]("ROOT_REPARTITION") + protected def isEmpty(plan: LogicalPlan): Boolean = plan match { case p: LocalRelation => p.data.isEmpty case _ => false @@ -136,8 +140,13 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup case _: Sort => empty(p) case _: GlobalLimit if !p.isStreaming => empty(p) case _: LocalLimit if !p.isStreaming => empty(p) - case _: Repartition => empty(p) - case _: RepartitionByExpression => empty(p) + case _: RepartitionOperation => + if (p.getTagValue(ROOT_REPARTITION).isEmpty) { + empty(p) + } else { + p.unsetTagValue(ROOT_REPARTITION) + p + } case _: RebalancePartitions => empty(p) // An aggregate with non-empty group expression will return one output row per group when the // input to the aggregate is not empty. If the input to the aggregate is empty then all groups @@ -160,13 +169,40 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup case _ => p } } + + protected def userSpecifiedRepartition(p: LogicalPlan): Boolean = p match { + case _: Repartition => true + case r: RepartitionByExpression + if r.optNumPartitions.isDefined || r.partitionExpressions.nonEmpty => true + case _ => false + } + + protected def applyInternal(plan: LogicalPlan): LogicalPlan + + /** + * Add a [[ROOT_REPARTITION]] tag for the root user-specified repartition so this rule can + * skip optimize it. + */ + private def addTagForRootRepartition(plan: LogicalPlan): LogicalPlan = plan match { + case p: Project => p.mapChildren(addTagForRootRepartition) + case f: Filter => f.mapChildren(addTagForRootRepartition) + case r if userSpecifiedRepartition(r) => + r.setTagValue(ROOT_REPARTITION, ()) + r + case _ => plan + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + val planWithTag = addTagForRootRepartition(plan) + applyInternal(planWithTag) + } } /** * This rule runs in the normal optimizer */ object PropagateEmptyRelation extends PropagateEmptyRelationBase { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( + override protected def applyInternal(p: LogicalPlan): LogicalPlan = p.transformUpWithPruning( _.containsAnyPattern(LOCAL_RELATION, TRUE_OR_FALSE_LITERAL), ruleId) { commonApplyFunc } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 8277e44458bb1..72ef8fdd91b60 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -309,4 +309,42 @@ class PropagateEmptyRelationSuite extends PlanTest { val optimized2 = Optimize.execute(plan2) comparePlans(optimized2, expected) } + + test("Propagate empty relation with repartition") { + val emptyRelation = LocalRelation($"a".int, $"b".int) + comparePlans(Optimize.execute( + emptyRelation.repartition(1).sortBy($"a".asc).analyze + ), emptyRelation.analyze) + + comparePlans(Optimize.execute( + emptyRelation.distribute($"a")(1).sortBy($"a".asc).analyze + ), emptyRelation.analyze) + + comparePlans(Optimize.execute( + emptyRelation.repartition().analyze + ), emptyRelation.analyze) + + comparePlans(Optimize.execute( + emptyRelation.repartition(1).sortBy($"a".asc).repartition().analyze + ), emptyRelation.analyze) + } + + test("SPARK-39915: Dataset.repartition(N) may not create N partitions") { + val emptyRelation = LocalRelation($"a".int, $"b".int) + val p1 = emptyRelation.repartition(1).analyze + comparePlans(Optimize.execute(p1), p1) + + val p2 = emptyRelation.repartition(1).select($"a").analyze + comparePlans(Optimize.execute(p2), p2) + + val p3 = emptyRelation.repartition(1).where($"a" > rand(1)).analyze + comparePlans(Optimize.execute(p3), p3) + + val p4 = emptyRelation.repartition(1).where($"a" > rand(1)).select($"a").analyze + comparePlans(Optimize.execute(p4), p4) + + val p5 = emptyRelation.sortBy("$a".asc).repartition().limit(1).repartition(1).analyze + val expected5 = emptyRelation.repartition(1).analyze + comparePlans(Optimize.execute(p5), expected5) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala index bab77515f79a2..132c919c29112 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala @@ -69,7 +69,7 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase { empty(j) } - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( + override protected def applyInternal(p: LogicalPlan): LogicalPlan = p.transformUpWithPruning( // LOCAL_RELATION and TRUE_OR_FALSE_LITERAL pattern are matched at // `PropagateEmptyRelationBase.commonApplyFunc` // LOGICAL_QUERY_STAGE pattern is matched at `PropagateEmptyRelationBase.commonApplyFunc` diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 43aca31d138f4..b05d320ca07f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -3281,6 +3281,13 @@ class DataFrameSuite extends QueryTest Row(java.sql.Date.valueOf("2020-02-01"), java.sql.Date.valueOf("2020-02-01")) :: Row(java.sql.Date.valueOf("2020-01-01"), java.sql.Date.valueOf("2020-01-02")) :: Nil) } + + test("SPARK-39915: Dataset.repartition(N) may not create N partitions") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + val df = spark.sql("select * from values(1) where 1 < rand()").repartition(2) + assert(df.queryExecution.executedPlan.execute().getNumPartitions == 2) + } + } } case class GroupByKey(a: Int, b: Int)