Skip to content

Commit

Permalink
[SPARK-48700][SQL] Mode expression for complex types (all collations)
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Add support for complex types with subfields that are collated strings, for the mode operator.

### Why are the changes needed?

Full support for collations as per SPARK-48700

### Does this PR introduce _any_ user-facing change?

Yes.

### How was this patch tested?

Unit tests only, so far.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #47154 from GideonPotok/collationmodecomplex.

Lead-authored-by: Gideon P <[email protected]>
Co-authored-by: Gideon Potok <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
2 people authored and MaxGekk committed Oct 1, 2024
1 parent c0a1ea2 commit 97e9bb3
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 102 deletions.
10 changes: 10 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,11 @@
"Cannot process input data types for the expression: <expression>."
],
"subClass" : {
"BAD_INPUTS" : {
"message" : [
"The input data types to <functionName> must be valid, but found the input types <dataType>."
]
},
"MISMATCHED_TYPES" : {
"message" : [
"All input types must be the same except nullable, containsNull, valueContainsNull flags, but found the input types <inputTypes>."
Expand Down Expand Up @@ -1011,6 +1016,11 @@
"The input of <functionName> can't be <dataType> type data."
]
},
"UNSUPPORTED_MODE_DATA_TYPE" : {
"message" : [
"The <mode> does not support the <child> data type, because there is a \"MAP\" type with keys and/or values that have collated sub-fields."
]
},
"UNSUPPORTED_UDF_INPUT_TYPE" : {
"message" : [
"UDFs do not support '<dataType>' as an input data type."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
import org.apache.spark.sql.catalyst.expressions.Cast.toSQLExpr
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData, UnsafeRowUtils}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, GenericArrayData, UnsafeRowUtils}
import org.apache.spark.sql.errors.DataTypeErrors.{toSQLId, toSQLType}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, StringType}
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, MapType, StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.OpenHashMap

Expand All @@ -50,17 +53,20 @@ case class Mode(
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)

override def checkInputDataTypes(): TypeCheckResult = {
if (UnsafeRowUtils.isBinaryStable(child.dataType) || child.dataType.isInstanceOf[StringType]) {
// TODO: SPARK-49358: Mode expression for map type with collated fields
if (UnsafeRowUtils.isBinaryStable(child.dataType) ||
!child.dataType.existsRecursively(f => f.isInstanceOf[MapType] &&
!UnsafeRowUtils.isBinaryStable(f))) {
/*
* The Mode class uses collation awareness logic to handle string data.
* Complex types with collated fields are not yet supported.
* All complex types except MapType with collated fields are supported.
*/
// TODO: SPARK-48700: Mode expression for complex types (all collations)
super.checkInputDataTypes()
} else {
TypeCheckResult.TypeCheckFailure("The input to the function 'mode' was" +
" a type of binary-unstable type that is " +
s"not currently supported by ${prettyName}.")
TypeCheckResult.DataTypeMismatch("UNSUPPORTED_MODE_DATA_TYPE",
messageParameters =
Map("child" -> toSQLType(child.dataType),
"mode" -> toSQLId(prettyName)))
}
}

Expand All @@ -86,6 +92,54 @@ case class Mode(
buffer
}

private def getCollationAwareBuffer(
childDataType: DataType,
buffer: OpenHashMap[AnyRef, Long]): Iterable[(AnyRef, Long)] = {
def groupAndReduceBuffer(groupingFunction: AnyRef => _): Iterable[(AnyRef, Long)] = {
buffer.groupMapReduce(t =>
groupingFunction(t._1))(x => x)((x, y) => (x._1, x._2 + y._2)).values
}
def determineBufferingFunction(
childDataType: DataType): Option[AnyRef => _] = {
childDataType match {
case _ if UnsafeRowUtils.isBinaryStable(child.dataType) => None
case _ => Some(collationAwareTransform(_, childDataType))
}
}
determineBufferingFunction(childDataType).map(groupAndReduceBuffer).getOrElse(buffer)
}

protected[sql] def collationAwareTransform(data: AnyRef, dataType: DataType): AnyRef = {
dataType match {
case _ if UnsafeRowUtils.isBinaryStable(dataType) => data
case st: StructType =>
processStructTypeWithBuffer(data.asInstanceOf[InternalRow].toSeq(st).zip(st.fields))
case at: ArrayType => processArrayTypeWithBuffer(at, data.asInstanceOf[ArrayData])
case st: StringType =>
CollationFactory.getCollationKey(data.asInstanceOf[UTF8String], st.collationId)
case _ =>
throw new SparkIllegalArgumentException(
errorClass = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.BAD_INPUTS",
messageParameters = Map(
"expression" -> toSQLExpr(this),
"functionName" -> toSQLType(prettyName),
"dataType" -> toSQLType(child.dataType))
)
}
}

private def processStructTypeWithBuffer(
tuples: Seq[(Any, StructField)]): Seq[Any] = {
tuples.map(t => collationAwareTransform(t._1.asInstanceOf[AnyRef], t._2.dataType))
}

private def processArrayTypeWithBuffer(
a: ArrayType,
data: ArrayData): Seq[Any] = {
(0 until data.numElements()).map(i =>
collationAwareTransform(data.get(i, a.elementType), a.elementType))
}

override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
if (buffer.isEmpty) {
return null
Expand All @@ -102,17 +156,12 @@ case class Mode(
* to a single value (the sum of the counts), and finally reduces the groups to a single map.
*
* The new map is then used in the rest of the Mode evaluation logic.
*
* It is expected to work for all simple and complex types with
* collated fields, except for MapType (temporarily).
*/
val collationAwareBuffer = child.dataType match {
case c: StringType if
!CollationFactory.fetchCollation(c.collationId).supportsBinaryEquality =>
val collationId = c.collationId
val modeMap = buffer.toSeq.groupMapReduce {
case (k, _) => CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId)
}(x => x)((x, y) => (x._1, x._2 + y._2)).values
modeMap
case _ => buffer
}
val collationAwareBuffer = getCollationAwareBuffer(child.dataType, buffer)

reverseOpt.map { reverse =>
val defaultKeyOrdering = if (reverse) {
PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]].reverse
Expand Down
Loading

0 comments on commit 97e9bb3

Please sign in to comment.