Skip to content

Commit

Permalink
[SPARK-28201][SQL] Revisit MakeDecimal behavior on overflow
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

In SPARK-23179, it has been introduced a flag to control the behavior in case of overflow on decimals. The behavior is: returning `null` when `spark.sql.decimalOperations.nullOnOverflow` (default and traditional Spark behavior); throwing an `ArithmeticException` if that conf is false (according to SQL standards, other DBs behavior).

`MakeDecimal` so far had an ambiguous behavior. In case of codegen mode, it returned `null` as the other operators, but in interpreted mode, it was throwing an `IllegalArgumentException`.

The PR aligns `MakeDecimal`'s behavior with the one of other operators as defined in SPARK-23179. So now both modes return `null` or throw `ArithmeticException` according to `spark.sql.decimalOperations.nullOnOverflow`'s value.

Credits for this PR to mickjermsurawong-stripe who pointed out the wrong behavior in apache#20350.

## How was this patch tested?

improved UTs

Closes apache#25010 from mgaido91/SPARK-28201.

Authored-by: Marco Gaido <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
mgaido91 authored and cloud-fan committed Jul 1, 2019
1 parent 048224c commit bc4a676
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -46,19 +47,38 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
*/
case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression {

private val nullOnOverflow = SQLConf.get.decimalOperationsNullOnOverflow

override def dataType: DataType = DecimalType(precision, scale)
override def nullable: Boolean = true
override def nullable: Boolean = child.nullable || nullOnOverflow
override def toString: String = s"MakeDecimal($child,$precision,$scale)"

protected override def nullSafeEval(input: Any): Any =
Decimal(input.asInstanceOf[Long], precision, scale)
protected override def nullSafeEval(input: Any): Any = {
val longInput = input.asInstanceOf[Long]
val result = new Decimal()
if (nullOnOverflow) {
result.setOrNull(longInput, precision, scale)
} else {
result.set(longInput, precision, scale)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, eval => {
val setMethod = if (nullOnOverflow) {
"setOrNull"
} else {
"set"
}
val setNull = if (nullable) {
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
}
s"""
${ev.value} = (new Decimal()).setOrNull($eval, $precision, $scale);
${ev.isNull} = ${ev.value} == null;
"""
|${ev.value} = (new Decimal()).$setMethod($eval, $precision, $scale);
|$setNull
|""".stripMargin
})
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
*/
def set(unscaled: Long, precision: Int, scale: Int): Decimal = {
if (setOrNull(unscaled, precision, scale) == null) {
throw new IllegalArgumentException("Unscaled value too large for precision")
throw new ArithmeticException("Unscaled value too large for precision")
}
this
}
Expand Down Expand Up @@ -111,9 +111,10 @@ final class Decimal extends Ordered[Decimal] with Serializable {
*/
def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
this.decimalVal = decimal.setScale(scale, ROUND_HALF_UP)
require(
decimalVal.precision <= precision,
s"Decimal precision ${decimalVal.precision} exceeds max precision $precision")
if (decimalVal.precision > precision) {
throw new ArithmeticException(
s"Decimal precision ${decimalVal.precision} exceeds max precision $precision")
}
this.longVal = 0L
this._precision = precision
this._scale = scale
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{Decimal, DecimalType, LongType}

class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand All @@ -31,8 +32,23 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("MakeDecimal") {
checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") {
checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
val overflowExpr = MakeDecimal(Literal.create(1000L, LongType), 3, 1)
checkEvaluation(overflowExpr, null)
checkEvaluationWithMutableProjection(overflowExpr, null)
evaluateWithoutCodegen(overflowExpr, null)
checkEvaluationWithUnsafeProjection(overflowExpr, null)
}
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") {
checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
val overflowExpr = MakeDecimal(Literal.create(1000L, LongType), 3, 1)
intercept[ArithmeticException](checkEvaluationWithMutableProjection(overflowExpr, null))
intercept[ArithmeticException](evaluateWithoutCodegen(overflowExpr, null))
intercept[ArithmeticException](checkEvaluationWithUnsafeProjection(overflowExpr, null))
}
}

test("PromotePrecision") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
checkDecimal(Decimal(1000000000000000000L, 20, 2), "10000000000000000.00", 20, 2)
checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0)
checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0)
intercept[IllegalArgumentException](Decimal(170L, 2, 1))
intercept[IllegalArgumentException](Decimal(170L, 2, 0))
intercept[IllegalArgumentException](Decimal(BigDecimal("10.030"), 2, 1))
intercept[IllegalArgumentException](Decimal(BigDecimal("-9.95"), 2, 1))
intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0))
intercept[ArithmeticException](Decimal(170L, 2, 1))
intercept[ArithmeticException](Decimal(170L, 2, 0))
intercept[ArithmeticException](Decimal(BigDecimal("10.030"), 2, 1))
intercept[ArithmeticException](Decimal(BigDecimal("-9.95"), 2, 1))
intercept[ArithmeticException](Decimal(1e17.toLong, 17, 0))
}

test("creating decimals with negative scale") {
Expand Down

0 comments on commit bc4a676

Please sign in to comment.