Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-28067][SQL] Fix incorrect results for decimal aggregate sum by returning null on decimal overflow #27627

Closed
wants to merge 15 commits into from

Conversation

skambha
Copy link
Contributor

@skambha skambha commented Feb 18, 2020

What changes were proposed in this pull request?

JIRA SPARK-28067: Wrong results are returned for aggregate sum with decimals with whole stage codegen enabled

Repro:
WholeStage enabled enabled -> Wrong results
WholeStage disabled -> Returns exception Decimal precision 39 exceeds max precision 38

Issues:

  1. Wrong results are returned which is bad
  2. Inconsistency between whole stage enabled and disabled.

Cause:
Sum does not take care of possibility of overflow for the intermediate steps. ie the updateExpressions and mergeExpressions.

This PR makes the following changes:

  • Add changes to check if overflow occurs for decimal in aggregate Sum and if there is an overflow, it will return null for the Sum operation when spark.sql.ansi.enabled is false.
  • When spark.sql.ansi.enabled is true, then the sum operation will return an exception if an overflow occurs for the decimal operation in Sum.
  • This is keeping it consistent with the behavior defined in spark.sql.ansi.enabled property

Before the fix: Scenario 1: - WRONG RESULTS

scala> val df = Seq(
     |  (BigDecimal("10000000000000000000"), 1),
     |  (BigDecimal("10000000000000000000"), 1),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
df: org.apache.spark.sql.DataFrame = [decNum: decimal(38,18), intNum: int]

scala> val df2 = df.withColumnRenamed("decNum", "decNum2").join(df, "intNum").agg(sum("decNum"))
df2: org.apache.spark.sql.DataFrame = [sum(decNum): decimal(38,18)]

scala> df2.show(40,false)
+---------------------------------------+                                       
|sum(decNum)                            |
+---------------------------------------+
|20000000000000000000.000000000000000000|
+---------------------------------------+

--
Before fix: Scenario2: Setting spark.sql.ansi.enabled to true - WRONG RESULTS

scala> spark.conf.set("spark.sql.ansi.enabled", "true")

scala> val df = Seq(
     |  (BigDecimal("10000000000000000000"), 1),
     |  (BigDecimal("10000000000000000000"), 1),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
df: org.apache.spark.sql.DataFrame = [decNum: decimal(38,18), intNum: int]

scala> val df2 = df.withColumnRenamed("decNum", "decNum2").join(df, "intNum").agg(sum("decNum"))
df2: org.apache.spark.sql.DataFrame = [sum(decNum): decimal(38,18)]

scala> df2.show(40,false)
+---------------------------------------+
|sum(decNum)                            |
+---------------------------------------+
|20000000000000000000.000000000000000000|
+---------------------------------------+


After the fix: Scenario1:

scala> val df = Seq(
     |  (BigDecimal("10000000000000000000"), 1),
     |  (BigDecimal("10000000000000000000"), 1),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
df: org.apache.spark.sql.DataFrame = [decNum: decimal(38,18), intNum: int]

scala> val df2 = df.withColumnRenamed("decNum", "decNum2").join(df, "intNum").agg(sum("decNum"))
df2: org.apache.spark.sql.DataFrame = [sum(decNum): decimal(38,18)]

scala>  df2.show(40,false)
+-----------+                                                                   
|sum(decNum)|
+-----------+
|null       |
+-----------+

After fix: Scenario2: Setting the spark.sql.ansi.enabled to true:

scala> spark.conf.set("spark.sql.ansi.enabled", "true")

scala> val df = Seq(
     |  (BigDecimal("10000000000000000000"), 1),
     |  (BigDecimal("10000000000000000000"), 1),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2),
     |  (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
df: org.apache.spark.sql.DataFrame = [decNum: decimal(38,18), intNum: int]

scala> val df2 = df.withColumnRenamed("decNum", "decNum2").join(df, "intNum").agg(sum("decNum"))
df2: org.apache.spark.sql.DataFrame = [sum(decNum): decimal(38,18)]

scala>  df2.show(40,false)
20/02/18 10:57:43 ERROR Executor: Exception in task 5.0 in stage 4.0 (TID 30)
java.lang.ArithmeticException: Decimal(expanded,100000000000000000000.000000000000000000,39,18}) cannot be represented as Decimal(38, 18).

Why are the changes needed?

The changes are needed in order to fix the wrong results that are returned for decimal aggregate sum.

Does this PR introduce any user-facing change?

User would see wrong results on aggregate sum that involved decimal overflow prior to this change, but now the user will see null. But if user enables the spark.sql.ansi.enabled flag to true, then the user will see an exception and not incorrect results.

How was this patch tested?

New test has been added and existing tests for sql, catalyst and hive suites were run ok.

@skambha
Copy link
Contributor Author

skambha commented Feb 18, 2020

Please see some notes in this JIRA for the two approaches. This is a implementation for approach 2 fix. I'm marking this as WIP in order to get comments.

if (child.nullable) {
Seq(
/* sum = */
coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in the Sum are mostly to check if overflow has occurred when we do the different additions in the updateExpressions and mergeExpressions. The actual addition operations are all the same. Reading the diff may not show that easily so wanted to make a note here on that.


override lazy val initialValues: Seq[Expression] = Seq(
/* sum = */ Literal.create(null, sumDataType)
/* sum = */ Literal.create(null, sumDataType),
/* overflow = */ Literal.create(false, BooleanType)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We keep track of overflow using this aggBufferAttributes - overflow to know if any of the intermediate add operations in updateExpressions and/or mergeExpressions overflow'd. If the overflow is true and if spark.sql.ansi.enabled flag is false, then we return null for the sum operation in evaluateExpression.

@HyukjinKwon
Copy link
Member

cc @mgaido91

@cloud-fan
Copy link
Contributor

Sum does not take care of possibility of overflow for the intermediate steps. ie the updateExpressions and mergeExpressions.

I'm a little confused. These expressions are used in non-whole-stage-codegen as well, why only whole-stage-codegen has the problem?


override def dataType: DataType = BooleanType

override def nullable: Boolean = true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we override nullable with false as doGenCode() does?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the child can be nullable, the input value can be null. Making nullable to false in that case will not work, as it may result in npe. We can change the doGenCode() here to make the check for null for that, but since the nullSafeCodeGen in UnaryExpression already takes care of the if nullable checks, it seems there is no need to add if null checks here.

@kiszk
Copy link
Member

kiszk commented Feb 24, 2020

@skambha Should we review #27629 instead of this PR?

@skambha
Copy link
Contributor Author

skambha commented Feb 26, 2020

Thanks @cloud-fan and @kiszk for your comments.

These expressions are used in non-whole-stage-codegen as well, why only whole-stage-codegen has the problem?

In case of the whole stage disabled, the exception for the error comes from here.

AggregationIterator —> .. JoinedRow —> UnsafeRow -> Decimal.set where it checks if the value can fit within the precision and scale.

Caused by: java.lang.ArithmeticException: Decimal precision 39 exceeds max precision 38
  at org.apache.spark.sql.types.Decimal.set(Decimal.scala:122)
  at org.apache.spark.sql.types.Decimal$.apply(Decimal.scala:574)
  at org.apache.spark.sql.types.Decimal.apply(Decimal.scala)
  at org.apache.spark.sql.catalyst.expressions.UnsafeRow.getDecimal(UnsafeRow.java:385)
  at org.apache.spark.sql.catalyst.expressions.JoinedRow.getDecimal(JoinedRow.scala:95)
  at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown Source)
  at org.apache.spark.sql.execution.aggregate.AggregationIterator.$anonfun$generateProcessRow$7(AggregationIterator.scala:209)
  at org.apache.spark.sql.execution.aggregate.AggregationIterator.$anonfun$generateProcessRow$7$adapted(AggregationIterator.scala:207)
  at org.apache.spark.sql.execution.aggregate.TungstenAggregationIterator.processInputs(TungstenAggregationIterator.scala:187)
  at org.apache.spark.sql.execution.aggregate.TungstenAggregationIterator.<init>(TungstenAggregationIterator.scala:362)
  at org.apache.spark.sql.execution.aggregate.HashAggregateExec.$anonfun$doExecute$2(HashAggregateExec.scala:136)
  at org.apache.spark.sql.execution.aggregate.HashAggregateExec.$anonfun$doExecute$2$adapted(HashAggregateExec.scala:111)
  at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndex$2(RDD.scala:889)
  at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndex$2$adapted(RDD.scala:889)
  at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
  at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
  at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
  at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
  at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:349)
  at org.apache.spark.rdd.RDD.iterator(RDD.scala:313)
  at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
  at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
  at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
  at org.apache.spark.scheduler.Task.run(Task.scala:127)
  at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:444)
  at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
  at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:447)
  at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
  at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
  at java.lang.Thread.run(Thread.java:748)

———

In case of the whole stage, here is the codegen that is generated for it (spark today, so no changes from this pr).

a) codegen agg sum - all default (ansi mode false)
b) codegen agg sum - with ansi mode true

In case of the whole stage codegen, you can see the decimal ‘+’ expressions will at some point be larger than what can be contained in dec 38,18 but it gets written out as null. This messes up the end result of the sum and you get wrong results.
The decimal values computed from the + expressions are written using the UnsafeRowWriter.write https://github.com/apache/spark/blob/master/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java#L184

Here is a snippet highlighting the behavior observed for the usecase in this issue: https://github.com/skambha/notes/blob/master/UnsafeRowWriterTestSnippet
The relevant code is here: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java#L202
This is triggered in the whole stage codegen path.

@skambha
Copy link
Contributor Author

skambha commented Feb 26, 2020

@skambha Should we review #27629 instead of this PR?

Thanks for checking. I was hoping to get some feedback on the preferred approach on the JIRA. So in the meantime, I have these 2 pr open. The other PR throws exception if there is overflow. This pr is to honor the ansi flag/null behavior and follows the approach in SPARK-28224. Once we make a decision I can close one out.

@cloud-fan
Copy link
Contributor

OK to test

@cloud-fan
Copy link
Contributor

In case of the whole stage codegen, you can see the decimal ‘+’ expressions will at some point be larger than what can be contained in dec 38,18 but it gets written out as null. This messes up the end result of the sum and you get wrong results.

Still a bit confused. If it gets written out as null, then null + decimal always return null and the final result is null?

@skambha
Copy link
Contributor Author

skambha commented Mar 3, 2020

In case of the whole stage codegen, you can see the decimal ‘+’ expressions will at some point be larger than what can be contained in dec 38,18 but it gets written out as null. This messes up the end result of the sum and you get wrong results.

Still a bit confused. If it gets written out as null, then null + decimal always return null and the final result is null?

So if we look into the aggregate Sum, we have coalesce in updateExpressions and mergeExpressions, so it is not purely only a null + decimal only expression. For e.g, in updateExpressions if the intermediate sum becomes null because of overflow, then in the next iteration of the updateExpressions, coalesce will be used to make the sum become 0 for that and the sum will be updated to 0 + decimal.
https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L76

 override lazy val updateExpressions: Seq[Expression]
...
        /* sum = */
        coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)

Please let me know if this helps.

@cloud-fan
Copy link
Contributor

I tried some queries locally but can't reproduce. Is it only a problem with a join?

@cloud-fan
Copy link
Contributor

find a way to reproduce without join

scala> val decimalStr = "1" + "0" * 19
decimalStr: String = 10000000000000000000

scala> val df = spark.range(0, 12, 1, 1)
df: org.apache.spark.sql.Dataset[Long] = [id: bigint]

scala> df.select(expr(s"cast('$decimalStr' as decimal (38, 18)) as d")).agg(sum($"d")).show
// This is correct
+------+
|sum(d)|
+------+
|  null|
+------+

scala> val df = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
df: org.apache.spark.sql.Dataset[Long] = [id: bigint]

scala> df.select(expr(s"cast('$decimalStr' as decimal (38, 18)) as d")).agg(sum($"d")).show
// This is wrong
+--------------------+
|              sum(d)|
+--------------------+
|10000000000000000...|
+--------------------+

I think the root cause is, sum in partial aggregate overflows and write null to the unsafe row. sum in final aggregate reads null from the unsafe row and mistakenly think it's caused by empty data and convert it to 0.

We should create a DecimalSum, which use 2 buffer attributes: sum and isEmpty. Then in final aggregate we can check the isEmpty flag to konw if the null is caused by overflow or empty data.

@cloud-fan
Copy link
Contributor

cloud-fan commented Mar 5, 2020

cc @viirya @maropu @dongjoon-hyun as well

@dongjoon-hyun
Copy link
Member

Thank you for the example and analysis, @cloud-fan .

@maropu
Copy link
Member

maropu commented Mar 5, 2020

For integral sum (e.g., int/long), overflow can happen in partial aggregate sides (via Math.addExact). We don't need to follow the behaviour in decimal sum, too, for consistency?

scala> sql("SET spark.sql.ansi.enabled=true")
res39: org.apache.spark.sql.DataFrame = [key: string, value: string]

scala> spark.table("t").printSchema
root
 |-- v: long (nullable = true)

scala> sql("select * from t").show()
+-------------------+
|                  v|
+-------------------+
|9223372036854775807|
|                  1|
+-------------------+


scala> sql("select sum(*) from t").show()
// Throws an exception
java.lang.ArithmeticException: long overflow
	at java.lang.Math.addExact(Math.java:809)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.agg_doConsume_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.agg_doAggregateWithoutKey_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:729)
	at ...

scala> sql("SET spark.sql.ansi.enabled=false")
scala> sql("select sum(*) from t").show()
// Wrong result
+--------------------+
|              sum(v)|
+--------------------+
|-9223372036854775808|
+--------------------+

@cloud-fan
Copy link
Contributor

Ideally, under ANSI mode, decimal sum should also fail if overflows. It's hard to fail at the partial aggregate side, as we don't have a chance to check overflow before shuffling the aggregate buffer row. We can fail at the final aggregate side: If the value is null and isEmpty is false, fail with the overflow exception.

@viirya
Copy link
Member

viirya commented Mar 5, 2020

Based on previous explanation, looks like the overflow happens during partial aggregate ( updateExpressions and mergeExpressions)? At final aggregate, is a check for null useful for the reported case here?

@cloud-fan
Copy link
Contributor

looks like the overflow happens during partial aggregate ( updateExpressions and mergeExpressions)

partial aggregate does updateExpressions and produces buffer, final aggregate does mergeExpressions to merge buffers. The problem here is, mergeExpressions treats null as 0.

@viirya
Copy link
Member

viirya commented Mar 5, 2020

For e.g, in updateExpressions if the intermediate sum becomes null because of overflow, then in the next iteration of the updateExpressions, coalesce will be used to make the sum become 0 for that and the sum will be updated to 0 + decimal.

Sounds like not just a problem of mergeExpressions?

@cloud-fan
Copy link
Contributor

Note: the semantic of most aggregate functions are skipping nulls. So sum will skip null inputs, and overflow doesn't return null (also true for decimal), so updateExpressions is fine.

@skambha
Copy link
Contributor Author

skambha commented Mar 5, 2020

In the repro (that has the join and then the agg), we also see the null being written out via the UnsafeRowWriter that is mentioned in my earlier comment with codegen details here.#27627 (comment)

Codegen and the plan details from #27627 (comment) here:
https://github.com/skambha/notes/blob/master/spark28067_ansifalse_wholestagetrue.txt

scala> df2.queryExecution.debug.codegen
Found 2 WholeStageCodegen subtrees.
== Subtree 1 / 2 (maxMethodCodeSize:173; maxConstantPoolSize:145(0.22% used); numInnerClasses:0) ==
*(2) HashAggregate(keys=[], functions=[sum(decNum#14)], output=[sum(decNum)#23])
+- Exchange SinglePartition, true, [id=#77]
   +- *(1) HashAggregate(keys=[], functions=[partial_sum(decNum#14)], output=[sum#29])
      +- *(1) Project [decNum#14]
         +- *(1) BroadcastHashJoin [intNum#8], [intNum#15], Inner, BuildLeft
            :- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint))), [id=#65]
            :  +- LocalTableScan [intNum#8]
            +- *(1) LocalTableScan [decNum#14, intNum#15]

The relevant code lines of interest are the following:

  • Addition of two decimal values ( expression coming from sum) that results in value that cannot be contained.
  • Writing a big overflow decimal using UnsafeWriter —> which will write null silently.

Codegen id1:
/* 108 / agg_value_3 = agg_value_4.$plus(agg_expr_0_0);
/
113 / agg_value_2 = agg_value_3;
/
126 / agg_mutableStateArray_0[0] = agg_value_2;
/
147 */ bhj_mutableStateArray_0[3].write(0, agg_mutableStateArray_0[0], 38, 18);
https://github.com/skambha/notes/blob/master/spark28067_ansifalse_wholestagetrue.txt#L336
https://github.com/skambha/notes/blob/master/spark28067_ansifalse_wholestagetrue.txt#L354
https://github.com/skambha/notes/blob/master/spark28067_ansifalse_wholestagetrue.txt#L375

Codegen id2:
/* 080 / agg_value_5 = agg_value_6.$plus(agg_expr_0_0);
/
091 / agg_value_4 = agg_mutableStateArray_0[0];
/
098 / agg_mutableStateArray_0[0] = agg_value_4;
/
128 */ agg_mutableStateArray_1[0].write(0, agg_value_1, 38, 18);
https://github.com/skambha/notes/blob/master/spark28067_ansifalse_wholestagetrue.txt#L164
https://github.com/skambha/notes/blob/master/spark28067_ansifalse_wholestagetrue.txt#L212

  • For both partial and final stage of aggregate, we can see the same issue that can happen.

@cloud-fan, given the above pieces of code, why would the problem not exist for the partial aggregate (updateExpression) case?

@cloud-fan
Copy link
Contributor

Addition of two decimal values ( expression coming from sum) that results in value that cannot be contained.

You can try Decimal.+(Decimal) locally, it does return a value that is not null. We can't hold an overflowed decimal value in unsafe row, but a Decimal object can be temporary overflowed.

In my repro, spark.range(0, 12, 1, 1) works fine and spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) gives wrong result. I looked at the code again and whole-stage-codegen also stores the partial aggregate result to unsafe row. Can someone investigate it further and see why spark.range(0, 12, 1, 1) works?

@skambha
Copy link
Contributor Author

skambha commented Mar 6, 2020

Yes. I agree. E.g " agg_value_3 = agg_value_4.$plus(agg_expr_0_0);" This can overflow and hold the value but not when it writes out to the UnsafeRow.

Can someone investigate it further and see why spark.range(0, 12, 1, 1) works?

Yes. I will look into this. If someone else has suggestions please do share.

@skambha
Copy link
Contributor Author

skambha commented Mar 10, 2020

I looked into why the spark.range(0,12,1,1) works and the other scenario returns wrong results
that is mentioned in #27627 (comment)

Case 1 that used spark.range(0,12,1,1) - this returns null.
Please see this file for details. https://github.com/skambha/notes/blob/master/case1_works_notes.txt

There is only one partition and 1 wholestagecodegen subtree. There is no intermediate writing out via UnsafeRowWriter. Hence all the values are added up which ends up being a value that is a overflow value and then the check overflow code kicks in, that sees that is is a overflow value and it will turn it to a null. The null then gets written out.

  1. Partition is 1.
  2. This has 1 WholeStageCodegen subtrees.
  3. So the updateExpression codegen code will do the add ie +, for all the 12 decimal values and then it does not write it out to UnsafeRow via UnsafeRowWriter
    This will amount to 120000000000000000000.000000000000000000 which is overflowed value.
  4. This then goes to the CheckOverflow code which checks the decimal value with precision and it overflows and this call will evaluate to null.
/* 239 */         agg_value_1 = agg_mutableStateArray_0[0].toPrecision(
/* 240 */           38, 18, Decimal.ROUND_HALF_UP(), true);
  1. The null is then written out via UnsafeRowWriter.write call below.
/* 250 */       if (agg_isNull_1) {
/* 251 */         range_mutableStateArray_0[3].write(0, (Decimal) null, 38, 18);

This explains why this scenario works fine.

Case 2 that used spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) and got wrong results
Please see the file for details:
https://github.com/skambha/notes/blob/master/case2_wrongresults.txt

  1. Wholestagecodegen has 4 subtrees.
  2. NumPartitions is 2
  3. 11 rows in one partition and 1 row in another partition
  4. For the partition with 11 rows, the 11 rows add up to a value that is 110000000000000000000.000000000000000000 which is not containable in dec(38,18) and in UnsafeRowWriter, this will get written out as null.
  5. For the partition that has 1 row, the sum value is 10000000000000000000.000000000000000000 which is containable.
  6. Next, it does a merge of the two a) null and b) 10000000000000000000 using the coalesce and + expression in mergeExpressions and then that value is checked for overflow or not. This value is containable, so it returns the wrong result.

(Sidenote: If you split the range to 2 and 11 elements, you can see the result is 20000000000000000000.000000000000000000)


We already saw that when we write out to UnsafeRow an overflow value of decimal, it will write a null value.

So, whether the wrong results repros or not is basically

  • a combination of the values in each partition and when the sum of the values in that partition would become a overflow value and if that would end up being written out via UnsafeRowWriter, then that partition's sum will result being a null, and
  • then in the subsequent execution, when the values for each partition are sum'd for merge phase, if the resultant value is a overflow value or not.
    • If it is a overflow value, then due to the overflow check( coming from evaluateExpression), the resultant value will become null and then gets written out.
    • If it is not a overflow value, the resultant value(which could be wrong) will get written out.

@cloud-fan
Copy link
Contributor

LGTM except a few minor comments. This changes the aggregate buffer of the sum function, which means the streaming state store format is also changed. We need to mention it in the migration guide and suggest users restart the stream if their query contains sum of decimals.

cc @marmbrus @zsxwing @HeartSaVioR @xuanyuanking @dongjoon-hyun Usually we need to backport correctness fixes, but this breaks the streaming checkpoint and I don't know if it affects the backport policy.

@HeartSaVioR
Copy link
Contributor

HeartSaVioR commented Jun 1, 2020

This PR looks to "selectively" change the aggregation buffer (not entirely) - I don't know the possibility/ratio of actual usage on decimal for input on sum so it's not easy to say, but we let the streaming queries which contain stream-stream outer join "fail" if it starts with previous version of state, whereas we allow to run with previous version of state for stream-stream inner join.

https://spark.apache.org/docs/3.0.0-preview2/ss-migration-guide.html

Spark 3.0 fixes the correctness issue on Stream-stream outer join, which changes the schema of state. (SPARK-26154 for more details) Spark 3.0 will fail the query if you start your query from checkpoint constructed from Spark 2.x which uses stream-stream outer join. Please discard the checkpoint and replay previous inputs to recalculate outputs.

SPARK-26154 was decided to only put to Spark 3.0.0 because of backward incompatibility. IMHO this sounds like a similar case - it should be landed on Spark 3.0.0, not sure we do want to make breaking change on Spark 2.4.x.

Btw, this was simply possible because we added "versioning" of the state, but we have been applying the versioning for the operator rather than "specific" aggregate function, hence this hasn't been the case yet.

If we feel beneficial to construct the way to apply versioning for specific aggregate function (at least built-in) and someone has the straightforward way to deal with this, it would be ideal to do that.

Otherwise, I'd rather say we should make it fail, but shouldn't be with "arbitrary" error message. SPARK-27237 (#24173) will make it possible to show the schema incompatibility and why it's considered as incompatible.

I'd also propose to introduce state data source on batch query via SPARK-28190 so that end users can easily read and even rewrite the state, which eventually makes end users possible to migrate from old state format to the new state format. (It's available on separate project for Spark 2.4.x, but given there were some relevant requests in user mailing list - e.g. pre-load initial state, schema evolution - it's easier to work within the Spark project.)

@SparkQA
Copy link

SparkQA commented Jun 1, 2020

Test build #123389 has finished for PR 27627 at commit 7795888.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@HeartSaVioR
Copy link
Contributor

retest this, please

@xuanyuanking
Copy link
Member

Thanks @cloud-fan for reminding.

  • Agree with Jungtaek about only put this into 3.0 and leave 2.4 and earlier version unchanged. Since the sum(decimal) is common use cases and this PR aims to fix the overflow scenario.

  • Only add a migration guide might not enough. We need a detecting logic for the state store and fail fast here. I'll submit a PR this week. Recently, I'm addressing some streaming issues caused by underlying state store format changing. It's super dangerous now because the changes may cause random exceptions, even the wrong answer in Structured Streaming.

  • I'm wondering whether the bug fix can be achieved by not changing the format for all sum(decimal) cases, then it won't break most of the checkpoint.

@HeartSaVioR
Copy link
Contributor

@xuanyuanking

We need a detecting logic for the state store and fail fast here. I'll submit a PR this week.

Just to avoid redundant efforts, have you look into #24173? If your approach is different than #24173, what approach you will be proposing?

@cloud-fan
Copy link
Contributor

How about we merge it to master only first, and wait for the schema incompatibility check to be done? This is a long-standing issue and we don't need to block 3.0. We can still backport to branch-3.0 with more discussion.

@HeartSaVioR
Copy link
Contributor

Personally I'm OK with skip including this on Spark 3.0.0, preferably with mentioning as "known issue". It would be ideal if we can maintain the page about known "correctness"/"data-loss" issues and fixed version(s) if the issue is resolved on specific version.

@cloud-fan cloud-fan changed the title [WIP][SPARK-28067][SQL] Fix incorrect results for decimal aggregate sum by returning null on decimal overflow [SPARK-28067][SQL] Fix incorrect results for decimal aggregate sum by returning null on decimal overflow Jun 2, 2020
@cloud-fan
Copy link
Contributor

retest this please

@SparkQA
Copy link

SparkQA commented Jun 2, 2020

Test build #123398 has finished for PR 27627 at commit 7795888.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@xuanyuanking
Copy link
Member

xuanyuanking commented Jun 2, 2020

How about we merge it to master only first, and wait for the schema incompatibility check to be done?

Agree.

Just to avoid redundant efforts, have you look into #24173? If your approach is different than #24173, what approach you will be proposing?

@HeartSaVioR Thanks for the reminding. I also looked into #24173 before. My approach is checking the underlying unsafe row format instead of adding a new schema file in the checkpoint. It is decided by the requirement of detecting the format changing during migration, which has no chance for the user to create a schema file.

But I think our approaches can complement each other. Let's discuss in my newly created PR, I'll submit it late today.

@SparkQA
Copy link

SparkQA commented Jun 2, 2020

Test build #123414 has finished for PR 27627 at commit 7795888.

  • This patch fails due to an unknown error code, -9.
  • This patch merges cleanly.
  • This patch adds no public classes.

@xuanyuanking
Copy link
Member

retest this please

@cloud-fan
Copy link
Contributor

tests already pass: #27627 (comment)

I'm merging it to master first, thanks!

@cloud-fan cloud-fan closed this in 4161c62 Jun 2, 2020
@SparkQA
Copy link

SparkQA commented Jun 2, 2020

Test build #123431 has finished for PR 27627 at commit 7795888.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@skambha
Copy link
Contributor Author

skambha commented Jun 2, 2020

Thanks @cloud-fan and all for merge into master.

Thanks for the discussion related to streaming use cases backward compatibility.
I wanted to clarify my understanding of some of the recent discussion on it.

xuanyuanking said "Agree with Jungtaek about only put this into 3.0 and leave 2.4 and earlier version unchanged. "#27627 (comment)

Initially, it sounded like both (@xuanyuanking and @HeartSaVioR were ok with it going into v3.0 and not go into v2.4.

But later we decide to not put it into v3.0.

Could you point out the PR/ JIRA that we are waiting for before this can go into v3.x branch.

I agree, it seems like if there is a compatibility issue, it would be best to go into a major release than a minor one. Just trying to understand: What is different in this use case than what we have in SPARK-26154 that went into v3.0 and was pointed out here. #27627 (comment)

@HeartSaVioR
Copy link
Contributor

HeartSaVioR commented Jun 3, 2020

Actually, Spark 3.0.0 is the better place to land if we only concern about backward compatibility, but even for the major version update we also don't want to scare end users.

SPARK-26154 introduced the "versioning" of the state of stream-stream join, so that Spark 3.0.0 can indicate the "old" state and fail the query with "proper" error message. There's no such thing for this patch; that's why I asked about "versioning" of state for streaming aggregation "function" but I'm not sure it's preferred approach and even we agree with that I'm not sure we have enough time to deal with it in Spark 3.0.0. (It's actually delayed pretty much.)

My personal feeling is that we should bring the essential functionality (say, schema information of state, #24173) ASAP, so that we can at least guide such case for end users like "if you didn't touch your query but encounter the schema incompatible error on state, please find migration guide for your Spark version to see there's any backward incompatible change" in the future.

Unfortunately, even we adopt #24173 into Spark 3.0.0 (even I'm not sure it would happen), that doesn't apply on migration from Spark 2.x to 3.0.0 as they won't have schema in existing state from Spark 2.x as of now. Even we can craft a tool to create schema file for states on Spark 2.x structured streaming query so that end users can adopt it before migrating to Spark 3.0.0, but well, that's a sketched idea yet and I have to get some sort of supports from community to move on.

#28707 would help determining the issue at least for this issue (as the number of fields will not match) so #28707 might unblock this patch to be included in branch-3.0, but the error message would be a bit unfriendly because we won't have detailed information about schema of the state.

@skambha
Copy link
Contributor Author

skambha commented Jun 3, 2020

Actually, Spark 3.0.0 is the better place to land if we only concern about backward compatibility,

I think so too since this is a correctness issue.

If I summarize my understanding from discussion with @cloud-fan on the #28707, we are waiting on that pr which is probably going into v3.1 - it will help detect and throw a better error message when it detects compatibility issues, which is a good thing.

That said, if we wait for that and both these go into v3.1 lets say, I wonder if it is more of an unexpected breaking change for streaming use cases when they go from v3.0 to v3.1. @cloud-fan, wdyt? Thanks.

HyukjinKwon pushed a commit that referenced this pull request Jul 9, 2020
### What changes were proposed in this pull request?

This is a followup of #27627 to fix the remaining issues. There are 2 issues fixed in this PR:
1. `UnsafeRow.setDecimal` can set an overflowed decimal and causes an error when reading it. The expected behavior is to return null.
2. The update/merge expression for decimal type in `Sum` is wrong. We shouldn't turn the `sum` value back to 0 after it becomes null due to overflow. This issue was hidden because:
2.1 for hash aggregate, the buffer is unsafe row. Due to the first bug, we fail when overflow happens, so there is no chance to mistakenly turn null back to 0.
2.2 for sort-based aggregate, the buffer is generic row. The decimal can overflow (the Decimal class has unlimited precision) and we don't have the null problem.

If we only fix the first bug, then the second bug is exposed and test fails. If we only fix the second bug, there is no way to test it. This PR fixes these 2 bugs together.

### Why are the changes needed?

Fix issues during decimal sum when overflow happens

### Does this PR introduce _any_ user-facing change?

Yes. Now decimal sum can return null correctly for overflow under non-ansi mode.

### How was this patch tested?

new test and updated test

Closes #29026 from cloud-fan/decimal.

Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: HyukjinKwon <[email protected]>
cloud-fan pushed a commit that referenced this pull request Aug 13, 2020
…erflow of sum aggregation

### What changes were proposed in this pull request?

This is a followup of #29125
In branch 3.0:
1. for hash aggregation, before #29125 there will be a runtime exception on decimal overflow of sum aggregation; after #29125, there could be a wrong result.
2. for sort aggregation, with/without #29125, there could be a wrong result on decimal overflow.

While in master branch(the future 3.1 release), the problem doesn't exist since in #27627 there is a flag for marking whether overflow happens in aggregation buffer. However, the aggregation buffer is written in steaming checkpoints. Thus, we can't change to aggregation buffer to resolve the issue.

As there is no easy solution for returning null/throwing exception regarding `spark.sql.ansi.enabled` on overflow in branch 3.0, we have to make a choice here: always throw exception on decimal value overflow of sum aggregation.
### Why are the changes needed?

Avoid returning wrong result in decimal value sum aggregation.

### Does this PR introduce _any_ user-facing change?

Yes, there is always exception on decimal value overflow of sum aggregation, instead of a possible wrong result.

### How was this patch tested?

Unit test case

Closes #29404 from gengliangwang/fixSum.

Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
jlfsdtc added a commit to Kyligence/spark that referenced this pull request Jul 23, 2021
* KE-24858
[SPARK-28067][SQL] Fix incorrect results for decimal aggregate sum by returning null on decimal overflow

* [SPARK-28067][SPARK-32018] Fix decimal overflow issues

### What changes were proposed in this pull request?

This is a followup of apache#27627 to fix the remaining issues. There are 2 issues fixed in this PR:
1. `UnsafeRow.setDecimal` can set an overflowed decimal and causes an error when reading it. The expected behavior is to return null.
2. The update/merge expression for decimal type in `Sum` is wrong. We shouldn't turn the `sum` value back to 0 after it becomes null due to overflow. This issue was hidden because:
2.1 for hash aggregate, the buffer is unsafe row. Due to the first bug, we fail when overflow happens, so there is no chance to mistakenly turn null back to 0.
2.2 for sort-based aggregate, the buffer is generic row. The decimal can overflow (the Decimal class has unlimited precision) and we don't have the null problem.

If we only fix the first bug, then the second bug is exposed and test fails. If we only fix the second bug, there is no way to test it. This PR fixes these 2 bugs together.

### Why are the changes needed?

Fix issues during decimal sum when overflow happens

### Does this PR introduce _any_ user-facing change?

Yes. Now decimal sum can return null correctly for overflow under non-ansi mode.

### How was this patch tested?

new test and updated test

Closes apache#29026 from cloud-fan/decimal.

Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: HyukjinKwon <[email protected]>

* KE-24858
fix error: java.lang.IllegalArgumentException: Can not interpolate java.lang.Boolean into code block.

* KE-24858
fix ci error

* KE-24858 update pom version

Co-authored-by: Sunitha Kambhampati <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Co-authored-by: longfei.jiang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants