Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-28067][SQL] Fix incorrect results for decimal aggregate sum by returning null on decimal overflow #27627

Closed
wants to merge 15 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -62,38 +62,74 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast

private lazy val sum = AttributeReference("sum", sumDataType)()

private lazy val isEmpty = AttributeReference("isEmpty", BooleanType, nullable = false)()

private lazy val zero = Literal.default(sumDataType)

override lazy val aggBufferAttributes = sum :: Nil
override lazy val aggBufferAttributes = resultType match {
case _: DecimalType => sum :: isEmpty :: Nil
case _ => sum :: Nil
}

override lazy val initialValues: Seq[Expression] = Seq(
/* sum = */ Literal.create(null, sumDataType)
)
override lazy val initialValues: Seq[Expression] = resultType match {
case _: DecimalType => Seq(Literal(null, resultType), Literal(true, BooleanType))
case _ => Seq(Literal(null, resultType))
}

override lazy val updateExpressions: Seq[Expression] = {
if (child.nullable) {
Seq(
/* sum = */
coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
)
val updateSumExpr = coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
resultType match {
case _: DecimalType =>
Seq(updateSumExpr, isEmpty && child.isNull)
case _ => Seq(updateSumExpr)
}
} else {
Seq(
/* sum = */
coalesce(sum, zero) + child.cast(sumDataType)
)
val updateSumExpr = coalesce(sum, zero) + child.cast(sumDataType)
resultType match {
case _: DecimalType =>
Seq(updateSumExpr, Literal(false, BooleanType))
case _ => Seq(updateSumExpr)
}
}
}

/**
* For decimal type:
* If isEmpty is false and if sum is null, then it means we have an overflow.
*
* update of the sum is as follows:
* Check if either portion of the left.sum or right.sum has overflowed
skambha marked this conversation as resolved.
Show resolved Hide resolved
* If it has, then the sum value will remain null.
* If it did not have overflow, then add the sum.left and sum.right and check for overflow.
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
skambha marked this conversation as resolved.
Show resolved Hide resolved
*
* isEmpty: Set to false if either one of the left or right is set to false. This
* means we have seen atleast a row that was not null.
skambha marked this conversation as resolved.
Show resolved Hide resolved
*/
override lazy val mergeExpressions: Seq[Expression] = {
Seq(
/* sum = */
coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
)
val mergeSumExpr = coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
resultType match {
case _: DecimalType =>
val inputOverflow = !isEmpty.right && sum.right.isNull
val bufferOverflow = !isEmpty.left && sum.left.isNull
Seq(
If(inputOverflow || bufferOverflow, Literal.create(null, sumDataType), mergeSumExpr),
isEmpty.left && isEmpty.right)
case _ => Seq(mergeSumExpr)
}
}

/**
* If the isEmpty is true, then it means either there are no rows, or all the rows were
* null, so the result will be null.
* If the isEmpty is false, then if sum is null that means an overflow has happened.
* So now, if ansi is enabled, then throw exception, if not then return null.
* If sum is not null, then return the sum.
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
*/
override lazy val evaluateExpression: Expression = resultType match {
case d: DecimalType => CheckOverflow(sum, d, !SQLConf.get.ansiEnabled)
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
case d: DecimalType =>
If(isEmpty, Literal.create(null, sumDataType),
CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled))
case _ => sum
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -144,3 +145,54 @@ case class CheckOverflow(

override def sql: String = child.sql
}

// A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`.
case class CheckOverflowInSum(
child: Expression,
dataType: DecimalType,
nullOnOverflow: Boolean) extends UnaryExpression {

override def nullable: Boolean = true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we override nullable with false as doGenCode() does?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the child can be nullable, the input value can be null. Making nullable to false in that case will not work, as it may result in npe. We can change the doGenCode() here to make the check for null for that, but since the nullSafeCodeGen in UnaryExpression already takes care of the if nullable checks, it seems there is no need to add if null checks here.


override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
if (nullOnOverflow) null else throw new ArithmeticException("Overflow in sum of decimals.")
} else {
value.asInstanceOf[Decimal].toPrecision(
dataType.precision,
dataType.scale,
Decimal.ROUND_HALF_UP,
nullOnOverflow)
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
val nullHandling = if (nullOnOverflow) {
""
} else {
s"""
|throw new ArithmeticException("Overflow in sum of decimals.");
|""".stripMargin
}
val code = code"""
|${childGen.code}
|boolean ${ev.isNull} = ${childGen.isNull};
|Decimal ${ev.value} = null;
|if (${childGen.isNull}) {
| $nullHandling
|} else {
| ${ev.value} = ${childGen.value}.toPrecision(
| ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow);
| ${ev.isNull} = ${ev.value} == null;
|}
|""".stripMargin

ev.copy(code = code)
}

override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)"

override def sql: String = child.sql
}
112 changes: 105 additions & 7 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,28 @@ class DataFrameSuite extends QueryTest
structDf.select(xxhash64($"a", $"record.*")))
}

private def assertDecimalSumOverflow(
df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = {
if (!ansiEnabled) {
try {
checkAnswer(df, expectedAnswer)
} catch {
case e: SparkException if e.getCause.isInstanceOf[ArithmeticException] =>
// This is an existing bug that we can write overflowed decimal to UnsafeRow but fail
// to read it.
assert(e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38"))
}
} else {
val e = intercept[SparkException] {
df.collect
}
assert(e.getCause.isInstanceOf[ArithmeticException])
assert(e.getCause.getMessage.contains("cannot be represented as Decimal") ||
e.getCause.getMessage.contains("Overflow in sum of decimals") ||
e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38"))
}
}

test("SPARK-28224: Aggregate sum big decimal overflow") {
val largeDecimals = spark.sparkContext.parallelize(
DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) ::
Expand All @@ -200,14 +222,90 @@ class DataFrameSuite extends QueryTest
Seq(true, false).foreach { ansiEnabled =>
withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
val structDf = largeDecimals.select("a").agg(sum("a"))
if (!ansiEnabled) {
checkAnswer(structDf, Row(null))
} else {
val e = intercept[SparkException] {
structDf.collect
assertDecimalSumOverflow(structDf, ansiEnabled, Row(null))
}
}
}

test("SPARK-28067: sum of null decimal values") {
Seq("true", "false").foreach { wholeStageEnabled =>
withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) {
Seq("true", "false").foreach { ansiEnabled =>
withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled)) {
val df = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d"))
checkAnswer(df.agg(sum($"d")), Row(null))
}
}
}
}
}

test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") {
Seq("true", "false").foreach { wholeStageEnabled =>
withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
val df0 = Seq(
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
val df1 = Seq(
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
val df = df0.union(df1)
val df2 = df.withColumnRenamed("decNum", "decNum2").
join(df, "intNum").agg(sum("decNum"))
skambha marked this conversation as resolved.
Show resolved Hide resolved

val expectedAnswer = Row(null)
assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer)

val decStr = "1" + "0" * 19
val d1 = spark.range(0, 12, 1, 1)
val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d"))
assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer)

val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d"))
assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer)

val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"),
lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd")
assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer)

val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d"))

val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")).
toDF("d")
assertDecimalSumOverflow(
nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer)

val df3 = Seq(
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("50000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")

val df4 = Seq(
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")

val df5 = Seq(
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("20000000000000000000"), 2)).toDF("decNum", "intNum")

val df6 = df3.union(df4).union(df5)
val df7 = df6.groupBy("intNum").agg(sum("decNum"), countDistinct("decNum")).
filter("intNum == 1")
assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2))
}
assert(e.getCause.getClass.equals(classOf[ArithmeticException]))
assert(e.getCause.getMessage.contains("cannot be represented as Decimal"))
}
}
}
Expand Down