Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23179][SQL] Support option to throw exception if overflow occurs during Decimal arithmetic #20350

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ object DecimalPrecision extends TypeCoercionRule {
PromotePrecision(Cast(e, dataType))
}

private def nullOnOverflow: Boolean = SQLConf.get.decimalOperationsNullOnOverflow

override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// fix decimal precision for expressions
case q => q.transformExpressionsUp(
Expand All @@ -105,7 +107,7 @@ object DecimalPrecision extends TypeCoercionRule {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
CheckOverflow(Add(promotePrecision(e1, resultType), promotePrecision(e2, resultType)),
resultType)
resultType, nullOnOverflow)

case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultScale = max(s1, s2)
Expand All @@ -116,7 +118,7 @@ object DecimalPrecision extends TypeCoercionRule {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
CheckOverflow(Subtract(promotePrecision(e1, resultType), promotePrecision(e2, resultType)),
resultType)
resultType, nullOnOverflow)

case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
Expand All @@ -126,7 +128,7 @@ object DecimalPrecision extends TypeCoercionRule {
}
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType)
resultType, nullOnOverflow)

case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
Expand All @@ -148,7 +150,7 @@ object DecimalPrecision extends TypeCoercionRule {
}
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType)
resultType, nullOnOverflow)

case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
Expand All @@ -159,7 +161,7 @@ object DecimalPrecision extends TypeCoercionRule {
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType)
resultType, nullOnOverflow)

case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
Expand All @@ -170,7 +172,7 @@ object DecimalPrecision extends TypeCoercionRule {
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)),
resultType)
resultType, nullOnOverflow)

case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging {
collect(left, negate) ++ collect(right, !negate)
case UnaryMinus(child) =>
collect(child, !negate)
case CheckOverflow(child, _) =>
case CheckOverflow(child, _, _) =>
collect(child, negate)
case PromotePrecision(child) =>
collect(child, negate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ object RowEncoder {
d,
"fromDecimal",
inputObject :: Nil,
returnNullable = false), d)
returnNullable = false), d, SQLConf.get.decimalOperationsNullOnOverflow)

case StringType => createSerializerForString(inputObject)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,30 +81,34 @@ case class PromotePrecision(child: Expression) extends UnaryExpression {

/**
* Rounds the decimal to given scale and check whether the decimal can fit in provided precision
* or not, returns null if not.
* or not. If not, if `nullOnOverflow` is `true`, it returns `null`; otherwise an
* `ArithmeticException` is thrown.
*/
case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression {
case class CheckOverflow(
child: Expression,
dataType: DecimalType,
nullOnOverflow: Boolean) extends UnaryExpression {

override def nullable: Boolean = true

override def nullSafeEval(input: Any): Any =
input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale)
input.asInstanceOf[Decimal].toPrecision(
dataType.precision,
dataType.scale,
Decimal.ROUND_HALF_UP,
nullOnOverflow)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, eval => {
val tmp = ctx.freshName("tmp")
s"""
| Decimal $tmp = $eval.clone();
| if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) {
| ${ev.value} = $tmp;
| } else {
| ${ev.isNull} = true;
| }
|${ev.value} = $eval.toPrecision(
| ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow);
|${ev.isNull} = ${ev.value} == null;
""".stripMargin
})
}

override def toString: String = s"CheckOverflow($child, $dataType)"
override def toString: String = s"CheckOverflow($child, $dataType, $nullOnOverflow)"

override def sql: String = child.sql
}
Original file line number Diff line number Diff line change
Expand Up @@ -1138,8 +1138,10 @@ abstract class RoundBase(child: Expression, scale: Expression,
val evaluationCode = dataType match {
case DecimalType.Fixed(_, s) =>
s"""
${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, Decimal.$modeStr());
${ev.isNull} = ${ev.value} == null;"""
|${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s,
| Decimal.$modeStr(), true);
|${ev.isNull} = ${ev.value} == null;
""".stripMargin
case ByteType =>
if (_scale < 0) {
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1441,6 +1441,16 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val DECIMAL_OPERATIONS_NULL_ON_OVERFLOW =
buildConf("spark.sql.decimalOperations.nullOnOverflow")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overflow can happen with non-decimal operations, do we need a new config?

cc @JoshRosen

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking a look at this @cloud-fan !

Yes, that case (non-decimal) is handled in #21599. I'd say that, in the non-decimal case, the situation is pretty different. Indeed, overflow in decimal operation is handled by Spark now, converting overflow operations to null; while overflow in operation on non-decimal isn't handled at all currently.

In non-decimal operations, indeed we return a wrong value (the java way). So IMHO, the non-decimal case current behavior doesn't make any sense at all (considering this is SQL and not a low level language like Java/Scala) and keeping its current behavior makes no sense (we already discussed this in that PR actually).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A DB does not have to follow the SQL standard completely in every corners. The current behavior in Spark is by design and I don't think that's nonsense.

I do agree that it's a valid requirement that some users want overflow to fail, but it should be protected by a config.

My question is if we need one config for overflow, or 2 configs for decimal and non-decimal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A DB does not have to follow the SQL standard completely in every corners. The current behavior in Spark is by design and I don't think that's nonsense.

I am sorry, but I don't really agree with you on this. I see the discussion is a bit OT, but I'd like just to explain the reasons of my opinion. SQL is a declarative language and here we are coupling the result/behavior to the specific execution language we are using. Spark is cross-language, but for arithmetic operations overflow works in a very peculiar way of the language we use which is:

  • against SQL standards and no other DB works differently from SQL standards w.r.t. this, so very surprising (at least) for SQL users;
  • different from what happens in Python and in R when you overflow in those languages (an Int becomes long and so on there);

So there in no Spark user other than Scala/Java ones who might understand the behavior Spark has in those cases. Sorry for being a bit OT, anyway.

My question is if we need one config for overflow, or 2 configs for decimal and non-decimal.

Yes, this is the main point here. IMHO, I'd prefer 2 configs because when the config is turned off, the behavior is completely different: in once case it returns null, in the other we return wrong results. But I see also the value in reducing as much as possible the number of configs, which is already pretty big. So I'd prefer 2 configs, but if you and the community thinks 1 it is better, I can update the PR in order to make this config more generic.

Thanks for your feedbacks and the discussion!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I think separate flags are okay. Here's why:

  • While eventually we probably want to add flaggable non-Decimal overflow detection (see [SPARK-26218][SQL] Overflow on arithmetic operations returns incorrect result #21599 (comment)), these PRs should land separately (to limit scope of changes / code review). If we give this PR's flag a generic name, merge this PR, and then somehow fail to merge the integer overflow PR in time for 3.0 then we'd be facing a situation where we'd need to change the behavior of a released flag if we later merge the non-Decimal overflow PR.
  • If we implement separate flags for each type of overflow then that doesn't preclude us from later introducing a single flag which is used as the default value for the per-type flags.

I'm interested in whichever option allows us to make incremental progress by getting this merged (even if flagged off by default) so that we can rely on this functionality being available in 3.x instead of having to maintain it indefinitely in our own fork (with all of the associated long-term maintenance and testing burdens).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One followup question regarding flag naming: is "overflow" the most precise term for the change made here? Or does this flag also change behavior in precision-loss scenarios? Maybe I'm getting tripped up on terminology here, since insufficient precision to represent small fractional quantities is essentially an "overflow" of the digit space reserved to represent the fractional part.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments @JoshRosen.
Yes, this deals with the overflow case. The underflow (or precision loss) is handled in a different way and the behavior depends on another config (see SPARK-22036): it either avoids precision loss, causing eventually overflow (old behavior) or truncates (as defined by the SQL standard and following closely SQL server behavior from which we derived our decimal operations implementation). So this flag is related only to the overflow case.

.internal()
.doc("When true (default), if an overflow on a decimal occurs, then NULL is returned. " +
"Spark's older versions and Hive behave in this way. If turned to false, SQL ANSI 2011 " +
"specification will be followed instead: an arithmetic exception is thrown, as most " +
"of the SQL databases do.")
.booleanConf
.createWithDefault(true)

val LITERAL_PICK_MINIMUM_PRECISION =
buildConf("spark.sql.legacy.literal.pickMinimumPrecision")
.internal()
Expand Down Expand Up @@ -2205,6 +2215,8 @@ class SQLConf extends Serializable with Logging {

def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)

def decimalOperationsNullOnOverflow: Boolean = getConf(DECIMAL_OPERATIONS_NULL_ON_OVERFLOW)

def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION)

def continuousStreamingEpochBacklogQueueSize: Int =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,25 @@ final class Decimal extends Ordered[Decimal] with Serializable {
/**
* Create new `Decimal` with given precision and scale.
*
* @return a non-null `Decimal` value if successful or `null` if overflow would occur.
* @return a non-null `Decimal` value if successful. Otherwise, if `nullOnOverflow` is true, null
* is returned; if `nullOnOverflow` is false, an `ArithmeticException` is thrown.
*/
private[sql] def toPrecision(
precision: Int,
scale: Int,
roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = {
roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP,
nullOnOverflow: Boolean = true): Decimal = {
val copy = clone()
if (copy.changePrecision(precision, scale, roundMode)) copy else null
if (copy.changePrecision(precision, scale, roundMode)) {
copy
} else {
if (nullOnOverflow) {
null
mgaido91 marked this conversation as resolved.
Show resolved Hide resolved
} else {
throw new ArithmeticException(
s"$toDebugString cannot be represented as Decimal($precision, $scale).")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,26 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {

test("CheckOverflow") {
val d1 = Decimal("10.1")
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10"))
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1)
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1)
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null)
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0), true), Decimal("10"))
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1), true), d1)
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2), true), d1)
checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3), true), null)
intercept[ArithmeticException](CheckOverflow(Literal(d1), DecimalType(4, 3), false).eval())
intercept[ArithmeticException](checkEvaluationWithMutableProjection(
CheckOverflow(Literal(d1), DecimalType(4, 3), false), null))

val d2 = Decimal(101, 3, 1)
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10"))
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2)
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2)
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null)
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0), true), Decimal("10"))
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1), true), d2)
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2), true), d2)
checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3), true), null)
intercept[ArithmeticException](CheckOverflow(Literal(d2), DecimalType(4, 3), false).eval())
intercept[ArithmeticException](checkEvaluationWithMutableProjection(
CheckOverflow(Literal(d2), DecimalType(4, 3), false), null))

checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null)
checkEvaluation(CheckOverflow(
Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2), true), null)
checkEvaluation(CheckOverflow(
Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2), false), null)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,28 @@ select 12345678912345678912345678912.1234567 + 9999999999999999999999999999999.1
select 123456789123456789.1234567890 * 1.123456789123456789;
select 12345678912345.123456789123 / 0.000000012345678;

-- throw an exception instead of returning NULL, according to SQL ANSI 2011
set spark.sql.decimalOperations.nullOnOverflow=false;

-- test operations between decimals and constants
select id, a*10, b/10 from decimals_test order by id;

-- test operations on constants
select 10.3 * 3.0;
select 10.3000 * 3.0;
select 10.30000 * 30.0;
select 10.300000000000000000 * 3.000000000000000000;
select 10.300000000000000000 * 3.0000000000000000000;

-- arithmetic operations causing an overflow throw exception
select (5e36 + 0.1) + 5e36;
select (-4e36 - 0.1) - 7e36;
select 12345678901234567890.0 * 12345678901234567890.0;
select 1e35 / 0.1;

-- arithmetic operations causing a precision loss throw exception
select 123456789123456789.1234567890 * 1.123456789123456789;
select 123456789123456789.1234567890 * 1.123456789123456789;
select 12345678912345.123456789123 / 0.000000012345678;

drop table decimals_test;
Loading