From 0be5fb41a6b7ef4da9ba36f3604ac646cb6d4ae3 Mon Sep 17 00:00:00 2001 From: aokolnychyi Date: Mon, 17 Jul 2017 21:07:50 -0700 Subject: [PATCH] [SPARK-21332][SQL] Incorrect result type inferred for some decimal expressions ## What changes were proposed in this pull request? This PR changes the direction of expression transformation in the DecimalPrecision rule. Previously, the expressions were transformed down, which led to incorrect result types when decimal expressions had other decimal expressions as their operands. The root cause of this issue was in visiting outer nodes before their children. Consider the example below: ``` val inputSchema = StructType(StructField("col", DecimalType(26, 6)) :: Nil) val sc = spark.sparkContext val rdd = sc.parallelize(1 to 2).map(_ => Row(BigDecimal(12))) val df = spark.createDataFrame(rdd, inputSchema) // Works correctly since no nested decimal expression is involved // Expected result type: (26, 6) * (26, 6) = (38, 12) df.select($"col" * $"col").explain(true) df.select($"col" * $"col").printSchema() // Gives a wrong result since there is a nested decimal expression that should be visited first // Expected result type: ((26, 6) * (26, 6)) * (26, 6) = (38, 12) * (26, 6) = (38, 18) df.select($"col" * $"col" * $"col").explain(true) df.select($"col" * $"col" * $"col").printSchema() ``` The example above gives the following output: ``` // Correct result without sub-expressions == Parsed Logical Plan == 'Project [('col * 'col) AS (col * col)#4] +- LogicalRDD [col#1] == Analyzed Logical Plan == (col * col): decimal(38,12) Project [CheckOverflow((promote_precision(cast(col#1 as decimal(26,6))) * promote_precision(cast(col#1 as decimal(26,6)))), DecimalType(38,12)) AS (col * col)#4] +- LogicalRDD [col#1] == Optimized Logical Plan == Project [CheckOverflow((col#1 * col#1), DecimalType(38,12)) AS (col * col)#4] +- LogicalRDD [col#1] == Physical Plan == *Project [CheckOverflow((col#1 * col#1), DecimalType(38,12)) AS (col * col)#4] +- Scan ExistingRDD[col#1] // Schema root |-- (col * col): decimal(38,12) (nullable = true) // Incorrect result with sub-expressions == Parsed Logical Plan == 'Project [(('col * 'col) * 'col) AS ((col * col) * col)#11] +- LogicalRDD [col#1] == Analyzed Logical Plan == ((col * col) * col): decimal(38,12) Project [CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(col#1 as decimal(26,6))) * promote_precision(cast(col#1 as decimal(26,6)))), DecimalType(38,12)) as decimal(26,6))) * promote_precision(cast(col#1 as decimal(26,6)))), DecimalType(38,12)) AS ((col * col) * col)#11] +- LogicalRDD [col#1] == Optimized Logical Plan == Project [CheckOverflow((cast(CheckOverflow((col#1 * col#1), DecimalType(38,12)) as decimal(26,6)) * col#1), DecimalType(38,12)) AS ((col * col) * col)#11] +- LogicalRDD [col#1] == Physical Plan == *Project [CheckOverflow((cast(CheckOverflow((col#1 * col#1), DecimalType(38,12)) as decimal(26,6)) * col#1), DecimalType(38,12)) AS ((col * col) * col)#11] +- Scan ExistingRDD[col#1] // Schema root |-- ((col * col) * col): decimal(38,12) (nullable = true) ``` ## How was this patch tested? This PR was tested with available unit tests. Moreover, there are tests to cover previously failing scenarios. Author: aokolnychyi Closes #18583 from aokolnychyi/spark-21332. --- .../spark/sql/catalyst/analysis/DecimalPrecision.scala | 2 +- .../spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 9c38dd2ee4e53..fd2ac78b25dbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -80,7 +80,7 @@ object DecimalPrecision extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // fix decimal precision for expressions - case q => q.transformExpressions( + case q => q.transformExpressionsUp( decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index ccf3c3fb0949d..11a4a79089416 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -90,8 +90,14 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { checkType(Average(d1), DecimalType(6, 5)) checkType(Add(Add(d1, d2), d1), DecimalType(7, 2)) + checkType(Add(Add(d1, d1), d1), DecimalType(4, 1)) + checkType(Add(d1, Add(d1, d1)), DecimalType(4, 1)) checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2)) checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2)) + checkType(Subtract(Subtract(d2, d1), d1), DecimalType(7, 2)) + checkType(Multiply(Multiply(d1, d1), d2), DecimalType(11, 4)) + checkType(Divide(d2, Add(d1, d1)), DecimalType(10, 6)) + checkType(Sum(Add(d1, d1)), DecimalType(13, 1)) } test("Comparison operations") {