Skip to content

Commit

Permalink
[SPARK-28200][SQL] Decimal overflow handling in ExpressionEncoder
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

- Currently, `ExpressionEncoder` does not handle bigdecimal overflow. Round-tripping overflowing java/scala BigDecimal/BigInteger returns null.
  - The serializer encode java/scala BigDecimal to to sql Decimal, which still has the underlying data to the former.
  - When writing out to UnsafeRow, `changePrecision` will be false and row has null value.
https://github.com/apache/spark/blob/24e1e41648de58d3437e008b187b84828830e238/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java#L202-L206
- In [SPARK-23179](apache#20350), an option to throw exception on decimal overflow was introduced.
- This PR adds the option in `ExpressionEncoder` to throw when detecting overflowing BigDecimal/BigInteger before its corresponding Decimal gets written to Row. This gives a consistent behavior between decimal arithmetic on sql expression (DecimalPrecision), and getting decimal from dataframe (RowEncoder)

Thanks to mgaido91 for the very first PR `SPARK-23179` and follow-up discussion on this change.
Thanks to JoshRosen for working with me on this.

## How was this patch tested?

added unit tests

Closes apache#25016 from mickjermsurawong-stripe/SPARK-28200.

Authored-by: Mick Jermsurawong <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
mickjermsurawong-stripe authored and cloud-fan committed Jul 5, 2019
1 parent e299f62 commit 683e270
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, IsNull, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.{CheckOverflow, CreateNamedStruct, Expression, IsNull, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

object SerializerBuildHelper {

private def nullOnOverflow: Boolean = SQLConf.get.decimalOperationsNullOnOverflow

def createSerializerForBoolean(inputObject: Expression): Expression = {
Invoke(inputObject, "booleanValue", BooleanType)
}
Expand Down Expand Up @@ -99,25 +102,25 @@ object SerializerBuildHelper {
}

def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = {
StaticInvoke(
CheckOverflow(StaticInvoke(
Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil,
returnNullable = false)
returnNullable = false), DecimalType.SYSTEM_DEFAULT, nullOnOverflow)
}

def createSerializerForScalaBigDecimal(inputObject: Expression): Expression = {
createSerializerForJavaBigDecimal(inputObject)
}

def createSerializerForJavaBigInteger(inputObject: Expression): Expression = {
StaticInvoke(
CheckOverflow(StaticInvoke(
Decimal.getClass,
DecimalType.BigIntDecimal,
"apply",
inputObject :: Nil,
returnNullable = false)
returnNullable = false), DecimalType.BigIntDecimal, nullOnOverflow)
}

def createSerializerForScalaBigInt(inputObject: Expression): Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ClosureCleaner
Expand Down Expand Up @@ -379,6 +380,78 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
assert(e.getMessage.contains("tuple with more than 22 elements are not supported"))
}

// Scala / Java big decimals ----------------------------------------------------------

encodeDecodeTest(BigDecimal(("9" * 20) + "." + "9" * 18),
"scala decimal within precision/scale limit")
encodeDecodeTest(new java.math.BigDecimal(("9" * 20) + "." + "9" * 18),
"java decimal within precision/scale limit")

encodeDecodeTest(-BigDecimal(("9" * 20) + "." + "9" * 18),
"negative scala decimal within precision/scale limit")
encodeDecodeTest(new java.math.BigDecimal(("9" * 20) + "." + "9" * 18).negate,
"negative java decimal within precision/scale limit")

testOverflowingBigNumeric(BigDecimal("1" * 21), "scala big decimal")
testOverflowingBigNumeric(new java.math.BigDecimal("1" * 21), "java big decimal")

testOverflowingBigNumeric(-BigDecimal("1" * 21), "negative scala big decimal")
testOverflowingBigNumeric(new java.math.BigDecimal("1" * 21).negate, "negative java big decimal")

testOverflowingBigNumeric(BigDecimal(("1" * 21) + ".123"),
"scala big decimal with fractional part")
testOverflowingBigNumeric(new java.math.BigDecimal(("1" * 21) + ".123"),
"java big decimal with fractional part")

testOverflowingBigNumeric(BigDecimal(("1" * 21) + "." + "9999" * 100),
"scala big decimal with long fractional part")
testOverflowingBigNumeric(new java.math.BigDecimal(("1" * 21) + "." + "9999" * 100),
"java big decimal with long fractional part")

// Scala / Java big integers ----------------------------------------------------------

encodeDecodeTest(BigInt("9" * 38), "scala big integer within precision limit")
encodeDecodeTest(new BigInteger("9" * 38), "java big integer within precision limit")

encodeDecodeTest(-BigInt("9" * 38),
"negative scala big integer within precision limit")
encodeDecodeTest(new BigInteger("9" * 38).negate(),
"negative java big integer within precision limit")

testOverflowingBigNumeric(BigInt("1" * 39), "scala big int")
testOverflowingBigNumeric(new BigInteger("1" * 39), "java big integer")

testOverflowingBigNumeric(-BigInt("1" * 39), "negative scala big int")
testOverflowingBigNumeric(new BigInteger("1" * 39).negate, "negative java big integer")

testOverflowingBigNumeric(BigInt("9" * 100), "scala very large big int")
testOverflowingBigNumeric(new BigInteger("9" * 100), "java very big int")

private def testOverflowingBigNumeric[T: TypeTag](bigNumeric: T, testName: String): Unit = {
Seq(true, false).foreach { allowNullOnOverflow =>
testAndVerifyNotLeakingReflectionObjects(
s"overflowing $testName, allowNullOnOverflow=$allowNullOnOverflow") {
withSQLConf(
SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> allowNullOnOverflow.toString
) {
// Need to construct Encoder here rather than implicitly resolving it
// so that SQLConf changes are respected.
val encoder = ExpressionEncoder[T]()
if (allowNullOnOverflow) {
val convertedBack = encoder.resolveAndBind().fromRow(encoder.toRow(bigNumeric))
assert(convertedBack === null)
} else {
val e = intercept[RuntimeException] {
encoder.toRow(bigNumeric)
}
assert(e.getMessage.contains("Error while encoding"))
assert(e.getCause.getClass === classOf[ArithmeticException])
}
}
}
}
}

private def encodeDecodeTest[T : ExpressionEncoder](
input: T,
testName: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,32 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
assert(row.toSeq(schema).head == decimal)
}

test("SPARK-23179: RowEncoder should respect nullOnOverflow for decimals") {
val schema = new StructType().add("decimal", DecimalType.SYSTEM_DEFAULT)
testDecimalOverflow(schema, Row(BigDecimal("9" * 100)))
testDecimalOverflow(schema, Row(new java.math.BigDecimal("9" * 100)))
}

private def testDecimalOverflow(schema: StructType, row: Row): Unit = {
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") {
val encoder = RowEncoder(schema).resolveAndBind()
intercept[Exception] {
encoder.toRow(row)
} match {
case e: ArithmeticException =>
assert(e.getMessage.contains("cannot be represented as Decimal"))
case e: RuntimeException =>
assert(e.getCause.isInstanceOf[ArithmeticException])
assert(e.getCause.getMessage.contains("cannot be represented as Decimal"))
}
}

withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") {
val encoder = RowEncoder(schema).resolveAndBind()
assert(encoder.fromRow(encoder.toRow(row)).get(0) == null)
}
}

test("RowEncoder should preserve schema nullability") {
val schema = new StructType().add("int", IntegerType, nullable = false)
val encoder = RowEncoder(schema).resolveAndBind()
Expand Down

0 comments on commit 683e270

Please sign in to comment.