From 3f3d02473265d2ef4f3301a05b7166d93a125184 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Tue, 13 Aug 2024 23:47:08 +0800 Subject: [PATCH] [SPARK-49205][SQL] KeyGroupedPartitioning should inherit HashPartitioningLike ### What changes were proposed in this pull request? This pr makes `KeyGroupedPartitioning` inherit `HashPartitioningLike`, so that the `BroadcastHashJoin#expandOutputPartitioning` and `PartitioningPreservingUnaryExecNode` can work with it. ### Why are the changes needed? To make `KeyGroupedPartitioning` support alias aware framework. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? add test ### Was this patch authored or co-authored using generative AI tooling? no Closes #47734 from ulysses-you/SPARK-49205-partitioning. Authored-by: ulysses-you Signed-off-by: Kent Yao --- .../plans/physical/partitioning.scala | 9 ++-- .../KeyGroupedPartitioningSuite.scala | 50 +++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index f8e980747bf2a..30e223c3c3c87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -370,7 +370,7 @@ case class KeyGroupedPartitioning( expressions: Seq[Expression], numPartitions: Int, partitionValues: Seq[InternalRow] = Seq.empty, - originalPartitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning { + originalPartitionValues: Seq[InternalRow] = Seq.empty) extends HashPartitioningLike { override def satisfies0(required: Distribution): Boolean = { super.satisfies0(required) || { @@ -421,6 +421,9 @@ case class KeyGroupedPartitioning( .distinct .map(_.row) } + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + copy(expressions = newChildren) } object KeyGroupedPartitioning { @@ -766,8 +769,8 @@ case class CoalescedHashShuffleSpec( * * @param partitioning key grouped partitioning * @param distribution distribution - * @param joinKeyPosition position of join keys among cluster keys. - * This is set if joining on a subset of cluster keys is allowed. + * @param joinKeyPositions position of join keys among cluster keys. + * This is set if joining on a subset of cluster keys is allowed. */ case class KeyGroupedShuffleSpec( partitioning: KeyGroupedPartitioning, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 6a146bc887db4..14598f243785c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.connector +import java.sql.Timestamp import java.util.Collections import org.apache.spark.SparkConf @@ -583,6 +584,55 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-49205: KeyGroupedPartitioning should inherit HashPartitioningLike") { + val items_partitions = Array(days("arrive_time")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(days("time")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 44.0, cast('2020-01-15' as timestamp)), " + + "(1, 45.0, cast('2020-01-15' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + val df = sql( + s""" + |SELECT x, count(*) FROM ( + | SELECT /*+ broadcast(t2) */ arrive_time as x, * FROM testcat.ns.$items t1 + | JOIN testcat.ns.$purchases t2 ON t1.arrive_time = t2.time + |) + |GROUP BY x + |""".stripMargin) + checkAnswer(df, + Seq(Row(Timestamp.valueOf("2020-01-01 00:00:00"), 6), + Row(Timestamp.valueOf("2020-01-15 00:00:00"), 2), + Row(Timestamp.valueOf("2020-02-01 00:00:00"), 1))) + assert(collectAllShuffles(df.queryExecution.executedPlan).isEmpty) + + val df2 = sql( + s""" + |WITH t1 (SELECT * FROM testcat.ns.$items) + |SELECT x, count(*) FROM ( + | SELECT /*+ broadcast(t2) */ t2.time as x FROM t1 + | JOIN testcat.ns.$purchases t2 ON t1.arrive_time = t2.time + | JOIN t1 t3 ON t1.arrive_time = t3.arrive_time + |) GROUP BY x + |""".stripMargin) + checkAnswer(df2, + Seq(Row(Timestamp.valueOf("2020-01-01 00:00:00"), 18), + Row(Timestamp.valueOf("2020-01-15 00:00:00"), 2), + Row(Timestamp.valueOf("2020-02-01 00:00:00"), 1))) + assert(collectAllShuffles(df2.queryExecution.executedPlan).isEmpty) + } + test("SPARK-42038: partially clustered: with same partition keys and one side fully clustered") { val items_partitions = Array(identity("id")) createTable(items, itemsColumns, items_partitions)