Skip to content

Commit

Permalink
[SPARK-35103][SQL] Make TypeCoercion rules more efficient
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
This PR fixes a couple of things in TypeCoercion rules:
- Only run the propagate types step if the children of a node have output attributes with changed dataTypes and/or nullability. This is implemented as custom tree transformation. The TypeCoercion rules now only implement a partial function.
- Combine multiple type coercion rules into a single rule. Multiple rules are applied in single tree traversal.
- Reduce calls to conf.get in DecimalPrecision. This now happens once per tree traversal, instead of once per matched expression.
- Reduce the use of withNewChildren.

This brings down the number of CPU cycles spend in analysis by ~28% (benchmark: 10 iterations of all TPC-DS queries on SF10).

## How was this patch tested?
Existing tests.

Closes #32208 from sigmod/coercion.

Authored-by: Yingyi Bu <[email protected]>
Signed-off-by: herman <[email protected]>
  • Loading branch information
sigmod authored and hvanhovell committed Apr 19, 2021
1 parent 00f06dd commit 9a6d773
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 190 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ import org.apache.spark.sql.types._
*/
object AnsiTypeCoercion extends TypeCoercionBase {
override def typeCoercionRules: List[Rule[LogicalPlan]] =
InConversion ::
WidenSetOperationTypes ::
WidenSetOperationTypes ::
CombinedTypeCoercionRule(
InConversion ::
PromoteStringLiterals ::
DecimalPrecision ::
FunctionArgumentConversion ::
Expand All @@ -90,8 +91,7 @@ object AnsiTypeCoercion extends TypeCoercionBase {
ImplicitTypeCasts ::
DateTimeOperations ::
WindowFrameCoercion ::
StringLiteralCoercion ::
Nil
StringLiteralCoercion :: Nil) :: Nil

override def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
(t1, t2) match {
Expand Down Expand Up @@ -260,15 +260,14 @@ object AnsiTypeCoercion extends TypeCoercionBase {
*/
object PromoteStringLiterals extends TypeCoercionRule {
private def castExpr(expr: Expression, targetType: DataType): Expression = {
(expr.dataType, targetType) match {
case (NullType, dt) => Literal.create(null, targetType)
case (l, dt) if (l != dt) => Cast(expr, targetType)
expr.dataType match {
case NullType => Literal.create(null, targetType)
case l if l != targetType => Cast(expr, targetType)
case _ => expr
}
}

override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
override def transform: PartialFunction[Expression, Expression] = {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types._


Expand Down Expand Up @@ -80,16 +79,19 @@ object DecimalPrecision extends TypeCoercionRule {
PromotePrecision(Cast(e, dataType))
}

private def nullOnOverflow: Boolean = !conf.ansiEnabled
override def transform: PartialFunction[Expression, Expression] = {
decimalAndDecimal()
.orElse(integralAndDecimalLiteral)
.orElse(nondecimalAndDecimal(conf.literalPickMinimumPrecision))
}

override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// fix decimal precision for expressions
case q => q.transformExpressionsUp(
decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal))
private[catalyst] def decimalAndDecimal(): PartialFunction[Expression, Expression] = {
decimalAndDecimal(conf.decimalOperationsAllowPrecisionLoss, !conf.ansiEnabled)
}

/** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */
private[catalyst] val decimalAndDecimal: PartialFunction[Expression, Expression] = {
private[catalyst] def decimalAndDecimal(allowPrecisionLoss: Boolean, nullOnOverflow: Boolean)
: PartialFunction[Expression, Expression] = {
// Skip nodes whose children have not been resolved yet
case e if !e.childrenResolved => e

Expand All @@ -98,43 +100,43 @@ object DecimalPrecision extends TypeCoercionRule {

case a @ Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val resultScale = max(s1, s2)
val resultType = if (conf.decimalOperationsAllowPrecisionLoss) {
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
resultScale)
} else {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
CheckOverflow(
a.withNewChildren(Seq(promotePrecision(e1, resultType), promotePrecision(e2, resultType))),
a.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)),
resultType, nullOnOverflow)

case s @ Subtract(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2), _) =>
val resultScale = max(s1, s2)
val resultType = if (conf.decimalOperationsAllowPrecisionLoss) {
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
resultScale)
} else {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
CheckOverflow(
s.withNewChildren(Seq(promotePrecision(e1, resultType), promotePrecision(e2, resultType))),
s.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)),
resultType, nullOnOverflow)

case m @ Multiply(
e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val resultType = if (conf.decimalOperationsAllowPrecisionLoss) {
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
} else {
DecimalType.bounded(p1 + p2 + 1, s1 + s2)
}
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(
m.withNewChildren(Seq(promotePrecision(e1, widerType), promotePrecision(e2, widerType))),
m.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
resultType, nullOnOverflow)

case d @ Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val resultType = if (conf.decimalOperationsAllowPrecisionLoss) {
val resultType = if (allowPrecisionLoss) {
// Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
// Scale: max(6, s1 + p2 + 1)
val intDig = p1 - s1 + s2
Expand All @@ -153,39 +155,40 @@ object DecimalPrecision extends TypeCoercionRule {
}
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(
d.withNewChildren(Seq(promotePrecision(e1, widerType), promotePrecision(e2, widerType))),
d.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
resultType, nullOnOverflow)

case r @ Remainder(
e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val resultType = if (conf.decimalOperationsAllowPrecisionLoss) {
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
} else {
DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
}
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(
r.withNewChildren(Seq(promotePrecision(e1, widerType), promotePrecision(e2, widerType))),
r.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
resultType, nullOnOverflow)

case p @ Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val resultType = if (conf.decimalOperationsAllowPrecisionLoss) {
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
} else {
DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
}
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(
p.withNewChildren(Seq(promotePrecision(e1, widerType), promotePrecision(e2, widerType))),
p.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
resultType, nullOnOverflow)

case expr @ IntegralDivide(
e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val widerType = widerDecimalType(p1, s1, p2, s2)
val promotedExpr = expr.withNewChildren(
Seq(promotePrecision(e1, widerType), promotePrecision(e2, widerType)))
val promotedExpr = expr.copy(
left = promotePrecision(e1, widerType),
right = promotePrecision(e2, widerType))
if (expr.dataType.isInstanceOf[DecimalType]) {
// This follows division rule
val intDig = p1 - s1 + s2
Expand Down Expand Up @@ -301,7 +304,8 @@ object DecimalPrecision extends TypeCoercionRule {
* Type coercion for BinaryOperator in which one side is a non-decimal numeric, and the other
* side is a decimal.
*/
private val nondecimalAndDecimal: PartialFunction[Expression, Expression] = {
private def nondecimalAndDecimal(literalPickMinimumPrecision: Boolean)
: PartialFunction[Expression, Expression] = {
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
Expand All @@ -318,11 +322,11 @@ object DecimalPrecision extends TypeCoercionRule {
// become DECIMAL(38, 16), safely having a much lower precision loss.
case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] &&
l.dataType.isInstanceOf[IntegralType] &&
conf.literalPickMinimumPrecision =>
literalPickMinimumPrecision =>
b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r))
case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] &&
r.dataType.isInstanceOf[IntegralType] &&
conf.literalPickMinimumPrecision =>
literalPickMinimumPrecision =>
b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r))))
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
Expand Down
Loading

0 comments on commit 9a6d773

Please sign in to comment.