forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #17 from marmbrus/unionTypes
Basic type promotion when the types of union-ed expressions are different.
- Loading branch information
Showing
5 changed files
with
268 additions
and
218 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
253 changes: 253 additions & 0 deletions
253
src/main/scala/catalyst/analysis/HiveTypeCoercion.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
package catalyst | ||
package analysis | ||
|
||
import expressions._ | ||
import plans.logical._ | ||
import rules._ | ||
import types._ | ||
import catalyst.execution.{HiveUdf, HiveGenericUdf} | ||
|
||
/** | ||
* A collection of [[catalyst.rules.Rule Rules]] that can be used to coerce differing types that | ||
* participate in operations into compatible ones. Most of these rules are based on Hive semantics, | ||
* but they do not introduce any dependencies on the hive codebase. For this reason they remain in | ||
* Catalyst until we have a more standard set of coercions. | ||
*/ | ||
trait HiveTypeCoercion { | ||
|
||
val typeCoercionRules = | ||
List(PropagateTypes, ConvertNaNs, WidenTypes, PromoteStrings, BooleanComparisons, BooleanCasts, | ||
StringToIntegralCasts, FunctionArgumentConversion) | ||
|
||
/** | ||
* Applies any changes to [[catalyst.expressions.AttributeReference AttributeReference]] dataTypes | ||
* that are made by other rules to instances higher in the query tree. | ||
*/ | ||
object PropagateTypes extends Rule[LogicalPlan] { | ||
def apply(plan: LogicalPlan): LogicalPlan = plan transform { | ||
// No propagation required for leaf nodes. | ||
case q: LogicalPlan if q.children.isEmpty => q | ||
|
||
case q: LogicalPlan => q transformExpressions { | ||
case a: AttributeReference => | ||
q.inputSet.find(_.exprId == a.exprId) match { | ||
// This can happen when a Attribute reference is born in a non-leaf node, for example | ||
// due to a call to an external script like in the Transform operator. | ||
// TODO: Perhaps those should actually be aliases? | ||
case None => a | ||
// Leave the same if the dataTypes match. | ||
case Some(newType) if a.dataType == newType.dataType => a | ||
case Some(newType) => | ||
logger.debug(s"Promoting $a to ${newType} in ${q.simpleString}}") | ||
newType | ||
} | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to | ||
* the appropriate numeric equivalent. | ||
*/ | ||
object ConvertNaNs extends Rule[LogicalPlan] { | ||
val stringNaN = Literal("NaN", StringType) | ||
|
||
def apply(plan: LogicalPlan): LogicalPlan = plan transform { | ||
case q: LogicalPlan => q transformExpressions { | ||
// Skip nodes who's children have not been resolved yet. | ||
case e if !e.childrenResolved => e | ||
|
||
/* Double Conversions */ | ||
case b: BinaryExpression if b.left == stringNaN && b.right.dataType == DoubleType => | ||
b.makeCopy(Array(b.right, Literal(Double.NaN))) | ||
case b: BinaryExpression if b.left.dataType == DoubleType && b.right == stringNaN => | ||
b.makeCopy(Array(Literal(Double.NaN), b.left)) | ||
case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN => | ||
b.makeCopy(Array(Literal(Double.NaN), b.left)) | ||
|
||
/* Float Conversions */ | ||
case b: BinaryExpression if b.left == stringNaN && b.right.dataType == FloatType => | ||
b.makeCopy(Array(b.right, Literal(Float.NaN))) | ||
case b: BinaryExpression if b.left.dataType == FloatType && b.right == stringNaN => | ||
b.makeCopy(Array(Literal(Float.NaN), b.left)) | ||
case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN => | ||
b.makeCopy(Array(Literal(Float.NaN), b.left)) | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Widens numeric types and converts strings to numbers when appropriate. | ||
* | ||
* Loosely based on rules from "Hadoop: The Definitive Guide" 2nd edition, by Tom White | ||
* | ||
* The implicit conversion rules can be summarized as follows: | ||
* - Any integral numeric type can be implicitly converted to a wider type. | ||
* - All the integral numeric types, FLOAT, and (perhaps surprisingly) STRING can be implicitly | ||
* converted to DOUBLE. | ||
* - TINYINT, SMALLINT, and INT can all be converted to FLOAT. | ||
* - BOOLEAN types cannot be converted to any other type. | ||
* | ||
* Additionally, all types when UNION-ed with strings will be promoted to strings. | ||
* Other string conversions are handled by PromoteStrings | ||
*/ | ||
object WidenTypes extends Rule[LogicalPlan] { | ||
val integralPrecedence = Seq(NullType, ByteType, ShortType, IntegerType, LongType) | ||
val toDouble = integralPrecedence ++ Seq(NullType, FloatType, DoubleType) | ||
val toFloat = Seq(NullType, ByteType, ShortType, IntegerType) :+ FloatType | ||
val allPromotions: Seq[Seq[DataType]] = integralPrecedence :: toDouble :: toFloat :: Nil | ||
|
||
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { | ||
// Try and find a promotion rule that contains both types in question. | ||
val applicableConversion = | ||
allPromotions.find(p => p.contains(t1) && p.contains(t2)) | ||
|
||
// If found return the widest common type, otherwise None | ||
applicableConversion.map(_.filter(t => t == t1 || t == t2).last) | ||
} | ||
|
||
def apply(plan: LogicalPlan): LogicalPlan = plan transform { | ||
case u @ Union(left, right) if u.childrenResolved && !u.resolved => | ||
val castedInput = left.output.zip(right.output).map { | ||
// When a string is found on one side, make the other side a string too. | ||
case (l,r) if l.dataType == StringType && r.dataType != StringType => | ||
(l, Alias(Cast(r, StringType), r.name)()) | ||
case (l,r) if l.dataType != StringType && r.dataType == StringType => | ||
(Alias(Cast(l, StringType), l.name)(), r) | ||
|
||
case (l,r) if l.dataType != r.dataType => | ||
logger.debug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") | ||
findTightestCommonType(l.dataType, r.dataType).map { widestType => | ||
val newLeft = | ||
if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() | ||
val newRight = | ||
if (r.dataType == widestType) r else Alias(Cast(r, widestType), r.name)() | ||
|
||
(newLeft, newRight) | ||
}.getOrElse((l,r)) // If there is no applicable conversion, leave expression unchanged. | ||
case other => other | ||
} | ||
|
||
val (castedLeft, castedRight) = castedInput.unzip | ||
|
||
val newLeft = | ||
if(castedLeft.map(_.dataType) != left.output.map(_.dataType)) { | ||
logger.debug(s"Widening numeric types in union $castedLeft ${left.output}") | ||
Project(castedLeft, left) | ||
} else { | ||
left | ||
} | ||
|
||
val newRight = | ||
if(castedRight.map(_.dataType) != right.output.map(_.dataType)) { | ||
logger.debug(s"Widening numeric types in union $castedRight ${right.output}") | ||
Project(castedRight, right) | ||
} else { | ||
right | ||
} | ||
|
||
Union(newLeft, newRight) | ||
|
||
// Also widen types for BinaryExpressions. | ||
case q: LogicalPlan => q transformExpressions { | ||
// Skip nodes who's children have not been resolved yet. | ||
case e if !e.childrenResolved => e | ||
|
||
case b: BinaryExpression if b.left.dataType != b.right.dataType => | ||
findTightestCommonType(b.left.dataType, b.right.dataType).map { widestType => | ||
val newLeft = | ||
if (b.left.dataType == widestType) b.left else Cast(b.left, widestType) | ||
val newRight = | ||
if (b.right.dataType == widestType) b.right else Cast(b.right, widestType) | ||
b.makeCopy(Array(newLeft, newRight)) | ||
}.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Promotes strings that appear in arithmetic expressions. | ||
*/ | ||
object PromoteStrings extends Rule[LogicalPlan] { | ||
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { | ||
// Skip nodes who's children have not been resolved yet. | ||
case e if !e.childrenResolved => e | ||
|
||
case a: BinaryArithmetic if a.left.dataType == StringType => | ||
a.makeCopy(Array(Cast(a.left, DoubleType), a.right)) | ||
case a: BinaryArithmetic if a.right.dataType == StringType => | ||
a.makeCopy(Array(a.left, Cast(a.right, DoubleType))) | ||
|
||
case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType => | ||
p.makeCopy(Array(Cast(p.left, DoubleType), p.right)) | ||
case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType => | ||
p.makeCopy(Array(p.left, Cast(p.right, DoubleType))) | ||
|
||
case Sum(e) if e.dataType == StringType => | ||
Sum(Cast(e, DoubleType)) | ||
case Average(e) if e.dataType == StringType => | ||
Sum(Cast(e, DoubleType)) | ||
} | ||
} | ||
|
||
/** | ||
* Changes Boolean values to Bytes so that expressions like true < false can be Evaluated. | ||
*/ | ||
object BooleanComparisons extends Rule[LogicalPlan] { | ||
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { | ||
// Skip nodes who's children have not been resolved yet. | ||
case e if !e.childrenResolved => e | ||
// No need to change Equals operators as that actually makes sense for boolean types. | ||
case e: Equals => e | ||
// Otherwise turn them to Byte types so that there exists and ordering. | ||
case p: BinaryComparison if p.left.dataType == BooleanType && p.right.dataType == BooleanType => | ||
p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType))) | ||
} | ||
} | ||
|
||
/** | ||
* Casts to/from [[catalyst.types.BooleanType BooleanType]] are transformed into comparisons since | ||
* the JVM does not consider Booleans to be numeric types. | ||
*/ | ||
object BooleanCasts extends Rule[LogicalPlan] { | ||
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { | ||
// Skip nodes who's children have not been resolved yet. | ||
case e if !e.childrenResolved => e | ||
|
||
case Cast(e, BooleanType) => Not(Equals(e, Literal(0))) | ||
case Cast(e, dataType) if e.dataType == BooleanType => | ||
Cast(If(e, Literal(1), Literal(0)), dataType) | ||
} | ||
} | ||
|
||
/** | ||
* When encountering a cast from a string representing a valid fractional number to an integral | ||
* type the jvm will throw a `java.lang.NumberFormatException`. Hive, in contrast, returns the | ||
* truncated version of this number. | ||
*/ | ||
object StringToIntegralCasts extends Rule[LogicalPlan] { | ||
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { | ||
// Skip nodes who's children have not been resolved yet. | ||
case e if !e.childrenResolved => e | ||
|
||
case Cast(e @ StringType(), t: IntegralType) => | ||
Cast(Cast(e, DecimalType), t) | ||
} | ||
} | ||
|
||
/** | ||
* This ensure that the types for various functions are as expected. | ||
*/ | ||
object FunctionArgumentConversion extends Rule[LogicalPlan] { | ||
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { | ||
// Skip nodes who's children have not been resolved yet. | ||
case e if !e.childrenResolved => e | ||
|
||
// Promote SUM to largest types to prevent overflows. | ||
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest. | ||
case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType)) | ||
case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType)) | ||
|
||
} | ||
} | ||
} |
Oops, something went wrong.