Skip to content

Commit

Permalink
[SPARK-34837][SQL] Support ANSI SQL intervals by the aggregate functi…
Browse files Browse the repository at this point in the history
…on `avg`

### What changes were proposed in this pull request?
Extend the `Average` expression to support `DayTimeIntervalType` and `YearMonthIntervalType` added by #31614.

Note: the expressions can throw the overflow exception independently from the SQL config `spark.sql.ansi.enabled`. In this way, the modified expressions always behave in the ANSI mode for the intervals.

### Why are the changes needed?
Extend `org.apache.spark.sql.catalyst.expressions.aggregate.Average` to support `DayTimeIntervalType` and `YearMonthIntervalType`.

### Does this PR introduce _any_ user-facing change?
'No'.
Should not since new types have not been released yet.

### How was this patch tested?
Jenkins test

Closes #32229 from beliefer/SPARK-34837.

Authored-by: gengjiaan <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
beliefer authored and MaxGekk committed Apr 19, 2021
1 parent 70b606f commit 8dc455b
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit

override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType))

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function average")
TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "average")

override def nullable: Boolean = true

Expand All @@ -53,11 +54,15 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
private lazy val resultType = child.dataType match {
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 4, s + 4)
case _: YearMonthIntervalType => YearMonthIntervalType
case _: DayTimeIntervalType => DayTimeIntervalType
case _ => DoubleType
}

private lazy val sumDataType = child.dataType match {
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
case _: YearMonthIntervalType => YearMonthIntervalType
case _: DayTimeIntervalType => DayTimeIntervalType
case _ => DoubleType
}

Expand All @@ -82,6 +87,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
case _: DecimalType =>
DecimalPrecision.decimalAndDecimal(
Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
case _: YearMonthIntervalType => DivideYMInterval(sum, count)
case _: DayTimeIntervalType => DivideDTInterval(sum, count)
case _ =>
Divide(sum.cast(resultType), count.cast(resultType), failOnError = false)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -48,12 +49,8 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType))

override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
case other => TypeCheckResult.TypeCheckFailure(
s"function sum requires numeric or interval types, not ${other.catalogString}")
}
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "sum")

private lazy val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ object TypeUtils {
}
}

def checkForAnsiIntervalOrNumericType(
dt: DataType, funcName: String): TypeCheckResult = dt match {
case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
case other => TypeCheckResult.TypeCheckFailure(
s"function $funcName requires numeric or interval types, not ${other.catalogString}")
}

def getNumeric(t: DataType, exactNumericRequired: Boolean = false): Numeric[Any] = {
if (exactNumericRequired) {
t.asInstanceOf[NumericType].exactNumeric.asInstanceOf[Numeric[Any]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(Min(Symbol("mapField")), "min does not support ordering on type")
assertError(Max(Symbol("mapField")), "max does not support ordering on type")
assertError(Sum(Symbol("booleanField")), "function sum requires numeric or interval types")
assertError(Average(Symbol("booleanField")), "function average requires numeric type")
assertError(Average(Symbol("booleanField")),
"function average requires numeric or interval types")
}

test("check types for others") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,44 @@ class DataFrameAggregateSuite extends QueryTest
}
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
}

test("SPARK-34837: Support ANSI SQL intervals by the aggregate function `avg`") {
val df = Seq((1, Period.ofMonths(10), Duration.ofDays(10)),
(2, Period.ofMonths(1), Duration.ofDays(1)),
(2, null, null),
(3, Period.ofMonths(-3), Duration.ofDays(-6)),
(3, Period.ofMonths(21), Duration.ofDays(-5)))
.toDF("class", "year-month", "day-time")

val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)),
(Period.ofMonths(10), Duration.ofDays(10)))
.toDF("year-month", "day-time")

val avgDF = df.select(avg($"year-month"), avg($"day-time"))
checkAnswer(avgDF, Row(Period.ofMonths(7), Duration.ofDays(0)))
assert(find(avgDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
assert(avgDF.schema == StructType(Seq(StructField("avg(year-month)", YearMonthIntervalType),
StructField("avg(day-time)", DayTimeIntervalType))))

val avgDF2 = df.groupBy($"class").agg(avg($"year-month"), avg($"day-time"))
checkAnswer(avgDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) ::
Row(2, Period.ofMonths(1), Duration.ofDays(1)) ::
Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) ::Nil)
assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
assert(avgDF2.schema == StructType(Seq(StructField("class", IntegerType, false),
StructField("avg(year-month)", YearMonthIntervalType),
StructField("avg(day-time)", DayTimeIntervalType))))

val error = intercept[SparkException] {
checkAnswer(df2.select(avg($"year-month")), Nil)
}
assert(error.toString contains "java.lang.ArithmeticException: integer overflow")

val error2 = intercept[SparkException] {
checkAnswer(df2.select(avg($"day-time")), Nil)
}
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
}
}

case class B(c: Option[Double])
Expand Down

0 comments on commit 8dc455b

Please sign in to comment.