diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 4c31b52d62b4c..cbeff72d2d692 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -74,8 +74,9 @@ import org.apache.spark.sql.types._ */ object AnsiTypeCoercion extends TypeCoercionBase { override def typeCoercionRules: List[Rule[LogicalPlan]] = - InConversion :: - WidenSetOperationTypes :: + WidenSetOperationTypes :: + CombinedTypeCoercionRule( + InConversion :: PromoteStringLiterals :: DecimalPrecision :: FunctionArgumentConversion :: @@ -90,8 +91,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { ImplicitTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: - StringLiteralCoercion :: - Nil + StringLiteralCoercion :: Nil) :: Nil override def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { (t1, t2) match { @@ -260,15 +260,14 @@ object AnsiTypeCoercion extends TypeCoercionBase { */ object PromoteStringLiterals extends TypeCoercionRule { private def castExpr(expr: Expression, targetType: DataType): Expression = { - (expr.dataType, targetType) match { - case (NullType, dt) => Literal.create(null, targetType) - case (l, dt) if (l != dt) => Cast(expr, targetType) + expr.dataType match { + case NullType => Literal.create(null, targetType) + case l if l != targetType => Cast(expr, targetType) case _ => expr } } - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override def transform: PartialFunction[Expression, Expression] = { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index d27b0a5817912..bf128cd3753e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ @@ -80,16 +79,19 @@ object DecimalPrecision extends TypeCoercionRule { PromotePrecision(Cast(e, dataType)) } - private def nullOnOverflow: Boolean = !conf.ansiEnabled + override def transform: PartialFunction[Expression, Expression] = { + decimalAndDecimal() + .orElse(integralAndDecimalLiteral) + .orElse(nondecimalAndDecimal(conf.literalPickMinimumPrecision)) + } - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - // fix decimal precision for expressions - case q => q.transformExpressionsUp( - decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) + private[catalyst] def decimalAndDecimal(): PartialFunction[Expression, Expression] = { + decimalAndDecimal(conf.decimalOperationsAllowPrecisionLoss, !conf.ansiEnabled) } /** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */ - private[catalyst] val decimalAndDecimal: PartialFunction[Expression, Expression] = { + private[catalyst] def decimalAndDecimal(allowPrecisionLoss: Boolean, nullOnOverflow: Boolean) + : PartialFunction[Expression, Expression] = { // Skip nodes whose children have not been resolved yet case e if !e.childrenResolved => e @@ -98,43 +100,43 @@ object DecimalPrecision extends TypeCoercionRule { case a @ Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => val resultScale = max(s1, s2) - val resultType = if (conf.decimalOperationsAllowPrecisionLoss) { + val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } else { DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } CheckOverflow( - a.withNewChildren(Seq(promotePrecision(e1, resultType), promotePrecision(e2, resultType))), + a.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)), resultType, nullOnOverflow) case s @ Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => val resultScale = max(s1, s2) - val resultType = if (conf.decimalOperationsAllowPrecisionLoss) { + val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } else { DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } CheckOverflow( - s.withNewChildren(Seq(promotePrecision(e1, resultType), promotePrecision(e2, resultType))), + s.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)), resultType, nullOnOverflow) case m @ Multiply( e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => - val resultType = if (conf.decimalOperationsAllowPrecisionLoss) { + val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) } else { DecimalType.bounded(p1 + p2 + 1, s1 + s2) } val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow( - m.withNewChildren(Seq(promotePrecision(e1, widerType), promotePrecision(e2, widerType))), + m.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), resultType, nullOnOverflow) case d @ Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => - val resultType = if (conf.decimalOperationsAllowPrecisionLoss) { + val resultType = if (allowPrecisionLoss) { // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) // Scale: max(6, s1 + p2 + 1) val intDig = p1 - s1 + s2 @@ -153,12 +155,12 @@ object DecimalPrecision extends TypeCoercionRule { } val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow( - d.withNewChildren(Seq(promotePrecision(e1, widerType), promotePrecision(e2, widerType))), + d.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), resultType, nullOnOverflow) case r @ Remainder( e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => - val resultType = if (conf.decimalOperationsAllowPrecisionLoss) { + val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) } else { DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) @@ -166,11 +168,11 @@ object DecimalPrecision extends TypeCoercionRule { // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow( - r.withNewChildren(Seq(promotePrecision(e1, widerType), promotePrecision(e2, widerType))), + r.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), resultType, nullOnOverflow) case p @ Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => - val resultType = if (conf.decimalOperationsAllowPrecisionLoss) { + val resultType = if (allowPrecisionLoss) { DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) } else { DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) @@ -178,14 +180,15 @@ object DecimalPrecision extends TypeCoercionRule { // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow( - p.withNewChildren(Seq(promotePrecision(e1, widerType), promotePrecision(e2, widerType))), + p.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)), resultType, nullOnOverflow) case expr @ IntegralDivide( e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) => val widerType = widerDecimalType(p1, s1, p2, s2) - val promotedExpr = expr.withNewChildren( - Seq(promotePrecision(e1, widerType), promotePrecision(e2, widerType))) + val promotedExpr = expr.copy( + left = promotePrecision(e1, widerType), + right = promotePrecision(e2, widerType)) if (expr.dataType.isInstanceOf[DecimalType]) { // This follows division rule val intDig = p1 - s1 + s2 @@ -301,7 +304,8 @@ object DecimalPrecision extends TypeCoercionRule { * Type coercion for BinaryOperator in which one side is a non-decimal numeric, and the other * side is a decimal. */ - private val nondecimalAndDecimal: PartialFunction[Expression, Expression] = { + private def nondecimalAndDecimal(literalPickMinimumPrecision: Boolean) + : PartialFunction[Expression, Expression] = { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles case b @ BinaryOperator(left, right) if left.dataType != right.dataType => @@ -318,11 +322,11 @@ object DecimalPrecision extends TypeCoercionRule { // become DECIMAL(38, 16), safely having a much lower precision loss. case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] && l.dataType.isInstanceOf[IntegralType] && - conf.literalPickMinimumPrecision => + literalPickMinimumPrecision => b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r)) case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] && r.dataType.isInstanceOf[IntegralType] && - conf.literalPickMinimumPrecision => + literalPickMinimumPrecision => b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r)))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 4c651a0db6d9a..6ad84651c206f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -166,6 +166,28 @@ abstract class TypeCoercionBase { } } + /** + * Type coercion rule that combines multiple type coercion rules and applies them in a single tree + * traversal. + */ + case class CombinedTypeCoercionRule(rules: Seq[TypeCoercionRule]) extends TypeCoercionRule { + override def transform: PartialFunction[Expression, Expression] = { + val transforms = rules.map(_.transform) + Function.unlift { e: Expression => + val result = transforms.foldLeft(e) { + case (current, transform) => transform.applyOrElse(current, identity[Expression]) + } + if (result ne e) { + Some(result) + } else { + None + } + } + } + + override val ruleName: String = rules.map(_.ruleName).mkString("Combined[", ", ", "]") + } + /** * Widens the data types of the children of Union/Except/Intersect. * 1. When ANSI mode is off: @@ -194,9 +216,9 @@ abstract class TypeCoercionBase { * The implicit conversion is determined by the closest common data type from the precedent * lists from left and right child. See the comments of Object `AnsiTypeCoercion` for details. */ - object WidenSetOperationTypes extends TypeCoercionRule { + object WidenSetOperationTypes extends Rule[LogicalPlan] { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = { + override def apply(plan: LogicalPlan): LogicalPlan = { plan resolveOperatorsUpWithNewOutput { case s @ Except(left, right, isAll) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => @@ -294,8 +316,7 @@ abstract class TypeCoercionBase { * Analysis Exception will be raised at the type checking phase. */ object InConversion extends TypeCoercionRule { - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override val transform: PartialFunction[Expression, Expression] = { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -342,8 +363,7 @@ abstract class TypeCoercionBase { */ object FunctionArgumentConversion extends TypeCoercionRule { - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override val transform: PartialFunction[Expression, Expression] = { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -448,8 +468,7 @@ abstract class TypeCoercionBase { * converted to fractional types. */ object Division extends TypeCoercionRule { - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override val transform: PartialFunction[Expression, Expression] = { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.childrenResolved => e @@ -458,7 +477,7 @@ abstract class TypeCoercionBase { case d: Divide if d.dataType == DoubleType => d case d: Divide if d.dataType.isInstanceOf[DecimalType] => d case d @ Divide(left, right, _) if isNumericOrNull(left) && isNumericOrNull(right) => - d.withNewChildren(Seq(Cast(left, DoubleType), Cast(right, DoubleType))) + d.copy(left = Cast(left, DoubleType), right = Cast(right, DoubleType)) } private def isNumericOrNull(ex: Expression): Boolean = { @@ -472,10 +491,10 @@ abstract class TypeCoercionBase { * This rule cast the integral inputs to long type, to avoid overflow during calculation. */ object IntegralDivision extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override val transform: PartialFunction[Expression, Expression] = { case e if !e.childrenResolved => e case d @ IntegralDivide(left, right, _) => - d.withNewChildren(Seq(mayCastToLong(left), mayCastToLong(right))) + d.copy(left = mayCastToLong(left), right = mayCastToLong(right)) } private def mayCastToLong(expr: Expression): Expression = expr.dataType match { @@ -488,8 +507,7 @@ abstract class TypeCoercionBase { * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends TypeCoercionRule { - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override val transform: PartialFunction[Expression, Expression] = { case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) => val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) maybeCommonType.map { commonType => @@ -506,8 +524,7 @@ abstract class TypeCoercionBase { * Coerces the type of different branches of If statement to a common type. */ object IfCoercion extends TypeCoercionRule { - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override val transform: PartialFunction[Expression, Expression] = { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if !haveSameType(i.inputTypesForMerging) => @@ -527,7 +544,7 @@ abstract class TypeCoercionBase { * Coerces NullTypes in the Stack expression to the column types of the corresponding positions. */ object StackCoercion extends TypeCoercionRule { - override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override val transform: PartialFunction[Expression, Expression] = { case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows => Stack(children.zipWithIndex.map { // The first child is the number of rows for stack. @@ -547,19 +564,15 @@ abstract class TypeCoercionBase { */ object ConcatCoercion extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = { - plan resolveOperators { case p => - p transformExpressionsUp { - // Skip nodes if unresolved or empty children - case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c - case c @ Concat(children) if conf.concatBinaryAsString || - !children.map(_.dataType).forall(_ == BinaryType) => - val newChildren = c.children.map { e => - implicitCast(e, StringType).getOrElse(e) - } - c.copy(children = newChildren) + override val transform: PartialFunction[Expression, Expression] = { + // Skip nodes if unresolved or empty children + case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c + case c @ Concat(children) if conf.concatBinaryAsString || + !children.map(_.dataType).forall(_ == BinaryType) => + val newChildren = c.children.map { e => + implicitCast(e, StringType).getOrElse(e) } - } + c.copy(children = newChildren) } } @@ -568,7 +581,7 @@ abstract class TypeCoercionBase { * to a common type. */ object MapZipWithCoercion extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override val transform: PartialFunction[Expression, Expression] = { // Lambda function isn't resolved when the rule is executed. case m @ MapZipWith(left, right, function) if m.arguments.forall(a => a.resolved && MapType.acceptsType(a.dataType)) && !m.leftKeyType.sameType(m.rightKeyType) => @@ -595,30 +608,26 @@ abstract class TypeCoercionBase { */ object EltCoercion extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = { - plan resolveOperators { case p => - p transformExpressionsUp { - // Skip nodes if unresolved or not enough children - case c @ Elt(children, _) if !c.childrenResolved || children.size < 2 => c - case c @ Elt(children, _) => - val index = children.head - val newIndex = implicitCast(index, IntegerType).getOrElse(index) - val newInputs = if (conf.eltOutputAsString || - !children.tail.map(_.dataType).forall(_ == BinaryType)) { - children.tail.map { e => - implicitCast(e, StringType).getOrElse(e) - } - } else { - children.tail - } - c.copy(children = newIndex +: newInputs) + override val transform: PartialFunction[Expression, Expression] = { + // Skip nodes if unresolved or not enough children + case c @ Elt(children, _) if !c.childrenResolved || children.size < 2 => c + case c @ Elt(children, _) => + val index = children.head + val newIndex = implicitCast(index, IntegerType).getOrElse(index) + val newInputs = if (conf.eltOutputAsString || + !children.tail.map(_.dataType).forall(_ == BinaryType)) { + children.tail.map { e => + implicitCast(e, StringType).getOrElse(e) + } + } else { + children.tail } - } + c.copy(children = newIndex +: newInputs) } } - object DateTimeOperations extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + object DateTimeOperations extends TypeCoercionRule { + override val transform: PartialFunction[Expression, Expression] = { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e case d @ DateAdd(TimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) @@ -652,8 +661,7 @@ abstract class TypeCoercionBase { } } - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override val transform: PartialFunction[Expression, Expression] = { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -705,7 +713,7 @@ abstract class TypeCoercionBase { } } - udf.withNewChildren(children) + udf.copy(children = children) } private def udfInputToCastType(input: DataType, expectedType: DataType): DataType = { @@ -738,8 +746,7 @@ abstract class TypeCoercionBase { * Cast WindowFrame boundaries to the type they operate upon. */ object WindowFrameCoercion extends TypeCoercionRule { - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override val transform: PartialFunction[Expression, Expression] = { case s @ WindowSpecDefinition(_, Seq(order), SpecifiedWindowFrame(RangeFrame, lower, upper)) if order.resolved => s.copy(frameSpecification = SpecifiedWindowFrame( @@ -766,7 +773,7 @@ abstract class TypeCoercionBase { * TODO(SPARK-28589): implement ANSI type type coercion and handle string literals. */ object StringLiteralCoercion extends TypeCoercionRule { - override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override val transform: PartialFunction[Expression, Expression] = { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e case DateAdd(l, r) if r.dataType == StringType && r.foldable => @@ -805,8 +812,9 @@ abstract class TypeCoercionBase { object TypeCoercion extends TypeCoercionBase { override def typeCoercionRules: List[Rule[LogicalPlan]] = - InConversion :: - WidenSetOperationTypes :: + WidenSetOperationTypes :: + CombinedTypeCoercionRule( + InConversion :: PromoteStrings :: DecimalPrecision :: BooleanEquality :: @@ -822,8 +830,7 @@ object TypeCoercion extends TypeCoercionBase { ImplicitTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: - StringLiteralCoercion :: - Nil + StringLiteralCoercion :: Nil) :: Nil override def canCast(from: DataType, to: DataType): Boolean = Cast.canCast(from, to) @@ -1057,8 +1064,7 @@ object TypeCoercion extends TypeCoercionBase { } } - override protected def coerceTypes( - plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override def transform: PartialFunction[Expression, Expression] = { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -1104,11 +1110,11 @@ object TypeCoercion extends TypeCoercionBase { /** * Changes numeric values to booleans so that expressions like true = 1 can be evaluated. */ - object BooleanEquality extends Rule[LogicalPlan] { + object BooleanEquality extends TypeCoercionRule { private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal.ONE) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal.ZERO) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override def transform: PartialFunction[Expression, Expression] = { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -1152,39 +1158,46 @@ trait TypeCoercionRule extends Rule[LogicalPlan] with Logging { * to instances higher in the query tree. */ def apply(plan: LogicalPlan): LogicalPlan = { - val newPlan = coerceTypes(plan) - if (plan.fastEquals(newPlan)) { - plan - } else { - propagateTypes(newPlan) + val typeCoercionFn = transform + def rewrite(plan: LogicalPlan): LogicalPlan = { + val withNewChildren = plan.mapChildren(rewrite) + if (!withNewChildren.childrenResolved) { + withNewChildren + } else { + // Only propagate types if the children have changed. + val withPropagatedTypes = if (withNewChildren ne plan) { + propagateTypes(withNewChildren) + } else { + plan + } + withPropagatedTypes.transformExpressionsUp(typeCoercionFn) + } } + rewrite(plan) } - protected def coerceTypes(plan: LogicalPlan): LogicalPlan + def transform: PartialFunction[Expression, Expression] - private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { - // No propagation required for leaf nodes. - case q: LogicalPlan if q.children.isEmpty => q - - // Don't propagate types from unresolved children. - case q: LogicalPlan if !q.childrenResolved => q - - case q: LogicalPlan => - val inputMap = q.inputSet.toSeq.map(a => (a.exprId, a)).toMap - q transformExpressions { + private def propagateTypes(plan: LogicalPlan): LogicalPlan = { + // Check if the inputs have changed. + val references = AttributeMap(plan.references.collect { + case a if a.resolved => a -> a + }.toSeq) + def sameButDifferent(a: Attribute): Boolean = { + references.get(a).exists(b => b.dataType != a.dataType || b.nullable != a.nullable) + } + val inputMap = AttributeMap(plan.inputSet.collect { + case a if a.resolved && sameButDifferent(a) => a -> a + }.toSeq) + if (inputMap.isEmpty) { + // Nothing changed. + plan + } else { + // Update the references if the dataType/nullability has changed. + plan transformExpressions { case a: AttributeReference => - inputMap.get(a.exprId) match { - // This can happen when an Attribute reference is born in a non-leaf node, for - // example due to a call to an external script like in the Transform operator. - // TODO: Perhaps those should actually be aliases? - case None => a - // Leave the same if the dataTypes match. - case Some(newType) if a.dataType == newType.dataType => a - case Some(newType) => - logDebug(s"Promoting $a from ${a.dataType} to ${newType.dataType} in " + - s" ${q.simpleString(conf.maxToStringFields)}") - newType - } + inputMap.getOrElse(a, a) } + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index b53f87c8b14e6..4fc0256bce23c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -85,7 +85,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit // We can't directly use `/` as it throws an exception under ansi mode. override lazy val evaluateExpression = child.dataType match { case _: DecimalType => - DecimalPrecision.decimalAndDecimal( + DecimalPrecision.decimalAndDecimal()( Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType) case _: YearMonthIntervalType => DivideYMInterval(sum, count) case _: DayTimeIntervalType => DivideDTInterval(sum, count) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 795af05078a64..e4e546aa158b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -343,9 +343,12 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] }.asInstanceOf[PlanType] } - private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { - val exprId = attrMap.getOrElse(attr, attr).exprId - attr.withExprId(exprId) + private def updateAttr(a: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { + attrMap.get(a) match { + case Some(b) => + AttributeReference(a.name, b.dataType, b.nullable, a.metadata)(b.exprId, a.qualifier) + case None => a + } } /** diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out index 7955f5a7e878e..8a4ee142011ce 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out @@ -4673,7 +4673,7 @@ struct<(CAST(999999999999999999999 AS DECIMAL(38,0)) div 1000000000000000000000) -- !query select mod(cast(999999999999999999999 as decimal(38, 0)),1000000000000000000000) -- !query schema -struct +struct<(CAST(999999999999999999999 AS DECIMAL(38,0)) % 1000000000000000000000):decimal(22,0)> -- !query output 999999999999999999999 @@ -4689,7 +4689,7 @@ struct<(CAST(-9999999999999999999999 AS DECIMAL(38,0)) div 100000000000000000000 -- !query select mod(cast(-9999999999999999999999 as decimal(38, 0)),1000000000000000000000) -- !query schema -struct +struct<(CAST(-9999999999999999999999 AS DECIMAL(38,0)) % 1000000000000000000000):decimal(22,0)> -- !query output -999999999999999999999 @@ -4697,7 +4697,7 @@ struct +struct<(((CAST(-9999999999999999999999 AS DECIMAL(38,0)) div 1000000000000000000000) * 1000000000000000000000) + (CAST(-9999999999999999999999 AS DECIMAL(38,0)) % 1000000000000000000000)):decimal(38,0)> -- !query output -9999999999999999999999 @@ -4705,7 +4705,7 @@ struct<(((CAST(-9999999999999999999999 AS DECIMAL(38,0)) div 1000000000000000000 -- !query select mod (70.0,70) -- !query schema -struct +struct<(70.0 % 70):decimal(3,1)> -- !query output 0.0 diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q77a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q77a.sf100/explain.txt index 1f4f137f42c4a..4c884e185904f 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q77a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q77a.sf100/explain.txt @@ -523,55 +523,55 @@ Output [8]: [channel#36, id#37, sum#126, isEmpty#127, sum#128, isEmpty#129, sum# (89) HashAggregate [codegen id : 48] Input [8]: [channel#36, id#37, sum#126, isEmpty#127, sum#128, isEmpty#129, sum#130, isEmpty#131] Keys [2]: [channel#36, id#37] -Functions [3]: [sum(sales#18), sum(returns#38), sum(profit#132)] -Aggregate Attributes [3]: [sum(sales#18)#133, sum(returns#38)#134, sum(profit#132)#135] -Results [4]: [channel#36, sum(sales#18)#133 AS sales#136, sum(returns#38)#134 AS returns#137, sum(profit#132)#135 AS profit#138] +Functions [3]: [sum(sales#18), sum(returns#38), sum(profit#39)] +Aggregate Attributes [3]: [sum(sales#18)#132, sum(returns#38)#133, sum(profit#39)#134] +Results [4]: [channel#36, sum(sales#18)#132 AS sales#135, sum(returns#38)#133 AS returns#136, sum(profit#39)#134 AS profit#137] (90) HashAggregate [codegen id : 48] -Input [4]: [channel#36, sales#136, returns#137, profit#138] +Input [4]: [channel#36, sales#135, returns#136, profit#137] Keys [1]: [channel#36] -Functions [3]: [partial_sum(sales#136), partial_sum(returns#137), partial_sum(profit#138)] -Aggregate Attributes [6]: [sum#139, isEmpty#140, sum#141, isEmpty#142, sum#143, isEmpty#144] -Results [7]: [channel#36, sum#145, isEmpty#146, sum#147, isEmpty#148, sum#149, isEmpty#150] +Functions [3]: [partial_sum(sales#135), partial_sum(returns#136), partial_sum(profit#137)] +Aggregate Attributes [6]: [sum#138, isEmpty#139, sum#140, isEmpty#141, sum#142, isEmpty#143] +Results [7]: [channel#36, sum#144, isEmpty#145, sum#146, isEmpty#147, sum#148, isEmpty#149] (91) Exchange -Input [7]: [channel#36, sum#145, isEmpty#146, sum#147, isEmpty#148, sum#149, isEmpty#150] -Arguments: hashpartitioning(channel#36, 5), ENSURE_REQUIREMENTS, [id=#151] +Input [7]: [channel#36, sum#144, isEmpty#145, sum#146, isEmpty#147, sum#148, isEmpty#149] +Arguments: hashpartitioning(channel#36, 5), ENSURE_REQUIREMENTS, [id=#150] (92) HashAggregate [codegen id : 49] -Input [7]: [channel#36, sum#145, isEmpty#146, sum#147, isEmpty#148, sum#149, isEmpty#150] +Input [7]: [channel#36, sum#144, isEmpty#145, sum#146, isEmpty#147, sum#148, isEmpty#149] Keys [1]: [channel#36] -Functions [3]: [sum(sales#136), sum(returns#137), sum(profit#138)] -Aggregate Attributes [3]: [sum(sales#136)#152, sum(returns#137)#153, sum(profit#138)#154] -Results [5]: [channel#36, null AS id#155, sum(sales#136)#152 AS sales#156, sum(returns#137)#153 AS returns#157, sum(profit#138)#154 AS profit#158] +Functions [3]: [sum(sales#135), sum(returns#136), sum(profit#137)] +Aggregate Attributes [3]: [sum(sales#135)#151, sum(returns#136)#152, sum(profit#137)#153] +Results [5]: [channel#36, null AS id#154, sum(sales#135)#151 AS sales#155, sum(returns#136)#152 AS returns#156, sum(profit#137)#153 AS profit#157] (93) ReusedExchange [Reuses operator id: unknown] -Output [8]: [channel#36, id#37, sum#159, isEmpty#160, sum#161, isEmpty#162, sum#163, isEmpty#164] +Output [8]: [channel#36, id#37, sum#158, isEmpty#159, sum#160, isEmpty#161, sum#162, isEmpty#163] (94) HashAggregate [codegen id : 73] -Input [8]: [channel#36, id#37, sum#159, isEmpty#160, sum#161, isEmpty#162, sum#163, isEmpty#164] +Input [8]: [channel#36, id#37, sum#158, isEmpty#159, sum#160, isEmpty#161, sum#162, isEmpty#163] Keys [2]: [channel#36, id#37] -Functions [3]: [sum(sales#18), sum(returns#38), sum(profit#165)] -Aggregate Attributes [3]: [sum(sales#18)#166, sum(returns#38)#167, sum(profit#165)#168] -Results [3]: [sum(sales#18)#166 AS sales#136, sum(returns#38)#167 AS returns#137, sum(profit#165)#168 AS profit#138] +Functions [3]: [sum(sales#18), sum(returns#38), sum(profit#39)] +Aggregate Attributes [3]: [sum(sales#18)#164, sum(returns#38)#165, sum(profit#39)#166] +Results [3]: [sum(sales#18)#164 AS sales#135, sum(returns#38)#165 AS returns#136, sum(profit#39)#166 AS profit#137] (95) HashAggregate [codegen id : 73] -Input [3]: [sales#136, returns#137, profit#138] +Input [3]: [sales#135, returns#136, profit#137] Keys: [] -Functions [3]: [partial_sum(sales#136), partial_sum(returns#137), partial_sum(profit#138)] -Aggregate Attributes [6]: [sum#169, isEmpty#170, sum#171, isEmpty#172, sum#173, isEmpty#174] -Results [6]: [sum#175, isEmpty#176, sum#177, isEmpty#178, sum#179, isEmpty#180] +Functions [3]: [partial_sum(sales#135), partial_sum(returns#136), partial_sum(profit#137)] +Aggregate Attributes [6]: [sum#167, isEmpty#168, sum#169, isEmpty#170, sum#171, isEmpty#172] +Results [6]: [sum#173, isEmpty#174, sum#175, isEmpty#176, sum#177, isEmpty#178] (96) Exchange -Input [6]: [sum#175, isEmpty#176, sum#177, isEmpty#178, sum#179, isEmpty#180] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#181] +Input [6]: [sum#173, isEmpty#174, sum#175, isEmpty#176, sum#177, isEmpty#178] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#179] (97) HashAggregate [codegen id : 74] -Input [6]: [sum#175, isEmpty#176, sum#177, isEmpty#178, sum#179, isEmpty#180] +Input [6]: [sum#173, isEmpty#174, sum#175, isEmpty#176, sum#177, isEmpty#178] Keys: [] -Functions [3]: [sum(sales#136), sum(returns#137), sum(profit#138)] -Aggregate Attributes [3]: [sum(sales#136)#182, sum(returns#137)#183, sum(profit#138)#184] -Results [5]: [null AS channel#185, null AS id#186, sum(sales#136)#182 AS sales#187, sum(returns#137)#183 AS returns#188, sum(profit#138)#184 AS profit#189] +Functions [3]: [sum(sales#135), sum(returns#136), sum(profit#137)] +Aggregate Attributes [3]: [sum(sales#135)#180, sum(returns#136)#181, sum(profit#137)#182] +Results [5]: [null AS channel#183, null AS id#184, sum(sales#135)#180 AS sales#185, sum(returns#136)#181 AS returns#186, sum(profit#137)#182 AS profit#187] (98) Union @@ -584,7 +584,7 @@ Results [5]: [channel#36, id#37, sales#123, returns#124, profit#125] (100) Exchange Input [5]: [channel#36, id#37, sales#123, returns#124, profit#125] -Arguments: hashpartitioning(channel#36, id#37, sales#123, returns#124, profit#125, 5), ENSURE_REQUIREMENTS, [id=#190] +Arguments: hashpartitioning(channel#36, id#37, sales#123, returns#124, profit#125, 5), ENSURE_REQUIREMENTS, [id=#188] (101) HashAggregate [codegen id : 76] Input [5]: [channel#36, id#37, sales#123, returns#124, profit#125] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q77a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q77a/explain.txt index 42c9e941756c3..c74b44df70a65 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q77a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q77a/explain.txt @@ -523,55 +523,55 @@ Output [8]: [channel#36, id#37, sum#126, isEmpty#127, sum#128, isEmpty#129, sum# (89) HashAggregate [codegen id : 48] Input [8]: [channel#36, id#37, sum#126, isEmpty#127, sum#128, isEmpty#129, sum#130, isEmpty#131] Keys [2]: [channel#36, id#37] -Functions [3]: [sum(sales#18), sum(returns#38), sum(profit#132)] -Aggregate Attributes [3]: [sum(sales#18)#133, sum(returns#38)#134, sum(profit#132)#135] -Results [4]: [channel#36, sum(sales#18)#133 AS sales#136, sum(returns#38)#134 AS returns#137, sum(profit#132)#135 AS profit#138] +Functions [3]: [sum(sales#18), sum(returns#38), sum(profit#39)] +Aggregate Attributes [3]: [sum(sales#18)#132, sum(returns#38)#133, sum(profit#39)#134] +Results [4]: [channel#36, sum(sales#18)#132 AS sales#135, sum(returns#38)#133 AS returns#136, sum(profit#39)#134 AS profit#137] (90) HashAggregate [codegen id : 48] -Input [4]: [channel#36, sales#136, returns#137, profit#138] +Input [4]: [channel#36, sales#135, returns#136, profit#137] Keys [1]: [channel#36] -Functions [3]: [partial_sum(sales#136), partial_sum(returns#137), partial_sum(profit#138)] -Aggregate Attributes [6]: [sum#139, isEmpty#140, sum#141, isEmpty#142, sum#143, isEmpty#144] -Results [7]: [channel#36, sum#145, isEmpty#146, sum#147, isEmpty#148, sum#149, isEmpty#150] +Functions [3]: [partial_sum(sales#135), partial_sum(returns#136), partial_sum(profit#137)] +Aggregate Attributes [6]: [sum#138, isEmpty#139, sum#140, isEmpty#141, sum#142, isEmpty#143] +Results [7]: [channel#36, sum#144, isEmpty#145, sum#146, isEmpty#147, sum#148, isEmpty#149] (91) Exchange -Input [7]: [channel#36, sum#145, isEmpty#146, sum#147, isEmpty#148, sum#149, isEmpty#150] -Arguments: hashpartitioning(channel#36, 5), ENSURE_REQUIREMENTS, [id=#151] +Input [7]: [channel#36, sum#144, isEmpty#145, sum#146, isEmpty#147, sum#148, isEmpty#149] +Arguments: hashpartitioning(channel#36, 5), ENSURE_REQUIREMENTS, [id=#150] (92) HashAggregate [codegen id : 49] -Input [7]: [channel#36, sum#145, isEmpty#146, sum#147, isEmpty#148, sum#149, isEmpty#150] +Input [7]: [channel#36, sum#144, isEmpty#145, sum#146, isEmpty#147, sum#148, isEmpty#149] Keys [1]: [channel#36] -Functions [3]: [sum(sales#136), sum(returns#137), sum(profit#138)] -Aggregate Attributes [3]: [sum(sales#136)#152, sum(returns#137)#153, sum(profit#138)#154] -Results [5]: [channel#36, null AS id#155, sum(sales#136)#152 AS sales#156, sum(returns#137)#153 AS returns#157, sum(profit#138)#154 AS profit#158] +Functions [3]: [sum(sales#135), sum(returns#136), sum(profit#137)] +Aggregate Attributes [3]: [sum(sales#135)#151, sum(returns#136)#152, sum(profit#137)#153] +Results [5]: [channel#36, null AS id#154, sum(sales#135)#151 AS sales#155, sum(returns#136)#152 AS returns#156, sum(profit#137)#153 AS profit#157] (93) ReusedExchange [Reuses operator id: unknown] -Output [8]: [channel#36, id#37, sum#159, isEmpty#160, sum#161, isEmpty#162, sum#163, isEmpty#164] +Output [8]: [channel#36, id#37, sum#158, isEmpty#159, sum#160, isEmpty#161, sum#162, isEmpty#163] (94) HashAggregate [codegen id : 73] -Input [8]: [channel#36, id#37, sum#159, isEmpty#160, sum#161, isEmpty#162, sum#163, isEmpty#164] +Input [8]: [channel#36, id#37, sum#158, isEmpty#159, sum#160, isEmpty#161, sum#162, isEmpty#163] Keys [2]: [channel#36, id#37] -Functions [3]: [sum(sales#18), sum(returns#38), sum(profit#165)] -Aggregate Attributes [3]: [sum(sales#18)#166, sum(returns#38)#167, sum(profit#165)#168] -Results [3]: [sum(sales#18)#166 AS sales#136, sum(returns#38)#167 AS returns#137, sum(profit#165)#168 AS profit#138] +Functions [3]: [sum(sales#18), sum(returns#38), sum(profit#39)] +Aggregate Attributes [3]: [sum(sales#18)#164, sum(returns#38)#165, sum(profit#39)#166] +Results [3]: [sum(sales#18)#164 AS sales#135, sum(returns#38)#165 AS returns#136, sum(profit#39)#166 AS profit#137] (95) HashAggregate [codegen id : 73] -Input [3]: [sales#136, returns#137, profit#138] +Input [3]: [sales#135, returns#136, profit#137] Keys: [] -Functions [3]: [partial_sum(sales#136), partial_sum(returns#137), partial_sum(profit#138)] -Aggregate Attributes [6]: [sum#169, isEmpty#170, sum#171, isEmpty#172, sum#173, isEmpty#174] -Results [6]: [sum#175, isEmpty#176, sum#177, isEmpty#178, sum#179, isEmpty#180] +Functions [3]: [partial_sum(sales#135), partial_sum(returns#136), partial_sum(profit#137)] +Aggregate Attributes [6]: [sum#167, isEmpty#168, sum#169, isEmpty#170, sum#171, isEmpty#172] +Results [6]: [sum#173, isEmpty#174, sum#175, isEmpty#176, sum#177, isEmpty#178] (96) Exchange -Input [6]: [sum#175, isEmpty#176, sum#177, isEmpty#178, sum#179, isEmpty#180] -Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#181] +Input [6]: [sum#173, isEmpty#174, sum#175, isEmpty#176, sum#177, isEmpty#178] +Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#179] (97) HashAggregate [codegen id : 74] -Input [6]: [sum#175, isEmpty#176, sum#177, isEmpty#178, sum#179, isEmpty#180] +Input [6]: [sum#173, isEmpty#174, sum#175, isEmpty#176, sum#177, isEmpty#178] Keys: [] -Functions [3]: [sum(sales#136), sum(returns#137), sum(profit#138)] -Aggregate Attributes [3]: [sum(sales#136)#182, sum(returns#137)#183, sum(profit#138)#184] -Results [5]: [null AS channel#185, null AS id#186, sum(sales#136)#182 AS sales#187, sum(returns#137)#183 AS returns#188, sum(profit#138)#184 AS profit#189] +Functions [3]: [sum(sales#135), sum(returns#136), sum(profit#137)] +Aggregate Attributes [3]: [sum(sales#135)#180, sum(returns#136)#181, sum(profit#137)#182] +Results [5]: [null AS channel#183, null AS id#184, sum(sales#135)#180 AS sales#185, sum(returns#136)#181 AS returns#186, sum(profit#137)#182 AS profit#187] (98) Union @@ -584,7 +584,7 @@ Results [5]: [channel#36, id#37, sales#123, returns#124, profit#125] (100) Exchange Input [5]: [channel#36, id#37, sales#123, returns#124, profit#125] -Arguments: hashpartitioning(channel#36, id#37, sales#123, returns#124, profit#125, 5), ENSURE_REQUIREMENTS, [id=#190] +Arguments: hashpartitioning(channel#36, id#37, sales#123, returns#124, profit#125, 5), ENSURE_REQUIREMENTS, [id=#188] (101) HashAggregate [codegen id : 76] Input [5]: [channel#36, id#37, sales#123, returns#124, profit#125]