diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java new file mode 100644 index 0000000000000..561d66092d641 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.annotation.Evolving; + +/** + * A 'reducer' for output of user-defined functions. + * + * @see ReducibleFunction + * + * A user defined function f_source(x) is 'reducible' on another user_defined function + * f_target(x) if + * + * + * @param reducer input type + * @param reducer output type + * @since 4.0.0 + */ +@Evolving +public interface Reducer { + O reduce(I arg); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java new file mode 100644 index 0000000000000..ef1a14e50cdad --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.annotation.Evolving; + +/** + * Base class for user-defined functions that can be 'reduced' on another function. + * + * A function f_source(x) is 'reducible' on another function f_target(x) if + *
    + *
  • There exists a reducer function r(x) such that r(f_source(x)) = f_target(x) + * for all input x, or
  • + *
  • More generally, there exists reducer functions r1(x) and r2(x) such that + * r1(f_source(x)) = r2(f_target(x)) for all input x.
  • + *
+ *

+ * Examples: + *

    + *
  • Bucket functions where one side has reducer + *
      + *
    • f_source(x) = bucket(4, x)
    • + *
    • f_target(x) = bucket(2, x)
    • + *
    • r(x) = x % 2
    • + *
    + * + *
  • Bucket functions where both sides have reducer + *
      + *
    • f_source(x) = bucket(16, x)
    • + *
    • f_target(x) = bucket(12, x)
    • + *
    • r1(x) = x % 4
    • + *
    • r2(x) = x % 4
    • + *
    + * + *
  • Date functions + *
      + *
    • f_source(x) = days(x)
    • + *
    • f_target(x) = hours(x)
    • + *
    • r(x) = x / 24
    • + *
    + *
+ * @param reducer function input type + * @param reducer function output type + * @since 4.0.0 + */ +@Evolving +public interface ReducibleFunction { + + /** + * This method is for the bucket function. + * + * If this bucket function is 'reducible' on another bucket function, + * return the {@link Reducer} function. + *

+ * For example, to return reducer for reducing f_source = bucket(4, x) on f_target = bucket(2, x) + *

    + *
  • thisBucketFunction = bucket
  • + *
  • thisNumBuckets = 4
  • + *
  • otherBucketFunction = bucket
  • + *
  • otherNumBuckets = 2
  • + *
+ * + * @param thisNumBuckets parameter for this function + * @param otherBucketFunction the other parameterized function + * @param otherNumBuckets parameter for the other function + * @return a reduction function if it is reducible, null if not + */ + default Reducer reducer( + int thisNumBuckets, + ReducibleFunction otherBucketFunction, + int otherNumBuckets) { + throw new UnsupportedOperationException(); + } + + /** + * This method is for all other functions. + * + * If this function is 'reducible' on another function, return the {@link Reducer} function. + *

+ * Example of reducing f_source = days(x) on f_target = hours(x) + *

    + *
  • thisFunction = days
  • + *
  • otherFunction = hours
  • + *
+ * + * @param otherFunction the other function + * @return a reduction function if it is reducible, null if not. + */ + default Reducer reducer(ReducibleFunction otherFunction) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index 8412de554b711..d37c9d9f6452a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.connector.catalog.functions.BoundFunction +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction} import org.apache.spark.sql.types.DataType /** @@ -54,6 +54,61 @@ case class TransformExpression( false } + /** + * Whether this [[TransformExpression]]'s function is compatible with the `other` + * [[TransformExpression]]'s function. + * + * This is true if both are instances of [[ReducibleFunction]] and there exists a [[Reducer]] r(x) + * such that r(t1(x)) = t2(x), or r(t2(x)) = t1(x), for all input x. + * + * @param other the transform expression to compare to + * @return true if compatible, false if not + */ + def isCompatible(other: TransformExpression): Boolean = { + if (isSameFunction(other)) { + true + } else { + (function, other.function) match { + case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) => + val thisReducer = reducer(f, numBucketsOpt, o, other.numBucketsOpt) + val otherReducer = reducer(o, other.numBucketsOpt, f, numBucketsOpt) + thisReducer.isDefined || otherReducer.isDefined + case _ => false + } + } + } + + /** + * Return a [[Reducer]] for this transform expression on another + * on the transform expression. + *

+ * A [[Reducer]] exists for a transform expression function if it is + * 'reducible' on the other expression function. + *

+ * @return reducer function or None if not reducible on the other transform expression + */ + def reducers(other: TransformExpression): Option[Reducer[_, _]] = { + (function, other.function) match { + case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) => + reducer(e1, numBucketsOpt, e2, other.numBucketsOpt) + case _ => None + } + } + + // Return a Reducer for a reducible function on another reducible function + private def reducer( + thisFunction: ReducibleFunction[_, _], + thisNumBucketsOpt: Option[Int], + otherFunction: ReducibleFunction[_, _], + otherNumBucketsOpt: Option[Int]): Option[Reducer[_, _]] = { + val res = (thisNumBucketsOpt, otherNumBucketsOpt) match { + case (Some(numBuckets), Some(otherNumBuckets)) => + thisFunction.reducer(numBuckets, otherFunction, otherNumBuckets) + case _ => thisFunction.reducer(otherFunction) + } + Option(res) + } + override def dataType: DataType = function.resultType() override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index a070c843411ed..5d7e1cfe34e2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -23,6 +23,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper +import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, IntegerType} @@ -832,10 +833,42 @@ case class KeyGroupedShuffleSpec( (left, right) match { case (_: LeafExpression, _: LeafExpression) => true case (left: TransformExpression, right: TransformExpression) => - left.isSameFunction(right) + if (SQLConf.get.v2BucketingPushPartValuesEnabled && + !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled && + SQLConf.get.v2BucketingAllowCompatibleTransforms) { + left.isCompatible(right) + } else { + left.isSameFunction(right) + } case _ => false } + /** + * Return a set of [[Reducer]] for the partition expressions of this shuffle spec, + * on the partition expressions of another shuffle spec. + *

+ * A [[Reducer]] exists for a partition expression function of this shuffle spec if it is + * 'reducible' on the corresponding partition expression function of the other shuffle spec. + *

+ * If a value is returned, there must be one [[Reducer]] per partition expression. + * A None value in the set indicates that the particular partition expression is not reducible + * on the corresponding expression on the other shuffle spec. + *

+ * Returning none also indicates that none of the partition expressions can be reduced on the + * corresponding expression on the other shuffle spec. + * + * @param other other key-grouped shuffle spec + */ + def reducers(other: KeyGroupedShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = { + val results = partitioning.expressions.zip(other.partitioning.expressions).map { + case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2) + case (_, _) => None + } + + // optimize to not return a value, if none of the partition expressions are reducible + if (results.forall(p => p.isEmpty)) None else Some(results) + } + override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && // Only support partition expressions are AttributeReference for now partitioning.expressions.forall(_.isInstanceOf[AttributeReference]) @@ -845,6 +878,21 @@ case class KeyGroupedShuffleSpec( } } +object KeyGroupedShuffleSpec { + def reducePartitionValue( + row: InternalRow, + expressions: Seq[Expression], + reducers: Seq[Option[Reducer[_, _]]]): + InternalRowComparableWrapper = { + val partitionVals = row.toSeq(expressions.map(_.dataType)) + val reducedRow = partitionVals.zip(reducers).map{ + case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v) + case (v, _) => v + }.toArray + InternalRowComparableWrapper(new GenericInternalRow(reducedRow), expressions) + } +} + case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec { override def isCompatibleWith(other: ShuffleSpec): Boolean = { specs.exists(_.isCompatibleWith(other)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ad6c7ad1701b6..23d3c93812b31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1504,6 +1504,18 @@ object SQLConf { .booleanConf .createWithDefault(false) + val V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS = + buildConf("spark.sql.sources.v2.bucketing.allowCompatibleTransforms.enabled") + .doc("Whether to allow storage-partition join in the case where the partition transforms " + + "are compatible but not identical. This config requires both " + + s"${V2_BUCKETING_ENABLED.key} and ${V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key} to be " + + s"enabled and ${V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " + + "to be disabled." + ) + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") .doc("The maximum number of buckets allowed.") .version("2.4.0") @@ -4798,6 +4810,9 @@ class SQLConf extends Serializable with Logging { def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean = getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS) + def v2BucketingAllowCompatibleTransforms: Boolean = + getConf(SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS) + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index afcc762e636a3..4427eda04ab78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -24,9 +24,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition} +import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.connector.read._ /** @@ -161,6 +162,18 @@ case class BatchScanExec( (groupedParts, expressions) } + // Also re-group the partitions if we are reducing compatible partition expressions + val finalGroupedPartitions = spjParams.reducers match { + case Some(reducers) => + val result = groupedPartitions.groupBy { case (row, _) => + KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers) + }.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq + val rowOrdering = RowOrdering.createNaturalAscendingOrdering( + partExpressions.map(_.dataType)) + result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) + case _ => groupedPartitions + } + // When partially clustered, the input partitions are not grouped by partition // values. Here we'll need to check `commonPartitionValues` and decide how to group // and replicate splits within a partition. @@ -171,7 +184,7 @@ case class BatchScanExec( .get .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2)) .toMap - val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) => + val nestGroupedPartitions = finalGroupedPartitions.map { case (partValue, splits) => // `commonPartValuesMap` should contain the part value since it's the super set. val numSplits = commonPartValuesMap .get(InternalRowComparableWrapper(partValue, partExpressions)) @@ -204,7 +217,7 @@ case class BatchScanExec( } else { // either `commonPartitionValues` is not defined, or it is defined but // `applyPartialClustering` is false. - val partitionMapping = groupedPartitions.map { case (partValue, splits) => + val partitionMapping = finalGroupedPartitions.map { case (partValue, splits) => InternalRowComparableWrapper(partValue, partExpressions) -> splits }.toMap @@ -256,6 +269,7 @@ case class StoragePartitionJoinParams( keyGroupedPartitioning: Option[Seq[Expression]] = None, joinKeyPositions: Option[Seq[Int]] = None, commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, + reducers: Option[Seq[Option[Reducer[_, _]]]] = None, applyPartialClustering: Boolean = false, replicatePartitions: Boolean = false) { override def equals(other: Any): Boolean = other match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 8552c950f6776..681c39aafd83c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper +import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} @@ -505,11 +506,28 @@ case class EnsureRequirements( } } - // Now we need to push-down the common partition key to the scan in each child - newLeft = populatePartitionValues(left, mergedPartValues, leftSpec.joinKeyPositions, - applyPartialClustering, replicateLeftSide) - newRight = populatePartitionValues(right, mergedPartValues, rightSpec.joinKeyPositions, - applyPartialClustering, replicateRightSide) + // in case of compatible but not identical partition expressions, we apply 'reduce' + // transforms to group one side's partitions as well as the common partition values + val leftReducers = leftSpec.reducers(rightSpec) + val rightReducers = rightSpec.reducers(leftSpec) + + if (leftReducers.isDefined || rightReducers.isDefined) { + mergedPartValues = reduceCommonPartValues(mergedPartValues, + leftSpec.partitioning.expressions, + leftReducers) + mergedPartValues = reduceCommonPartValues(mergedPartValues, + rightSpec.partitioning.expressions, + rightReducers) + val rowOrdering = RowOrdering + .createNaturalAscendingOrdering(partitionExprs.map(_.dataType)) + mergedPartValues = mergedPartValues.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) + } + + // Now we need to push-down the common partition information to the scan in each child + newLeft = populateCommonPartitionInfo(left, mergedPartValues, leftSpec.joinKeyPositions, + leftReducers, applyPartialClustering, replicateLeftSide) + newRight = populateCommonPartitionInfo(right, mergedPartValues, rightSpec.joinKeyPositions, + rightReducers, applyPartialClustering, replicateRightSide) } } @@ -527,11 +545,12 @@ case class EnsureRequirements( joinType == LeftAnti || joinType == LeftOuter } - // Populate the common partition values down to the scan nodes - private def populatePartitionValues( + // Populate the common partition information down to the scan nodes + private def populateCommonPartitionInfo( plan: SparkPlan, values: Seq[(InternalRow, Int)], joinKeyPositions: Option[Seq[Int]], + reducers: Option[Seq[Option[Reducer[_, _]]]], applyPartialClustering: Boolean, replicatePartitions: Boolean): SparkPlan = plan match { case scan: BatchScanExec => @@ -539,13 +558,26 @@ case class EnsureRequirements( spjParams = scan.spjParams.copy( commonPartitionValues = Some(values), joinKeyPositions = joinKeyPositions, + reducers = reducers, applyPartialClustering = applyPartialClustering, replicatePartitions = replicatePartitions ) ) case node => - node.mapChildren(child => populatePartitionValues( - child, values, joinKeyPositions, applyPartialClustering, replicatePartitions)) + node.mapChildren(child => populateCommonPartitionInfo( + child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions)) + } + + private def reduceCommonPartValues( + commonPartValues: Seq[(InternalRow, Int)], + expressions: Seq[Expression], + reducers: Option[Seq[Option[Reducer[_, _]]]]) = { + reducers match { + case Some(reducers) => commonPartValues.groupBy { case (row, _) => + KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers) + }.map{ case(wrapper, splits) => (wrapper.row, splits.map(_._2).sum) }.toSeq + case _ => commonPartValues + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 8829464d0a73d..1cb498c6414e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -79,11 +79,17 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Collections.emptyMap[String, String] } private val table: String = "tbl" + private val schema = new StructType() .add("id", IntegerType) .add("data", StringType) .add("ts", TimestampType) + private val schema2 = new StructType() + .add("store_id", IntegerType) + .add("dept_id", IntegerType) + .add("data", StringType) + test("clustered distribution: output partitioning should be KeyGroupedPartitioning") { val partitions: Array[Transform] = Array(Expressions.years("ts")) @@ -1406,6 +1412,470 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-47094: Support compatible buckets") { + val table1 = "tab1e1" + val table2 = "table2" + + Seq( + ((2, 4), (4, 2)), + ((4, 2), (2, 4)), + ((2, 2), (4, 6)), + ((6, 2), (2, 2))).foreach { + case ((table1buckets1, table1buckets2), (table2buckets1, table2buckets2)) => + catalog.clearTables() + + val partition1 = Array(bucket(table1buckets1, "store_id"), + bucket(table1buckets2, "dept_id")) + val partition2 = Array(bucket(table2buckets1, "store_id"), + bucket(table2buckets2, "dept_id")) + + Seq((table1, partition1), (table2, partition2)).foreach { case (tab, part) => + createTable(tab, schema2, part) + val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " + + "(0, 0, 'aa'), " + + "(0, 0, 'ab'), " + // duplicate partition key + "(0, 1, 'ac'), " + + "(0, 2, 'ad'), " + + "(0, 3, 'ae'), " + + "(0, 4, 'af'), " + + "(0, 5, 'ag'), " + + "(1, 0, 'ah'), " + + "(1, 0, 'ai'), " + // duplicate partition key + "(1, 1, 'aj'), " + + "(1, 2, 'ak'), " + + "(1, 3, 'al'), " + + "(1, 4, 'am'), " + + "(1, 5, 'an'), " + + "(2, 0, 'ao'), " + + "(2, 0, 'ap'), " + // duplicate partition key + "(2, 1, 'aq'), " + + "(2, 2, 'ar'), " + + "(2, 3, 'as'), " + + "(2, 4, 'at'), " + + "(2, 5, 'au'), " + + "(3, 0, 'av'), " + + "(3, 0, 'aw'), " + // duplicate partition key + "(3, 1, 'ax'), " + + "(3, 2, 'ay'), " + + "(3, 3, 'az'), " + + "(3, 4, 'ba'), " + + "(3, 5, 'bb'), " + + "(4, 0, 'bc'), " + + "(4, 0, 'bd'), " + // duplicate partition key + "(4, 1, 'be'), " + + "(4, 2, 'bf'), " + + "(4, 3, 'bg'), " + + "(4, 4, 'bh'), " + + "(4, 5, 'bi'), " + + "(5, 0, 'bj'), " + + "(5, 0, 'bk'), " + // duplicate partition key + "(5, 1, 'bl'), " + + "(5, 2, 'bm'), " + + "(5, 3, 'bn'), " + + "(5, 4, 'bo'), " + + "(5, 5, 'bp')" + + // additional unmatched partitions to test push down + val finalStr = if (tab == table1) { + insertStr ++ ", (8, 0, 'xa'), (8, 8, 'xx')" + } else { + insertStr ++ ", (9, 0, 'ya'), (9, 9, 'yy')" + } + + sql(finalStr) + } + + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString, + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |select t1.store_id, t1.dept_id, t1.data, t2.data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id + |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "SPJ should be triggered") + + val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. + partitions.length) + val expectedBuckets = Math.min(table1buckets1, table2buckets1) * + Math.min(table1buckets2, table2buckets2) + assert(scans == Seq(expectedBuckets, expectedBuckets)) + + checkAnswer(df, Seq( + Row(0, 0, "aa", "aa"), + Row(0, 0, "aa", "ab"), + Row(0, 0, "ab", "aa"), + Row(0, 0, "ab", "ab"), + Row(0, 1, "ac", "ac"), + Row(0, 2, "ad", "ad"), + Row(0, 3, "ae", "ae"), + Row(0, 4, "af", "af"), + Row(0, 5, "ag", "ag"), + Row(1, 0, "ah", "ah"), + Row(1, 0, "ah", "ai"), + Row(1, 0, "ai", "ah"), + Row(1, 0, "ai", "ai"), + Row(1, 1, "aj", "aj"), + Row(1, 2, "ak", "ak"), + Row(1, 3, "al", "al"), + Row(1, 4, "am", "am"), + Row(1, 5, "an", "an"), + Row(2, 0, "ao", "ao"), + Row(2, 0, "ao", "ap"), + Row(2, 0, "ap", "ao"), + Row(2, 0, "ap", "ap"), + Row(2, 1, "aq", "aq"), + Row(2, 2, "ar", "ar"), + Row(2, 3, "as", "as"), + Row(2, 4, "at", "at"), + Row(2, 5, "au", "au"), + Row(3, 0, "av", "av"), + Row(3, 0, "av", "aw"), + Row(3, 0, "aw", "av"), + Row(3, 0, "aw", "aw"), + Row(3, 1, "ax", "ax"), + Row(3, 2, "ay", "ay"), + Row(3, 3, "az", "az"), + Row(3, 4, "ba", "ba"), + Row(3, 5, "bb", "bb"), + Row(4, 0, "bc", "bc"), + Row(4, 0, "bc", "bd"), + Row(4, 0, "bd", "bc"), + Row(4, 0, "bd", "bd"), + Row(4, 1, "be", "be"), + Row(4, 2, "bf", "bf"), + Row(4, 3, "bg", "bg"), + Row(4, 4, "bh", "bh"), + Row(4, 5, "bi", "bi"), + Row(5, 0, "bj", "bj"), + Row(5, 0, "bj", "bk"), + Row(5, 0, "bk", "bj"), + Row(5, 0, "bk", "bk"), + Row(5, 1, "bl", "bl"), + Row(5, 2, "bm", "bm"), + Row(5, 3, "bn", "bn"), + Row(5, 4, "bo", "bo"), + Row(5, 5, "bp", "bp") + )) + } + } + } + } + + test("SPARK-47094: Support compatible buckets with common divisor") { + val table1 = "tab1e1" + val table2 = "table2" + + Seq( + ((6, 4), (4, 6)), + ((6, 6), (4, 4)), + ((4, 4), (6, 6)), + ((4, 6), (6, 4))).foreach { + case ((table1buckets1, table1buckets2), (table2buckets1, table2buckets2)) => + catalog.clearTables() + + val partition1 = Array(bucket(table1buckets1, "store_id"), + bucket(table1buckets2, "dept_id")) + val partition2 = Array(bucket(table2buckets1, "store_id"), + bucket(table2buckets2, "dept_id")) + + Seq((table1, partition1), (table2, partition2)).foreach { case (tab, part) => + createTable(tab, schema2, part) + val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " + + "(0, 0, 'aa'), " + + "(0, 0, 'ab'), " + // duplicate partition key + "(0, 1, 'ac'), " + + "(0, 2, 'ad'), " + + "(0, 3, 'ae'), " + + "(0, 4, 'af'), " + + "(0, 5, 'ag'), " + + "(1, 0, 'ah'), " + + "(1, 0, 'ai'), " + // duplicate partition key + "(1, 1, 'aj'), " + + "(1, 2, 'ak'), " + + "(1, 3, 'al'), " + + "(1, 4, 'am'), " + + "(1, 5, 'an'), " + + "(2, 0, 'ao'), " + + "(2, 0, 'ap'), " + // duplicate partition key + "(2, 1, 'aq'), " + + "(2, 2, 'ar'), " + + "(2, 3, 'as'), " + + "(2, 4, 'at'), " + + "(2, 5, 'au'), " + + "(3, 0, 'av'), " + + "(3, 0, 'aw'), " + // duplicate partition key + "(3, 1, 'ax'), " + + "(3, 2, 'ay'), " + + "(3, 3, 'az'), " + + "(3, 4, 'ba'), " + + "(3, 5, 'bb'), " + + "(4, 0, 'bc'), " + + "(4, 0, 'bd'), " + // duplicate partition key + "(4, 1, 'be'), " + + "(4, 2, 'bf'), " + + "(4, 3, 'bg'), " + + "(4, 4, 'bh'), " + + "(4, 5, 'bi'), " + + "(5, 0, 'bj'), " + + "(5, 0, 'bk'), " + // duplicate partition key + "(5, 1, 'bl'), " + + "(5, 2, 'bm'), " + + "(5, 3, 'bn'), " + + "(5, 4, 'bo'), " + + "(5, 5, 'bp')" + + // additional unmatched partitions to test push down + val finalStr = if (tab == table1) { + insertStr ++ ", (8, 0, 'xa'), (8, 8, 'xx')" + } else { + insertStr ++ ", (9, 0, 'ya'), (9, 9, 'yy')" + } + + sql(finalStr) + } + + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString, + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |select t1.store_id, t1.dept_id, t1.data, t2.data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id + |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "SPJ should be triggered") + + val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. + partitions.length) + + def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt + val expectedBuckets = gcd(table1buckets1, table2buckets1) * + gcd(table1buckets2, table2buckets2) + assert(scans == Seq(expectedBuckets, expectedBuckets)) + + checkAnswer(df, Seq( + Row(0, 0, "aa", "aa"), + Row(0, 0, "aa", "ab"), + Row(0, 0, "ab", "aa"), + Row(0, 0, "ab", "ab"), + Row(0, 1, "ac", "ac"), + Row(0, 2, "ad", "ad"), + Row(0, 3, "ae", "ae"), + Row(0, 4, "af", "af"), + Row(0, 5, "ag", "ag"), + Row(1, 0, "ah", "ah"), + Row(1, 0, "ah", "ai"), + Row(1, 0, "ai", "ah"), + Row(1, 0, "ai", "ai"), + Row(1, 1, "aj", "aj"), + Row(1, 2, "ak", "ak"), + Row(1, 3, "al", "al"), + Row(1, 4, "am", "am"), + Row(1, 5, "an", "an"), + Row(2, 0, "ao", "ao"), + Row(2, 0, "ao", "ap"), + Row(2, 0, "ap", "ao"), + Row(2, 0, "ap", "ap"), + Row(2, 1, "aq", "aq"), + Row(2, 2, "ar", "ar"), + Row(2, 3, "as", "as"), + Row(2, 4, "at", "at"), + Row(2, 5, "au", "au"), + Row(3, 0, "av", "av"), + Row(3, 0, "av", "aw"), + Row(3, 0, "aw", "av"), + Row(3, 0, "aw", "aw"), + Row(3, 1, "ax", "ax"), + Row(3, 2, "ay", "ay"), + Row(3, 3, "az", "az"), + Row(3, 4, "ba", "ba"), + Row(3, 5, "bb", "bb"), + Row(4, 0, "bc", "bc"), + Row(4, 0, "bc", "bd"), + Row(4, 0, "bd", "bc"), + Row(4, 0, "bd", "bd"), + Row(4, 1, "be", "be"), + Row(4, 2, "bf", "bf"), + Row(4, 3, "bg", "bg"), + Row(4, 4, "bh", "bh"), + Row(4, 5, "bi", "bi"), + Row(5, 0, "bj", "bj"), + Row(5, 0, "bj", "bk"), + Row(5, 0, "bk", "bj"), + Row(5, 0, "bk", "bk"), + Row(5, 1, "bl", "bl"), + Row(5, 2, "bm", "bm"), + Row(5, 3, "bn", "bn"), + Row(5, 4, "bo", "bo"), + Row(5, 5, "bp", "bp") + )) + } + } + } + } + + test("SPARK-47094: Support compatible buckets with less join keys than partition keys") { + val table1 = "tab1e1" + val table2 = "table2" + + Seq((2, 4), (4, 2), (2, 6), (6, 2)).foreach { + case (table1buckets, table2buckets) => + catalog.clearTables() + + val partition1 = Array(identity("data"), + bucket(table1buckets, "dept_id")) + val partition2 = Array(bucket(3, "store_id"), + bucket(table2buckets, "dept_id")) + + createTable(table1, schema2, partition1) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(0, 0, 'aa'), " + + "(1, 0, 'ab'), " + + "(2, 1, 'ac'), " + + "(3, 2, 'ad'), " + + "(4, 3, 'ae'), " + + "(5, 4, 'af'), " + + "(6, 5, 'ag'), " + + + // value without other side match + "(6, 6, 'xx')" + ) + + createTable(table2, schema2, partition2) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(6, 0, '01'), " + + "(5, 1, '02'), " + // duplicate partition key + "(5, 1, '03'), " + + "(4, 2, '04'), " + + "(3, 3, '05'), " + + "(2, 4, '06'), " + + "(1, 5, '07'), " + + + // value without other side match + "(7, 7, '99')" + ) + + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |select t1.store_id, t2.store_id, t1.dept_id, t2.dept_id, t1.data, t2.data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.dept_id = t2.dept_id + |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "SPJ should be triggered") + + val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. + partitions.length) + + val expectedBuckets = Math.min(table1buckets, table2buckets) + + assert(scans == Seq(expectedBuckets, expectedBuckets)) + + checkAnswer(df, Seq( + Row(0, 6, 0, 0, "aa", "01"), + Row(1, 6, 0, 0, "ab", "01"), + Row(2, 5, 1, 1, "ac", "02"), + Row(2, 5, 1, 1, "ac", "03"), + Row(3, 4, 2, 2, "ad", "04"), + Row(4, 3, 3, 3, "ae", "05"), + Row(5, 2, 4, 4, "af", "06"), + Row(6, 1, 5, 5, "ag", "07") + )) + } + } + } + + test("SPARK-47094: Compatible buckets does not support SPJ with " + + "push-down values or partially-clustered") { + val table1 = "tab1e1" + val table2 = "table2" + + val partition1 = Array(bucket(4, "store_id"), + bucket(2, "dept_id")) + val partition2 = Array(bucket(2, "store_id"), + bucket(2, "dept_id")) + + createTable(table1, schema2, partition1) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(0, 0, 'aa'), " + + "(1, 1, 'bb'), " + + "(2, 2, 'cc')" + ) + + createTable(table2, schema2, partition2) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(0, 0, 'aa'), " + + "(1, 1, 'bb'), " + + "(2, 2, 'cc')" + ) + + Seq(true, false).foreach{ allowPushDown => + Seq(true, false).foreach{ partiallyClustered => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> allowPushDown.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |select t1.store_id, t1.store_id, t1.dept_id, t2.dept_id, t1.data, t2.data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id + |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. + partitions.length) + + (allowPushDown, partiallyClustered) match { + case (true, false) => + assert(shuffles.isEmpty, "SPJ should be triggered") + assert(scans == Seq(2, 2)) + case (_, _) => + assert(shuffles.nonEmpty, "SPJ should not be triggered") + assert(scans == Seq(3, 2)) + } + + checkAnswer(df, Seq( + Row(0, 0, 0, 0, "aa", "aa"), + Row(1, 1, 1, 1, "bb", "bb"), + Row(2, 2, 2, 2, "cc", "cc") + )) + } + } + } + } + test("SPARK-44647: test join key is the second cluster key") { val table1 = "tab1e1" val table2 = "table2" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 61895d49c4a2a..5cdb900901056 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -76,7 +76,7 @@ object UnboundBucketFunction extends UnboundFunction { override def name(): String = "bucket" } -object BucketFunction extends ScalarFunction[Int] { +object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { override def inputTypes(): Array[DataType] = Array(IntegerType, LongType) override def resultType(): DataType = IntegerType override def name(): String = "bucket" @@ -85,6 +85,26 @@ object BucketFunction extends ScalarFunction[Int] { override def produceResult(input: InternalRow): Int = { (input.getLong(1) % input.getInt(0)).toInt } + + override def reducer( + thisNumBuckets: Int, + otherFunc: ReducibleFunction[_, _], + otherNumBuckets: Int): Reducer[Int, Int] = { + + if (otherFunc == BucketFunction) { + val gcd = this.gcd(thisNumBuckets, otherNumBuckets) + if (gcd != thisNumBuckets) { + return BucketReducer(thisNumBuckets, gcd) + } + } + null + } + + private def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt +} + +case class BucketReducer(thisNumBuckets: Int, divisor: Int) extends Reducer[Int, Int] { + override def reduce(bucket: Int): Int = bucket % divisor } object UnboundStringSelfFunction extends UnboundFunction {