Skip to content

Commit

Permalink
[SPARK-48572][SQL] Fix DateSub, DateAdd, WindowTime, TimeWindow and S…
Browse files Browse the repository at this point in the history
…essionWindow expressions

### What changes were proposed in this pull request?
Fix for listed expressions.

### Why are the changes needed?
These expressions are found to be faulty when working with collations.

### Does this PR introduce _any_ user-facing change?
Yes, it fixes expressions that are not working with collations.

### How was this patch tested?
Tests added for expressions to reproduce errors.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #46943 from mihailom-db/SPARK-48572.

Authored-by: Mihailo Milosevic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
mihailom-db authored and cloud-fan committed Jun 17, 2024
1 parent e00d26f commit d3455df
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,12 @@ object CollationTypeCasts extends TypeCoercionRule {

case otherExpr @ (
_: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least |
_: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask | _: StringReplace |
_: StringTranslate | _: StringTrim | _: StringTrimLeft | _: StringTrimRight) =>
_: Coalesce | _: ArrayContains | _: ArrayExcept | _: ConcatWs | _: Mask | _: StringReplace |
_: StringTranslate | _: StringTrim | _: StringTrimLeft | _: StringTrimRight |
_: ArrayIntersect | _: ArrayPosition | _: ArrayRemove | _: ArrayUnion | _: ArraysOverlap |
_: Contains | _: EndsWith | _: EqualNullSafe | _: EqualTo | _: FindInSet | _: GreaterThan |
_: GreaterThanOrEqual | _: LessThan | _: LessThanOrEqual | _: StartsWith | _: StringInstr |
_: ToNumber | _: TryToNumber) =>
val newChildren = collateToSingleType(otherExpr.children)
otherExpr.withNewChildren(newChildren)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ object TimeWindow {
* that we can use `window` in SQL.
*/
def parseExpression(expr: Expression): Long = expr match {
case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString)
case NonNullLiteral(s, _: StringType) => getIntervalInMicroSeconds(s.toString)
case IntegerLiteral(i) => i.toLong
case NonNullLiteral(l, LongType) => l.toString.toLong
case _ => throw QueryCompilationErrors.invalidLiteralForWindowDurationError()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty: Boolean)

private val defaultElementType: DataType = {
if (useStringTypeWhenEmpty) {
StringType
SQLConf.get.defaultStringType
} else {
NullType
}
Expand Down Expand Up @@ -354,7 +354,7 @@ case class MapFromArrays(left: Expression, right: Expression)
case object NamePlaceholder extends LeafExpression with Unevaluable {
override lazy val resolved: Boolean = false
override def nullable: Boolean = false
override def dataType: DataType = StringType
override def dataType: DataType = SQLConf.get.defaultStringType
override def prettyName: String = "NamePlaceholder"
override def toString: String = prettyName
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,18 @@ private[sql] abstract class DataTypeExpression(val dataType: DataType) {
}

private[sql] case object BooleanTypeExpression extends DataTypeExpression(BooleanType)
private[sql] case object StringTypeExpression extends DataTypeExpression(StringType)
private[sql] case object StringTypeExpression {
/**
* Enables matching against StringType for expressions:
* {{{
* case Cast(child @ StringType(collationId), NumericType) =>
* ...
* }}}
*/
def unapply(e: Expression): Boolean = {
e.dataType.isInstanceOf[StringType]
}
}
private[sql] case object TimestampTypeExpression extends DataTypeExpression(TimestampType)
private[sql] case object DateTypeExpression extends DataTypeExpression(DateType)
private[sql] case object ByteTypeExpression extends DataTypeExpression(ByteType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1905,6 +1905,84 @@ class CollationSQLExpressionsSuite
})
}

test("DateAdd expression with collation") {
// Supported collations
testSuppCollations.foreach(collationName => {
val query = s"""select date_add(collate('2016-07-30', '${collationName}'), 1)"""
// Result & data type check
val testQuery = sql(query)
val dataType = DateType
val expectedResult = "2016-07-31"
assert(testQuery.schema.fields.head.dataType.sameType(dataType))
checkAnswer(testQuery, Row(Date.valueOf(expectedResult)))
})
}

test("DateSub expression with collation") {
// Supported collations
testSuppCollations.foreach(collationName => {
val query = s"""select date_sub(collate('2016-07-30', '${collationName}'), 1)"""
// Result & data type check
val testQuery = sql(query)
val dataType = DateType
val expectedResult = "2016-07-29"
assert(testQuery.schema.fields.head.dataType.sameType(dataType))
checkAnswer(testQuery, Row(Date.valueOf(expectedResult)))
})
}

test("WindowTime and TimeWindow expressions with collation") {
// Supported collations
testSuppCollations.foreach(collationName => {
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) {
val query =
s"""SELECT window_time(window)
| FROM (SELECT a, window, count(*) as cnt FROM VALUES
|('A1', '2021-01-01 00:00:00'),
|('A1', '2021-01-01 00:04:30'),
|('A1', '2021-01-01 00:06:00'),
|('A2', '2021-01-01 00:01:00') AS tab(a, b)
|GROUP by a, window(b, '5 minutes') ORDER BY a, window.start);
|""".stripMargin
// Result & data type check
val testQuery = sql(query)
val dataType = TimestampType
val expectedResults =
Seq("2021-01-01 00:04:59.999999",
"2021-01-01 00:09:59.999999",
"2021-01-01 00:04:59.999999")
assert(testQuery.schema.fields.head.dataType.sameType(dataType))
checkAnswer(testQuery, expectedResults.map(ts => Row(Timestamp.valueOf(ts))))
}
})
}

test("SessionWindow expressions with collation") {
// Supported collations
testSuppCollations.foreach(collationName => {
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) {
val query =
s"""SELECT count(*) as cnt
| FROM VALUES
|('A1', '2021-01-01 00:00:00'),
|('A1', '2021-01-01 00:04:30'),
|('A1', '2021-01-01 00:10:00'),
|('A2', '2021-01-01 00:01:00'),
|('A2', '2021-01-01 00:04:30') AS tab(a, b)
|GROUP BY a,
|session_window(b, CASE WHEN a = 'A1' THEN '5 minutes' ELSE '1 minutes' END)
|ORDER BY a, session_window.start;
|""".stripMargin
// Result & data type check
val testQuery = sql(query)
val dataType = LongType
val expectedResults = Seq(2, 1, 1, 1)
assert(testQuery.schema.fields.head.dataType.sameType(dataType))
checkAnswer(testQuery, expectedResults.map(Row(_)))
}
})
}

test("ConvertTimezone expression with collation") {
// Supported collations
testSuppCollations.foreach(collationName => {
Expand Down

0 comments on commit d3455df

Please sign in to comment.