diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d31e004b9c348..634750dca2158 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -435,37 +435,57 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w if (evaluated == null) null else cast(evaluated) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = this match { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + // TODO(cg): Add support for more data types. + (child.dataType, dataType) match { - case Cast(child @ BinaryType(), StringType) => - castOrNull (ctx, ev, c => - s"new ${ctx.stringType}().set($c)") + case (BinaryType, StringType) => + defineCodeGen (ctx, ev, c => + s"new ${ctx.stringType}().set($c)") - case Cast(child @ DateType(), StringType) => - castOrNull(ctx, ev, c => - s"""new ${ctx.stringType}().set( + case (DateType, StringType) => + defineCodeGen(ctx, ev, c => + s"""new ${ctx.stringType}().set( org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""") - case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c?1:0)") + case (BooleanType, dt: NumericType) if !dt.isInstanceOf[DecimalType] => + defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c ? 1 : 0)") - case Cast(child @ DecimalType(), IntegerType) => - castOrNull(ctx, ev, c => s"($c).toInt()") + case (_: NumericType, dt: NumericType) if !dt.isInstanceOf[DecimalType] => + defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)") - case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - castOrNull(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()") + case (_: DecimalType, ByteType) => + defineCodeGen(ctx, ev, c => s"($c).toByte()") - case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - castOrNull(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)") + case (_: DecimalType, ShortType) => + defineCodeGen(ctx, ev, c => s"($c).toShort()") - // Special handling required for timestamps in hive test cases since the toString function - // does not match the expected output. - case Cast(e, StringType) if e.dataType != TimestampType => - castOrNull(ctx, ev, c => - s"new ${ctx.stringType}().set(String.valueOf($c))") + case (_: DecimalType, IntegerType) => + defineCodeGen(ctx, ev, c => s"($c).toInt()") - case other => - super.genCode(ctx, ev) + case (_: DecimalType, LongType) => + defineCodeGen(ctx, ev, c => s"($c).toLong()") + + case (_: DecimalType, FloatType) => + defineCodeGen(ctx, ev, c => s"($c).toFloat()") + + case (_: DecimalType, DoubleType) => + defineCodeGen(ctx, ev, c => s"($c).toDouble()") + + case (_: DecimalType, dt: NumericType) if !dt.isInstanceOf[DecimalType] => + defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()") + + // Special handling required for timestamps in hive test cases since the toString function + // does not match the expected output. + case (TimestampType, StringType) => + super.genCode(ctx, ev) + + case (_, StringType) => + defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))") + + case other => + super.genCode(ctx, ev) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 1f1a2fc9694af..db085c8c277ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -69,7 +69,9 @@ abstract class Expression extends TreeNode[Expression] { } /** - * Returns Java source code for this expression. + * Returns Java source code that can be compiled to evaluate this expression. + * The default behavior is to call the eval method of the expression. Concrete expression + * implementations should override this to do actual code generation. * * @param ctx a [[CodeGenContext]] * @param ev an [[GeneratedExpressionCode]] with unique terms. @@ -82,10 +84,10 @@ abstract class Expression extends TreeNode[Expression] { /* expression: ${this} */ Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i); boolean ${ev.nullTerm} = ${ev.objectTerm} == null; - ${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = - ${ctx.defaultValue(e.dataType)}; - if (!${ev.nullTerm}) ${ev.primitiveTerm} = - (${ctx.boxedType(e.dataType)})${ev.objectTerm}; + ${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(e.dataType)}; + if (!${ev.nullTerm}) { + ${ev.primitiveTerm} = (${ctx.boxedType(e.dataType)})${ev.objectTerm}; + } """ } @@ -155,17 +157,17 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def toString: String = s"($left $symbol $right)" - /** * Short hand for generating binary evaluation code, which depends on two sub-evaluations of * the same type. If either of the sub-expressions is null, the result of this computation * is assumed to be null. * - * @param f a function from two primitive term names to a tree that evaluates them. + * @param f accepts two variable names and returns Java code to compute the output. */ - def evaluate(ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String) => String): String = { + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { // TODO: Right now some timestamp tests fail if we enforce this... if (left.dataType != right.dataType) { // log.warn(s"${left.dataType} != ${right.dataType}") @@ -197,9 +199,22 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => - def castOrNull(ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: String => String): String = { + + /** + * Called by unary expressions to generate a code block that returns null if its parent returns + * null, and if not not null, use `f` to generate the expression. + * + * As an example, the following does a boolean inversion (i.e. NOT). + * {{{ + * defineCodeGen(ctx, ev, c => s"!($c)") + * }}} + * + * @param f function that accepts a variable name and returns Java code to compute the output. + */ + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: String => String): String = { val eval = child.gen(ctx) eval.code + s""" boolean ${ev.nullTerm} = ${eval.nullTerm}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index aad8479dafe41..a049f8878ed32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -87,6 +87,7 @@ case class Abs(child: Expression) extends UnaryArithmetic { abstract class BinaryArithmetic extends BinaryExpression { self: Product => + /** Name of the function for this expression on a [[Decimal]] type. */ def decimalMethod: String = "" override def dataType: DataType = left.dataType @@ -119,9 +120,9 @@ abstract class BinaryArithmetic extends BinaryExpression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { if (left.dataType.isInstanceOf[DecimalType]) { - evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1.$decimalMethod($eval2)" } ) + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") } else { - evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1 $symbol $eval2" } ) + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") } } @@ -205,6 +206,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } } + /** + * Special case handling due to division by 0 => null. + */ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -221,8 +225,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic eval1.code + eval2.code + s""" boolean ${ev.nullTerm} = false; - ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = - ${ctx.defaultValue(left.dataType)}; + ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)}; if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { ${ev.nullTerm} = true; } else { @@ -263,6 +266,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } } + /** + * Special case handling for x % 0 ==> null. + */ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -279,8 +285,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet eval1.code + eval2.code + s""" boolean ${ev.nullTerm} = false; - ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = - ${ctx.defaultValue(left.dataType)}; + ${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)}; if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { ${ev.nullTerm} = true; } else { @@ -337,7 +342,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet } /** - * A function that calculates bitwise xor(^) of two numbers. + * A function that calculates bitwise xor of two numbers. */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "^" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 80c51cb3588ad..f1d8313b5f1dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -67,14 +67,13 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un val eval = child.gen(ctx) eval.code + s""" boolean ${ev.nullTerm} = ${eval.nullTerm}; - org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} = - ${ctx.defaultValue(DecimalType())}; + org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} = ${ctx.defaultValue(DecimalType())}; if (!${ev.nullTerm}) { - ${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal(); - ${ev.primitiveTerm} = - ${ev.primitiveTerm}.setOrNull(${eval.primitiveTerm}, $precision, $scale); - ${ev.nullTerm} = ${ev.primitiveTerm} == null; + ${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal(); + ${ev.primitiveTerm} = + ${ev.primitiveTerm}.setOrNull(${eval.primitiveTerm}, $precision, $scale); + ${ev.nullTerm} = ${ev.primitiveTerm} == null; } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 21e21000c9437..1899c47613aae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -88,6 +88,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; """ } else { + // TODO(cg): Add support for more data types. dataType match { case StringType => val v = value.asInstanceOf[UTF8String] @@ -96,12 +97,12 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres final boolean ${ev.nullTerm} = false; ${ctx.stringType} ${ev.primitiveTerm} = new ${ctx.stringType}().set(${arr}); """ - case FloatType => + case FloatType => // This must go before NumericType s""" final boolean ${ev.nullTerm} = false; float ${ev.primitiveTerm} = ${value}f; """ - case dt: DecimalType => + case dt: DecimalType => // This must go before NumericType s""" final boolean ${ev.nullTerm} = false; ${ctx.primitiveType(dt)} ${ev.primitiveTerm} = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 7b26bd2697195..e380eafc3fc2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -61,9 +61,9 @@ case class Coalesce(children: Seq[Expression]) extends Expression { children.map { e => val eval = e.gen(ctx) s""" - if(${ev.nullTerm}) { + if (${ev.nullTerm}) { ${eval.code} - if(!${eval.nullTerm}) { + if (!${eval.nullTerm}) { ${ev.nullTerm} = false; ${ev.primitiveTerm} = ${eval.primitiveTerm}; } @@ -137,9 +137,9 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate val code = children.map { e => val eval = e.gen(ctx) s""" - if($nonnull < $n) { + if ($nonnull < $n) { ${eval.code} - if(!${eval.nullTerm}) { + if (!${eval.nullTerm}) { $nonnull += 1; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ad4535a09e04e..67cac26fd0d55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -85,7 +85,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - castOrNull(ctx, ev, c => s"!($c)") + defineCodeGen(ctx, ev, c => s"!($c)") } } @@ -220,13 +220,13 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { left.dataType match { - case dt: NumericType if ctx.isNativeType(dt) => evaluate (ctx, ev, { + case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, { (c1, c3) => s"$c1 $symbol $c3" }) case TimestampType => // java.sql.Timestamp does not have compare() super.genCode(ctx, ev) - case other => evaluate (ctx, ev, { + case other => defineCodeGen (ctx, ev, { (c1, c2) => s"$c1.compare($c2) $symbol 0" }) } @@ -277,7 +277,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - evaluate(ctx, ev, ctx.equalFunc(left.dataType)) + defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType)) } } @@ -392,7 +392,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi boolean ${ev.nullTerm} = false; ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; ${condEval.code} - if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { + if (!${condEval.nullTerm} && ${condEval.primitiveTerm}) { ${trueEval.code} ${ev.nullTerm} = ${trueEval.nullTerm}; ${ev.primitiveTerm} = ${trueEval.primitiveTerm};