Skip to content

Commit

Permalink
[SPARK-49205][SQL] KeyGroupedPartitioning should inherit HashPartitio…
Browse files Browse the repository at this point in the history
…ningLike

### What changes were proposed in this pull request?

This pr makes `KeyGroupedPartitioning` inherit `HashPartitioningLike`, so that the `BroadcastHashJoin#expandOutputPartitioning` and `PartitioningPreservingUnaryExecNode` can work with it.

### Why are the changes needed?

To make `KeyGroupedPartitioning` support alias aware framework.

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

add test

### Was this patch authored or co-authored using generative AI tooling?

no

Closes #47734 from ulysses-you/SPARK-49205-partitioning.

Authored-by: ulysses-you <[email protected]>
Signed-off-by: Kent Yao <[email protected]>
  • Loading branch information
ulysses-you authored and yaooqinn committed Aug 13, 2024
1 parent d82c695 commit 3f3d024
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ case class KeyGroupedPartitioning(
expressions: Seq[Expression],
numPartitions: Int,
partitionValues: Seq[InternalRow] = Seq.empty,
originalPartitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning {
originalPartitionValues: Seq[InternalRow] = Seq.empty) extends HashPartitioningLike {

override def satisfies0(required: Distribution): Boolean = {
super.satisfies0(required) || {
Expand Down Expand Up @@ -421,6 +421,9 @@ case class KeyGroupedPartitioning(
.distinct
.map(_.row)
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(expressions = newChildren)
}

object KeyGroupedPartitioning {
Expand Down Expand Up @@ -766,8 +769,8 @@ case class CoalescedHashShuffleSpec(
*
* @param partitioning key grouped partitioning
* @param distribution distribution
* @param joinKeyPosition position of join keys among cluster keys.
* This is set if joining on a subset of cluster keys is allowed.
* @param joinKeyPositions position of join keys among cluster keys.
* This is set if joining on a subset of cluster keys is allowed.
*/
case class KeyGroupedShuffleSpec(
partitioning: KeyGroupedPartitioning,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.connector

import java.sql.Timestamp
import java.util.Collections

import org.apache.spark.SparkConf
Expand Down Expand Up @@ -583,6 +584,55 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
}
}

test("SPARK-49205: KeyGroupedPartitioning should inherit HashPartitioningLike") {
val items_partitions = Array(days("arrive_time"))
createTable(items, itemsColumns, items_partitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " +
"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")

val purchases_partitions = Array(days("time"))
createTable(purchases, purchasesColumns, purchases_partitions)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(1, 44.0, cast('2020-01-15' as timestamp)), " +
"(1, 45.0, cast('2020-01-15' as timestamp)), " +
"(2, 11.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

val df = sql(
s"""
|SELECT x, count(*) FROM (
| SELECT /*+ broadcast(t2) */ arrive_time as x, * FROM testcat.ns.$items t1
| JOIN testcat.ns.$purchases t2 ON t1.arrive_time = t2.time
|)
|GROUP BY x
|""".stripMargin)
checkAnswer(df,
Seq(Row(Timestamp.valueOf("2020-01-01 00:00:00"), 6),
Row(Timestamp.valueOf("2020-01-15 00:00:00"), 2),
Row(Timestamp.valueOf("2020-02-01 00:00:00"), 1)))
assert(collectAllShuffles(df.queryExecution.executedPlan).isEmpty)

val df2 = sql(
s"""
|WITH t1 (SELECT * FROM testcat.ns.$items)
|SELECT x, count(*) FROM (
| SELECT /*+ broadcast(t2) */ t2.time as x FROM t1
| JOIN testcat.ns.$purchases t2 ON t1.arrive_time = t2.time
| JOIN t1 t3 ON t1.arrive_time = t3.arrive_time
|) GROUP BY x
|""".stripMargin)
checkAnswer(df2,
Seq(Row(Timestamp.valueOf("2020-01-01 00:00:00"), 18),
Row(Timestamp.valueOf("2020-01-15 00:00:00"), 2),
Row(Timestamp.valueOf("2020-02-01 00:00:00"), 1)))
assert(collectAllShuffles(df2.queryExecution.executedPlan).isEmpty)
}

test("SPARK-42038: partially clustered: with same partition keys and one side fully clustered") {
val items_partitions = Array(identity("id"))
createTable(items, itemsColumns, items_partitions)
Expand Down

0 comments on commit 3f3d024

Please sign in to comment.