diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 73c9a1c7afdad..991dad90878f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -26,7 +26,14 @@ object HiveTypeCoercion { // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: private val numericPrecedence = - Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited) + IndexedSeq( + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DecimalType.Unlimited) /** * Find the tightest common type of two types that might be used in a binary expression. @@ -34,25 +41,21 @@ object HiveTypeCoercion { * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]]. */ - def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { - val valueTypes = Seq(t1, t2).filter(t => t != NullType) - if (valueTypes.distinct.size > 1) { - // Promote numeric types to the highest of the two and all numeric types to unlimited decimal - if (numericPrecedence.contains(t1) && numericPrecedence.contains(t2)) { - Some(numericPrecedence.filter(t => t == t1 || t == t2).last) - } else if (t1.isInstanceOf[DecimalType] && t2.isInstanceOf[DecimalType]) { - // Fixed-precision decimals can up-cast into unlimited - if (t1 == DecimalType.Unlimited || t2 == DecimalType.Unlimited) { - Some(DecimalType.Unlimited) - } else { - None - } - } else { - None - } - } else { - Some(if (valueTypes.size == 0) NullType else valueTypes.head) - } + val findTightestCommonType: (DataType, DataType) => Option[DataType] = { + case (t1, t2) if t1 == t2 => Some(t1) + case (NullType, t1) => Some(t1) + case (t1, NullType) => Some(t1) + + // Promote numeric types to the highest of the two and all numeric types to unlimited decimal + case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => + val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) + Some(numericPrecedence(index)) + + // Fixed-precision decimals can up-cast into unlimited + case (DecimalType.Unlimited, _: DecimalType) => Some(DecimalType.Unlimited) + case (_: DecimalType, DecimalType.Unlimited) => Some(DecimalType.Unlimited) + + case _ => None } }