Skip to content

Commit

Permalink
[SPARK-47094][SQL] SPJ : Dynamically rebalance number of buckets when…
Browse files Browse the repository at this point in the history
… they are not equal (apache#1946)

-- Allow SPJ between 'compatible' bucket funtions
-- Add a mechanism to define 'reducible' functions, one function whose output can be 'reduced' to another for all inputs.

  ### Why are the changes needed?
-- SPJ currently applies only if the partition transform expressions on both sides are identifical.

  ### Does this PR introduce _any_ user-facing change?
No

  ### How was this patch tested?
Added new tests in KeyGroupedPartitioningSuite

  ### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#45267 from szehon-ho/spj-uneven-buckets.

Authored-by: Szehon Ho <[email protected]>

Signed-off-by: Chao Sun <[email protected]>
Co-authored-by: Szehon Ho <[email protected]>
  • Loading branch information
2 people authored and GitHub Enterprise committed Apr 17, 2024
1 parent 80f9b60 commit ae0c1f6
Show file tree
Hide file tree
Showing 9 changed files with 817 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connector.catalog.functions;

import org.apache.spark.annotation.Evolving;

/**
* A 'reducer' for output of user-defined functions.
*
* @see ReducibleFunction
*
* A user defined function f_source(x) is 'reducible' on another user_defined function
* f_target(x) if
* <ul>
* <li> There exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for
* all input x, or </li>
* <li> More generally, there exists reducer functions r1(x) and r2(x) such that
* r1(f_source(x)) = r2(f_target(x)) for all input x. </li>
* </ul>
*
* @param <I> reducer input type
* @param <O> reducer output type
* @since 4.0.0
*/
@Evolving
public interface Reducer<I, O> {
O reduce(I arg);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connector.catalog.functions;

import org.apache.spark.annotation.Evolving;

/**
* Base class for user-defined functions that can be 'reduced' on another function.
*
* A function f_source(x) is 'reducible' on another function f_target(x) if
* <ul>
* <li> There exists a reducer function r(x) such that r(f_source(x)) = f_target(x)
* for all input x, or </li>
* <li> More generally, there exists reducer functions r1(x) and r2(x) such that
* r1(f_source(x)) = r2(f_target(x)) for all input x. </li>
* </ul>
* <p>
* Examples:
* <ul>
* <li>Bucket functions where one side has reducer
* <ul>
* <li>f_source(x) = bucket(4, x)</li>
* <li>f_target(x) = bucket(2, x)</li>
* <li>r(x) = x % 2</li>
* </ul>
*
* <li>Bucket functions where both sides have reducer
* <ul>
* <li>f_source(x) = bucket(16, x)</li>
* <li>f_target(x) = bucket(12, x)</li>
* <li>r1(x) = x % 4</li>
* <li>r2(x) = x % 4</li>
* </ul>
*
* <li>Date functions
* <ul>
* <li>f_source(x) = days(x)</li>
* <li>f_target(x) = hours(x)</li>
* <li>r(x) = x / 24</li>
* </ul>
* </ul>
* @param <I> reducer function input type
* @param <O> reducer function output type
* @since 4.0.0
*/
@Evolving
public interface ReducibleFunction<I, O> {

/**
* This method is for the bucket function.
*
* If this bucket function is 'reducible' on another bucket function,
* return the {@link Reducer} function.
* <p>
* For example, to return reducer for reducing f_source = bucket(4, x) on f_target = bucket(2, x)
* <ul>
* <li>thisBucketFunction = bucket</li>
* <li>thisNumBuckets = 4</li>
* <li>otherBucketFunction = bucket</li>
* <li>otherNumBuckets = 2</li>
* </ul>
*
* @param thisNumBuckets parameter for this function
* @param otherBucketFunction the other parameterized function
* @param otherNumBuckets parameter for the other function
* @return a reduction function if it is reducible, null if not
*/
default Reducer<I, O> reducer(
int thisNumBuckets,
ReducibleFunction<?, ?> otherBucketFunction,
int otherNumBuckets) {
throw new UnsupportedOperationException();
}

/**
* This method is for all other functions.
*
* If this function is 'reducible' on another function, return the {@link Reducer} function.
* <p>
* Example of reducing f_source = days(x) on f_target = hours(x)
* <ul>
* <li>thisFunction = days</li>
* <li>otherFunction = hours</li>
* </ul>
*
* @param otherFunction the other function
* @return a reduction function if it is reducible, null if not.
*/
default Reducer<I, O> reducer(ReducibleFunction<?, ?> otherFunction) {
throw new UnsupportedOperationException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.connector.catalog.functions.BoundFunction
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction}
import org.apache.spark.sql.types.DataType

/**
Expand Down Expand Up @@ -54,6 +54,61 @@ case class TransformExpression(
false
}

/**
* Whether this [[TransformExpression]]'s function is compatible with the `other`
* [[TransformExpression]]'s function.
*
* This is true if both are instances of [[ReducibleFunction]] and there exists a [[Reducer]] r(x)
* such that r(t1(x)) = t2(x), or r(t2(x)) = t1(x), for all input x.
*
* @param other the transform expression to compare to
* @return true if compatible, false if not
*/
def isCompatible(other: TransformExpression): Boolean = {
if (isSameFunction(other)) {
true
} else {
(function, other.function) match {
case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) =>
val thisReducer = reducer(f, numBucketsOpt, o, other.numBucketsOpt)
val otherReducer = reducer(o, other.numBucketsOpt, f, numBucketsOpt)
thisReducer.isDefined || otherReducer.isDefined
case _ => false
}
}
}

/**
* Return a [[Reducer]] for this transform expression on another
* on the transform expression.
* <p>
* A [[Reducer]] exists for a transform expression function if it is
* 'reducible' on the other expression function.
* <p>
* @return reducer function or None if not reducible on the other transform expression
*/
def reducers(other: TransformExpression): Option[Reducer[_, _]] = {
(function, other.function) match {
case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) =>
reducer(e1, numBucketsOpt, e2, other.numBucketsOpt)
case _ => None
}
}

// Return a Reducer for a reducible function on another reducible function
private def reducer(
thisFunction: ReducibleFunction[_, _],
thisNumBucketsOpt: Option[Int],
otherFunction: ReducibleFunction[_, _],
otherNumBucketsOpt: Option[Int]): Option[Reducer[_, _]] = {
val res = (thisNumBucketsOpt, otherNumBucketsOpt) match {
case (Some(numBuckets), Some(otherNumBuckets)) =>
thisFunction.reducer(numBuckets, otherFunction, otherNumBuckets)
case _ => thisFunction.reducer(otherFunction)
}
Option(res)
}

override def dataType: DataType = function.resultType()

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
import org.apache.spark.sql.connector.catalog.functions.Reducer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, IntegerType}

Expand Down Expand Up @@ -832,10 +833,42 @@ case class KeyGroupedShuffleSpec(
(left, right) match {
case (_: LeafExpression, _: LeafExpression) => true
case (left: TransformExpression, right: TransformExpression) =>
left.isSameFunction(right)
if (SQLConf.get.v2BucketingPushPartValuesEnabled &&
!SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
SQLConf.get.v2BucketingAllowCompatibleTransforms) {
left.isCompatible(right)
} else {
left.isSameFunction(right)
}
case _ => false
}

/**
* Return a set of [[Reducer]] for the partition expressions of this shuffle spec,
* on the partition expressions of another shuffle spec.
* <p>
* A [[Reducer]] exists for a partition expression function of this shuffle spec if it is
* 'reducible' on the corresponding partition expression function of the other shuffle spec.
* <p>
* If a value is returned, there must be one [[Reducer]] per partition expression.
* A None value in the set indicates that the particular partition expression is not reducible
* on the corresponding expression on the other shuffle spec.
* <p>
* Returning none also indicates that none of the partition expressions can be reduced on the
* corresponding expression on the other shuffle spec.
*
* @param other other key-grouped shuffle spec
*/
def reducers(other: KeyGroupedShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = {
val results = partitioning.expressions.zip(other.partitioning.expressions).map {
case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2)
case (_, _) => None
}

// optimize to not return a value, if none of the partition expressions are reducible
if (results.forall(p => p.isEmpty)) None else Some(results)
}

override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
// Only support partition expressions are AttributeReference for now
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])
Expand All @@ -845,6 +878,21 @@ case class KeyGroupedShuffleSpec(
}
}

object KeyGroupedShuffleSpec {
def reducePartitionValue(
row: InternalRow,
expressions: Seq[Expression],
reducers: Seq[Option[Reducer[_, _]]]):
InternalRowComparableWrapper = {
val partitionVals = row.toSeq(expressions.map(_.dataType))
val reducedRow = partitionVals.zip(reducers).map{
case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v)
case (v, _) => v
}.toArray
InternalRowComparableWrapper(new GenericInternalRow(reducedRow), expressions)
}
}

case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec {
override def isCompatibleWith(other: ShuffleSpec): Boolean = {
specs.exists(_.isCompatibleWith(other))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS =
buildConf("spark.sql.sources.v2.bucketing.allowCompatibleTransforms.enabled")
.doc("Whether to allow storage-partition join in the case where the partition transforms " +
"are compatible but not identical. This config requires both " +
s"${V2_BUCKETING_ENABLED.key} and ${V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key} to be " +
s"enabled and ${V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
"to be disabled."
)
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
.version("2.4.0")
Expand Down Expand Up @@ -4798,6 +4810,9 @@ class SQLConf extends Serializable with Logging {
def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean =
getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)

def v2BucketingAllowCompatibleTransforms: Boolean =
getConf(SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS)

def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper}
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.catalog.functions.Reducer
import org.apache.spark.sql.connector.read._

/**
Expand Down Expand Up @@ -161,6 +162,18 @@ case class BatchScanExec(
(groupedParts, expressions)
}

// Also re-group the partitions if we are reducing compatible partition expressions
val finalGroupedPartitions = spjParams.reducers match {
case Some(reducers) =>
val result = groupedPartitions.groupBy { case (row, _) =>
KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers)
}.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq
val rowOrdering = RowOrdering.createNaturalAscendingOrdering(
partExpressions.map(_.dataType))
result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1))
case _ => groupedPartitions
}

// When partially clustered, the input partitions are not grouped by partition
// values. Here we'll need to check `commonPartitionValues` and decide how to group
// and replicate splits within a partition.
Expand All @@ -171,7 +184,7 @@ case class BatchScanExec(
.get
.map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2))
.toMap
val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) =>
val nestGroupedPartitions = finalGroupedPartitions.map { case (partValue, splits) =>
// `commonPartValuesMap` should contain the part value since it's the super set.
val numSplits = commonPartValuesMap
.get(InternalRowComparableWrapper(partValue, partExpressions))
Expand Down Expand Up @@ -204,7 +217,7 @@ case class BatchScanExec(
} else {
// either `commonPartitionValues` is not defined, or it is defined but
// `applyPartialClustering` is false.
val partitionMapping = groupedPartitions.map { case (partValue, splits) =>
val partitionMapping = finalGroupedPartitions.map { case (partValue, splits) =>
InternalRowComparableWrapper(partValue, partExpressions) -> splits
}.toMap

Expand Down Expand Up @@ -256,6 +269,7 @@ case class StoragePartitionJoinParams(
keyGroupedPartitioning: Option[Seq[Expression]] = None,
joinKeyPositions: Option[Seq[Int]] = None,
commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
reducers: Option[Seq[Option[Reducer[_, _]]]] = None,
applyPartialClustering: Boolean = false,
replicatePartitions: Boolean = false) {
override def equals(other: Any): Boolean = other match {
Expand Down
Loading

0 comments on commit ae0c1f6

Please sign in to comment.