diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index 3d3f0edeadd0a..240abdd468c02 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -37,5 +37,6 @@ public interface ReducibleFunction extends ScalarFunction { * @param otherArgument argument for other function instance * @return a reduction function if it is reducible, none if not */ - Option> reducer(ReducibleFunction other, Option thisArgument, Option otherArgument); + Option> reducer(ReducibleFunction other, Option thisArgument, + Option otherArgument); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index cc5810993a9fb..9b53ce54b456e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -69,8 +69,7 @@ case class TransformExpression( true } else { (function, other.function) match { - case (f: ReducibleFunction[Any, Any] @unchecked, - o: ReducibleFunction[Any, Any] @unchecked) => + case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) => val reducer = f.reducer(o, numBucketsOpt, other.numBucketsOpt) val otherReducer = o.reducer(f, other.numBucketsOpt, numBucketsOpt) reducer.isDefined || otherReducer.isDefined diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 3c1cc5e2e9e92..cb505779ece8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -651,7 +651,7 @@ trait ShuffleSpec { * Returning none also indicates that none of the partition expressions can be reduced on the * corresponding expression on the other shuffle spec. */ - def reducers(spec: ShuffleSpec): Option[Seq[Option[Reducer[Any]]]] = None + def reducers(spec: ShuffleSpec): Option[Seq[Option[Reducer[_]]]] = None } case object SinglePartitionShuffleSpec extends ShuffleSpec { @@ -854,17 +854,16 @@ case class KeyGroupedShuffleSpec( KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues) } - override def reducers(other: ShuffleSpec): Option[Seq[Option[Reducer[Any]]]] = { + override def reducers(other: ShuffleSpec): Option[Seq[Option[Reducer[_]]]] = { other match { case otherSpec: KeyGroupedShuffleSpec => val results = partitioning.expressions.zip(otherSpec.partitioning.expressions).map { case (e1: TransformExpression, e2: TransformExpression) - if e1.function.isInstanceOf[ReducibleFunction[Any, Any]@unchecked] - && e2.function.isInstanceOf[ReducibleFunction[Any, Any]@unchecked] => - e1.function.asInstanceOf[ReducibleFunction[Any, Any]].reducer( - e2.function.asInstanceOf[ReducibleFunction[Any, Any]], - e1.numBucketsOpt.map(a => a.asInstanceOf[Any]), - e2.numBucketsOpt.map(a => a.asInstanceOf[Any])) + if e1.function.isInstanceOf[ReducibleFunction[_, _]] + && e2.function.isInstanceOf[ReducibleFunction[_, _]] => + e1.function.asInstanceOf[ReducibleFunction[_, _]].reducer( + e2.function.asInstanceOf[ReducibleFunction[_, _]], + e1.numBucketsOpt, e2.numBucketsOpt) case (_, _) => None } @@ -892,11 +891,11 @@ case class KeyGroupedShuffleSpec( object KeyGroupedShuffleSpec { def reducePartitionValue(row: InternalRow, expressions: Seq[Expression], - reducers: Seq[Option[Reducer[Any]]]): + reducers: Seq[Option[Reducer[_]]]): InternalRowComparableWrapper = { val partitionVals = row.toSeq(expressions.map(_.dataType)) val reducedRow = partitionVals.zip(reducers).map{ - case (v, Some(reducer)) => reducer.reduce(v) + case (v, Some(reducer: Reducer[Any])) => reducer.reduce(v) case (v, _) => v }.toArray InternalRowComparableWrapper(new GenericInternalRow(reducedRow), expressions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 955d13c8f9be5..43c2d299ad30b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -272,7 +272,7 @@ case class StoragePartitionJoinParams( keyGroupedPartitioning: Option[Seq[Expression]] = None, joinKeyPositions: Option[Seq[Int]] = None, commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, - reducers: Option[Seq[Option[Reducer[Any]]]] = None, + reducers: Option[Seq[Option[Reducer[_]]]] = None, applyPartialClustering: Boolean = false, replicatePartitions: Boolean = false) { override def equals(other: Any): Boolean = other match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 1c4c797861b2d..2cbc8143f5650 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -550,7 +550,7 @@ case class EnsureRequirements( plan: SparkPlan, values: Seq[(InternalRow, Int)], joinKeyPositions: Option[Seq[Int]], - reducers: Option[Seq[Option[Reducer[Any]]]], + reducers: Option[Seq[Option[Reducer[_]]]], applyPartialClustering: Boolean, replicatePartitions: Boolean): SparkPlan = plan match { case scan: BatchScanExec => @@ -570,7 +570,7 @@ case class EnsureRequirements( private def reduceCommonPartValues(commonPartValues: Seq[(InternalRow, Int)], expressions: Seq[Expression], - reducers: Option[Seq[Option[Reducer[Any]]]]) = { + reducers: Option[Seq[Option[Reducer[_]]]]) = { reducers match { case Some(reducers) => commonPartValues.groupBy { case (row, _) => KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers)