Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-28201][SQL] Revisit MakeDecimal behavior on overflow #25010

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -46,19 +47,38 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
*/
case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression {

private val nullOnOverflow = SQLConf.get.decimalOperationsNullOnOverflow

override def dataType: DataType = DecimalType(precision, scale)
override def nullable: Boolean = true
override def nullable: Boolean = child.nullable || nullOnOverflow
override def toString: String = s"MakeDecimal($child,$precision,$scale)"

protected override def nullSafeEval(input: Any): Any =
Decimal(input.asInstanceOf[Long], precision, scale)
protected override def nullSafeEval(input: Any): Any = {
val longInput = input.asInstanceOf[Long]
val result = new Decimal()
if (nullOnOverflow) {
result.setOrNull(longInput, precision, scale)
} else {
result.set(longInput, precision, scale)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, eval => {
val setMethod = if (nullOnOverflow) {
"setOrNull"
} else {
"set"
}
val setNull = if (nullable) {
mgaido91 marked this conversation as resolved.
Show resolved Hide resolved
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
}
s"""
${ev.value} = (new Decimal()).setOrNull($eval, $precision, $scale);
${ev.isNull} = ${ev.value} == null;
"""
|${ev.value} = (new Decimal()).$setMethod($eval, $precision, $scale);
|$setNull
|""".stripMargin
})
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
*/
def set(unscaled: Long, precision: Int, scale: Int): Decimal = {
if (setOrNull(unscaled, precision, scale) == null) {
throw new IllegalArgumentException("Unscaled value too large for precision")
throw new ArithmeticException("Unscaled value too large for precision")
}
this
}
Expand Down Expand Up @@ -111,9 +111,10 @@ final class Decimal extends Ordered[Decimal] with Serializable {
*/
def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
this.decimalVal = decimal.setScale(scale, ROUND_HALF_UP)
require(
decimalVal.precision <= precision,
s"Decimal precision ${decimalVal.precision} exceeds max precision $precision")
if (decimalVal.precision > precision) {
throw new ArithmeticException(
s"Decimal precision ${decimalVal.precision} exceeds max precision $precision")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, All.
This seems to break JDBC Integration Test suite. I'll make a followup PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made #25165 .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dongjoon-hyun , sorry for the trouble. I hadn't run integration tests indeed.

}
this.longVal = 0L
this._precision = precision
this._scale = scale
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{Decimal, DecimalType, LongType}

class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand All @@ -31,8 +32,23 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("MakeDecimal") {
checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") {
checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
val overflowExpr = MakeDecimal(Literal.create(1000L, LongType), 3, 1)
checkEvaluation(overflowExpr, null)
checkEvaluationWithMutableProjection(overflowExpr, null)
evaluateWithoutCodegen(overflowExpr, null)
checkEvaluationWithUnsafeProjection(overflowExpr, null)
}
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") {
checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
val overflowExpr = MakeDecimal(Literal.create(1000L, LongType), 3, 1)
intercept[ArithmeticException](checkEvaluationWithMutableProjection(overflowExpr, null))
intercept[ArithmeticException](evaluateWithoutCodegen(overflowExpr, null))
intercept[ArithmeticException](checkEvaluationWithUnsafeProjection(overflowExpr, null))
}
}

test("PromotePrecision") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
checkDecimal(Decimal(1000000000000000000L, 20, 2), "10000000000000000.00", 20, 2)
checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0)
checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0)
intercept[IllegalArgumentException](Decimal(170L, 2, 1))
intercept[IllegalArgumentException](Decimal(170L, 2, 0))
intercept[IllegalArgumentException](Decimal(BigDecimal("10.030"), 2, 1))
intercept[IllegalArgumentException](Decimal(BigDecimal("-9.95"), 2, 1))
intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0))
intercept[ArithmeticException](Decimal(170L, 2, 1))
intercept[ArithmeticException](Decimal(170L, 2, 0))
intercept[ArithmeticException](Decimal(BigDecimal("10.030"), 2, 1))
intercept[ArithmeticException](Decimal(BigDecimal("-9.95"), 2, 1))
intercept[ArithmeticException](Decimal(1e17.toLong, 17, 0))
}

test("creating decimals with negative scale") {
Expand Down