diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 1ffc95c676f6f..1055be6e9d273 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.trees @@ -41,6 +42,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + s""" + final boolean ${ev.nullTerm} = i.isNullAt($ordinal); + final ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ? + ${ctx.defaultPrimitive(dataType)} : (${ctx.getColumn(dataType, ordinal)}); + """ + } } object BindReferences extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 21adac144112e..a986844d18e8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -433,6 +434,42 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val evaluated = child.eval(input) if (evaluated == null) null else cast(evaluated) } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = this match { + + case Cast(child @ BinaryType(), StringType) => + castOrNull (ctx, ev, c => + s"new org.apache.spark.sql.types.UTF8String().set($c)", + StringType) + + case Cast(child @ DateType(), StringType) => + castOrNull(ctx, ev, c => + s"""new org.apache.spark.sql.types.UTF8String().set( + org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""", + StringType) + + case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => + castOrNull(ctx, ev, c => s"(${ctx.primitiveForType(dt)})($c?1:0)", dt) + + case Cast(child @ DecimalType(), IntegerType) => + castOrNull(ctx, ev, c => s"($c).toInt()", IntegerType) + + case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => + castOrNull(ctx, ev, c => s"($c).to${ctx.termForType(dt)}()", dt) + + case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => + castOrNull(ctx, ev, c => s"(${ctx.primitiveForType(dt)})($c)", dt) + + // Special handling required for timestamps in hive test cases since the toString function + // does not match the expected output. + case Cast(e, StringType) if e.dataType != TimestampType => + castOrNull(ctx, ev, c => + s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))", + StringType) + + case other => + super.genSource(ctx, ev) + } } object Cast { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 3cf851aec15ea..f66f8f9ff105e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ @@ -51,6 +52,51 @@ abstract class Expression extends TreeNode[Expression] { /** Returns the result of evaluating this expression on a given input Row */ def eval(input: Row = null): Any + /** + * Returns an [[EvaluatedExpression]], which contains Java source code that + * can be used to generate the result of evaluating the expression on an input row. + * @param ctx a [[CodeGenContext]] + */ + def gen(ctx: CodeGenContext): EvaluatedExpression = { + val nullTerm = ctx.freshName("nullTerm") + val primitiveTerm = ctx.freshName("primitiveTerm") + val objectTerm = ctx.freshName("objectTerm") + val ve = EvaluatedExpression("", nullTerm, primitiveTerm, objectTerm) + ve.code = genSource(ctx, ve) + + // Only inject debugging code if debugging is turned on. + // val debugCode = + // if (debugLogging) { + // val localLogger = log + // val localLoggerTree = reify { localLogger } + // s""" + // $localLoggerTree.debug( + // ${this.toString} + ": " + (if (${ev.nullTerm}) "null" else ${ev.primitiveTerm}.toString)) + // """ + // } else { + // "" + // } + + ve + } + + /** + * Returns Java source code for this expression + */ + def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val e = this.asInstanceOf[Expression] + ctx.references += e + s""" + /* expression: ${this} */ + Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i); + boolean ${ev.nullTerm} = ${ev.objectTerm} == null; + ${ctx.primitiveForType(e.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultPrimitive(e.dataType)}; + if (!${ev.nullTerm}) ${ev.primitiveTerm} = + (${ctx.termForType(e.dataType)})${ev.objectTerm}; + """ + } + /** * Returns `true` if this expression and all its children have been resolved to a specific schema * and input data types checking passed, and `false` if it still contains any unresolved @@ -116,6 +162,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def nullable: Boolean = left.nullable || right.nullable override def toString: String = s"($left $symbol $right)" + + + /** + * Short hand for generating binary evaluation code, which depends on two sub-evaluations of + * the same type. If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f a function from two primitive term names to a tree that evaluates them. + */ + def evaluate(ctx: CodeGenContext, + ev: EvaluatedExpression, + f: (String, String) => String): String = + evaluateAs(left.dataType)(ctx, ev, f) + + def evaluateAs(resultType: DataType)(ctx: CodeGenContext, + ev: EvaluatedExpression, + f: (String, String) => String): String = { + // TODO: Right now some timestamp tests fail if we enforce this... + if (left.dataType != right.dataType) { + // log.warn(s"${left.dataType} != ${right.dataType}") + } + + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) + + eval1.code + eval2.code + + s""" + boolean ${ev.nullTerm} = ${eval1.nullTerm} || ${eval2.nullTerm}; + ${ctx.primitiveForType(resultType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(resultType)}; + if(!${ev.nullTerm}) { + ${ev.primitiveTerm} = (${ctx.primitiveForType(resultType)})($resultCode); + } + """ + } } abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { @@ -124,6 +205,19 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => + def castOrNull(ctx: CodeGenContext, + ev: EvaluatedExpression, + f: String => String, dataType: DataType): String = { + val eval = child.gen(ctx) + eval.code + + s""" + boolean ${ev.nullTerm} = ${eval.nullTerm}; + ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)}; + if (!${ev.nullTerm}) { + ${ev.primitiveTerm} = ${f(eval.primitiveTerm)}; + } + """ + } } // TODO Semantically we probably not need GroupExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 2ac53f8f6613f..4320fbf51bd6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -86,6 +87,8 @@ case class Abs(child: Expression) extends UnaryArithmetic { abstract class BinaryArithmetic extends BinaryExpression { self: Product => + def decimalMethod: String = "" + override def dataType: DataType = left.dataType override def checkInputDataTypes(): TypeCheckResult = { @@ -114,12 +117,21 @@ abstract class BinaryArithmetic extends BinaryExpression { } } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + if (left.dataType.isInstanceOf[DecimalType]) { + evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1.$decimalMethod($eval2)" } ) + } else { + evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1 $symbol $eval2" } ) + } + } + protected def evalInternal(evalE1: Any, evalE2: Any): Any = sys.error(s"BinaryArithmetics must override either eval or evalInternal") } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" + override def decimalMethod: String = "$plus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) @@ -134,6 +146,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "-" + override def decimalMethod: String = "$minus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) @@ -148,6 +161,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "*" + override def decimalMethod: String = "$times" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) @@ -162,6 +176,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "/" + override def decimalMethod: String = "$divide" + override def nullable: Boolean = true override lazy val resolved = @@ -188,10 +204,38 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } } } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val test = if (left.dataType.isInstanceOf[DecimalType]) { + s"${eval2.primitiveTerm}.isZero()" + } else { + s"${eval2.primitiveTerm} == 0" + } + val method = if (left.dataType.isInstanceOf[DecimalType]) { + s".$decimalMethod" + } else { + s"$symbol" + } + eval1.code + eval2.code + + s""" + boolean ${ev.nullTerm} = false; + ${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultPrimitive(left.dataType)}; + if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { + ${ev.nullTerm} = true; + } else { + ${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm}); + } + """ + } } case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "%" + override def decimalMethod: String = "reminder" + override def nullable: Boolean = true override lazy val resolved = @@ -218,6 +262,32 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } } } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val test = if (left.dataType.isInstanceOf[DecimalType]) { + s"${eval2.primitiveTerm}.isZero()" + } else { + s"${eval2.primitiveTerm} == 0" + } + val method = if (left.dataType.isInstanceOf[DecimalType]) { + s".$decimalMethod" + } else { + s"$symbol" + } + eval1.code + eval2.code + + s""" + boolean ${ev.nullTerm} = false; + ${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultPrimitive(left.dataType)}; + if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) { + ${ev.nullTerm} = true; + } else { + ${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm}); + } + """ + } } /** @@ -336,6 +406,33 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + if (ctx.isNativeType(left.dataType)) { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + eval1.code + eval2.code + s""" + boolean ${ev.nullTerm} = false; + ${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultPrimitive(left.dataType)}; + + if (${eval1.nullTerm}) { + ${ev.nullTerm} = ${eval2.nullTerm}; + ${ev.primitiveTerm} = ${eval2.primitiveTerm}; + } else if (${eval2.nullTerm}) { + ${ev.nullTerm} = ${eval1.nullTerm}; + ${ev.primitiveTerm} = ${eval1.primitiveTerm}; + } else { + if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) { + ${ev.primitiveTerm} = ${eval1.primitiveTerm}; + } else { + ${ev.primitiveTerm} = ${eval2.primitiveTerm}; + } + } + """ + } else { + super.genSource(ctx, ev) + } + } override def toString: String = s"MaxOf($left, $right)" } @@ -363,5 +460,35 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + if (ctx.isNativeType(left.dataType)) { + + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + + eval1.code + eval2.code + s""" + boolean ${ev.nullTerm} = false; + ${ctx.primitiveForType(left.dataType)} ${ev.primitiveTerm} = + ${ctx.defaultPrimitive(left.dataType)}; + + if (${eval1.nullTerm}) { + ${ev.nullTerm} = ${eval2.nullTerm}; + ${ev.primitiveTerm} = ${eval2.primitiveTerm}; + } else if (${eval2.nullTerm}) { + ${ev.nullTerm} = ${eval1.nullTerm}; + ${ev.primitiveTerm} = ${eval1.primitiveTerm}; + } else { + if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { + ${ev.primitiveTerm} = ${eval1.primitiveTerm}; + } else { + ${ev.primitiveTerm} = ${eval2.primitiveTerm}; + } + } + """ + } else { + super.genSource(ctx, ev) + } + } + override def toString: String = s"MinOf($left, $right)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cd604121b7dd9..bec1899a3aad2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -24,7 +24,6 @@ import com.google.common.cache.{CacheBuilder, CacheLoader} import org.codehaus.janino.ClassBodyEvaluator import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -33,586 +32,50 @@ class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int] class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] /** - * A base class for generators of byte code to perform expression evaluation. Includes a set of - * helpers for referring to Catalyst types and building trees that perform evaluation of individual - * expressions. + * Java source for evaluating an [[Expression]] given a [[Row]] of input. + * + * @param code The sequence of statements required to evaluate the expression. + * @param nullTerm A term that holds a boolean value representing whether the expression evaluated + * to null. + * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not + * valid if `nullTerm` is set to `true`. + * @param objectTerm A possibly boxed version of the result of evaluating this expression. */ -abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { +case class EvaluatedExpression(var code: String, + nullTerm: String, + primitiveTerm: String, + objectTerm: String) + +/** + * A context for codegen + * @param references the expressions that don't support codegen + */ +case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { - protected val rowType = classOf[Row].getName protected val stringType = classOf[UTF8String].getName protected val decimalType = classOf[Decimal].getName - protected val exprType = classOf[Expression].getName - protected val mutableRowType = classOf[MutableRow].getName - protected val genericMutableRowType = classOf[GenericMutableRow].getName private val curId = new java.util.concurrent.atomic.AtomicInteger() - /** - * Can be flipped on manually in the console to add (expensive) expression evaluation trace code. - */ - var debugLogging = false - - /** - * Generates a class for a given input expression. Called when there is not cached code - * already available. - */ - protected def create(in: InType): OutType - - /** - * Canonicalizes an input expression. Used to avoid double caching expressions that differ only - * cosmetically. - */ - protected def canonicalize(in: InType): InType - - /** Binds an input expression to a given input schema */ - protected def bind(in: InType, inputSchema: Seq[Attribute]): InType - - /** - * Compile the Java source code into a Java class, using Janino. - * - * It will track the time used to compile - */ - protected def compile(code: String): Class[_] = { - val startTime = System.nanoTime() - val clazz = new ClassBodyEvaluator(code).getClazz() - val endTime = System.nanoTime() - def timeMs: Double = (endTime - startTime).toDouble / 1000000 - logDebug(s"Compiled Java code (${code.size} bytes) in $timeMs ms") - clazz - } - - /** - * A cache of generated classes. - * - * From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most - * fundamental difference is that a ConcurrentMap persists all elements that are added to it until - * they are explicitly removed. A Cache on the other hand is generally configured to evict entries - * automatically, in order to constrain its memory footprint. Note that this cache does not use - * weak keys/values and thus does not respond to memory pressure. - */ - protected val cache = CacheBuilder.newBuilder() - .maximumSize(1000) - .build( - new CacheLoader[InType, OutType]() { - override def load(in: InType): OutType = { - val startTime = System.nanoTime() - val result = create(in) - val endTime = System.nanoTime() - def timeMs: Double = (endTime - startTime).toDouble / 1000000 - logInfo(s"Code generated expression $in in $timeMs ms") - result - } - }) - - /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ - def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType = - generate(bind(expressions, inputSchema)) - - /** Generates the requested evaluator given already bound expression(s). */ - def generate(expressions: InType): OutType = cache.get(canonicalize(expressions)) - /** * Returns a term name that is unique within this instance of a `CodeGenerator`. * * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - protected def freshName(prefix: String): String = { + def freshName(prefix: String): String = { s"$prefix${curId.getAndIncrement}" } - /** - * Scala ASTs for evaluating an [[Expression]] given a [[Row]] of input. - * - * @param code The sequence of statements required to evaluate the expression. - * @param nullTerm A term that holds a boolean value representing whether the expression evaluated - * to null. - * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not - * valid if `nullTerm` is set to `true`. - * @param objectTerm A possibly boxed version of the result of evaluating this expression. - */ - protected case class EvaluatedExpression( - code: String, - nullTerm: String, - primitiveTerm: String, - objectTerm: String) - - /** - * A context for codegen, which is used to bookkeeping the expressions those are not supported - * by codegen, then they are evaluated directly. The unsupported expression is appended at the - * end of `references`, the position of it is kept in the code, used to access and evaluate it. - */ - protected class CodeGenContext { - /** - * Holding all the expressions those do not support codegen, will be evaluated directly. - */ - val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() - } - - /** - * Create a new codegen context for expression evaluator, used to store those - * expressions that don't support codegen - */ - def newCodeGenContext(): CodeGenContext = { - new CodeGenContext() - } - - /** - * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that - * can be used to determine the result of evaluating the expression on an input row. - */ - def expressionEvaluator(e: Expression, ctx: CodeGenContext): EvaluatedExpression = { - val primitiveTerm = freshName("primitiveTerm") - val nullTerm = freshName("nullTerm") - val objectTerm = freshName("objectTerm") - - implicit class Evaluate1(e: Expression) { - def castOrNull(f: String => String, dataType: DataType): String = { - val eval = expressionEvaluator(e, ctx) - eval.code + - s""" - boolean $nullTerm = ${eval.nullTerm}; - ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)}; - if (!$nullTerm) { - $primitiveTerm = ${f(eval.primitiveTerm)}; - } - """ - } - } - - implicit class Evaluate2(expressions: (Expression, Expression)) { - - /** - * Short hand for generating binary evaluation code, which depends on two sub-evaluations of - * the same type. If either of the sub-expressions is null, the result of this computation - * is assumed to be null. - * - * @param f a function from two primitive term names to a tree that evaluates them. - */ - def evaluate(f: (String, String) => String): String = - evaluateAs(expressions._1.dataType)(f) - - def evaluateAs(resultType: DataType)(f: (String, String) => String): String = { - // TODO: Right now some timestamp tests fail if we enforce this... - if (expressions._1.dataType != expressions._2.dataType) { - log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}") - } - - val eval1 = expressionEvaluator(expressions._1, ctx) - val eval2 = expressionEvaluator(expressions._2, ctx) - val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) - - eval1.code + eval2.code + - s""" - boolean $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm}; - ${primitiveForType(resultType)} $primitiveTerm = ${defaultPrimitive(resultType)}; - if(!$nullTerm) { - $primitiveTerm = (${primitiveForType(resultType)})($resultCode); - } - """ - } - } - - val inputTuple = "i" - - // TODO: Skip generation of null handling code when expression are not nullable. - val primitiveEvaluation: PartialFunction[Expression, String] = { - case b @ BoundReference(ordinal, dataType, nullable) => - s""" - final boolean $nullTerm = $inputTuple.isNullAt($ordinal); - final ${primitiveForType(dataType)} $primitiveTerm = $nullTerm ? - ${defaultPrimitive(dataType)} : (${getColumn(inputTuple, dataType, ordinal)}); - """ - - case expressions.Literal(null, dataType) => - s""" - final boolean $nullTerm = true; - ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)}; - """ - - case expressions.Literal(value: UTF8String, StringType) => - val arr = s"new byte[]{${value.getBytes.map(_.toString).mkString(", ")}}" - s""" - final boolean $nullTerm = false; - ${stringType} $primitiveTerm = - new ${stringType}().set(${arr}); - """ - - case expressions.Literal(value, FloatType) => - s""" - final boolean $nullTerm = false; - float $primitiveTerm = ${value}f; - """ - - case expressions.Literal(value, dt @ DecimalType()) => - s""" - final boolean $nullTerm = false; - ${primitiveForType(dt)} $primitiveTerm = new ${primitiveForType(dt)}().set($value); - """ - - case expressions.Literal(value, dataType) => - s""" - final boolean $nullTerm = false; - ${primitiveForType(dataType)} $primitiveTerm = $value; - """ - - case Cast(child @ BinaryType(), StringType) => - child.castOrNull(c => - s"new ${stringType}().set($c)", - StringType) - - case Cast(child @ DateType(), StringType) => - child.castOrNull(c => - s"""new ${stringType}().set( - org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""", - StringType) - - case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - child.castOrNull(c => s"(${primitiveForType(dt)})($c?1:0)", dt) - - case Cast(child @ DecimalType(), IntegerType) => - child.castOrNull(c => s"($c).toInt()", IntegerType) - - case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - child.castOrNull(c => s"($c).to${termForType(dt)}()", dt) - - case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - child.castOrNull(c => s"(${primitiveForType(dt)})($c)", dt) - - // Special handling required for timestamps in hive test cases since the toString function - // does not match the expected output. - case Cast(e, StringType) if e.dataType != TimestampType => - e.castOrNull(c => - s"new ${stringType}().set(String.valueOf($c))", - StringType) - - case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) => - (e1, e2).evaluateAs (BooleanType) { - case (eval1, eval2) => - s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)" - } - - case EqualTo(e1, e2) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 == $eval2" } - - case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 > $eval2" } - case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 >= $eval2" } - case LessThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 < $eval2" } - case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 <= $eval2" } - - case And(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - s""" - ${eval1.code} - boolean $nullTerm = false; - boolean $primitiveTerm = false; - - if (!${eval1.nullTerm} && !${eval1.primitiveTerm}) { - } else { - ${eval2.code} - if (!${eval2.nullTerm} && !${eval2.primitiveTerm}) { - } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = true; - } else { - $nullTerm = true; - } - } - """ - - case Or(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - - s""" - ${eval1.code} - boolean $nullTerm = false; - boolean $primitiveTerm = false; - - if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) { - $primitiveTerm = true; - } else { - ${eval2.code} - if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) { - $primitiveTerm = true; - } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = false; - } else { - $nullTerm = true; - } - } - """ - - case Not(child) => - // Uh, bad function name... - child.castOrNull(c => s"!$c", BooleanType) - - case Add(e1 @ DecimalType(), e2 @ DecimalType()) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$plus($eval2)" } - case Subtract(e1 @ DecimalType(), e2 @ DecimalType()) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$minus($eval2)" } - case Multiply(e1 @ DecimalType(), e2 @ DecimalType()) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$times($eval2)" } - case Divide(e1 @ DecimalType(), e2 @ DecimalType()) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = null; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm}.$$div${eval2.primitiveTerm}); - } - """ - case Remainder(e1 @ DecimalType(), e2 @ DecimalType()) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = 0; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm}.remainder(${eval2.primitiveTerm}); - } - """ - - case Add(e1, e2) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 + $eval2" } - case Subtract(e1, e2) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 - $eval2" } - case Multiply(e1, e2) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 * $eval2" } - case Divide(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = 0; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm}; - } - """ - case Remainder(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = 0; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm}; - } - """ - - case IsNotNull(e) => - val eval = expressionEvaluator(e, ctx) - s""" - ${eval.code} - boolean $nullTerm = false; - boolean $primitiveTerm = !${eval.nullTerm}; - """ - - case IsNull(e) => - val eval = expressionEvaluator(e, ctx) - s""" - ${eval.code} - boolean $nullTerm = false; - boolean $primitiveTerm = ${eval.nullTerm}; - """ - - case e @ Coalesce(children) => - s""" - boolean $nullTerm = true; - ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; - """ + - children.map { c => - val eval = expressionEvaluator(c, ctx) - s""" - if($nullTerm) { - ${eval.code} - if(!${eval.nullTerm}) { - $nullTerm = false; - $primitiveTerm = ${eval.primitiveTerm}; - } - } - """ - }.mkString("\n") - - case e @ expressions.If(condition, trueValue, falseValue) => - val condEval = expressionEvaluator(condition, ctx) - val trueEval = expressionEvaluator(trueValue, ctx) - val falseEval = expressionEvaluator(falseValue, ctx) - - s""" - boolean $nullTerm = false; - ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; - ${condEval.code} - if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { - ${trueEval.code} - $nullTerm = ${trueEval.nullTerm}; - $primitiveTerm = ${trueEval.primitiveTerm}; - } else { - ${falseEval.code} - $nullTerm = ${falseEval.nullTerm}; - $primitiveTerm = ${falseEval.primitiveTerm}; - } - """ - - case NewSet(elementType) => - s""" - boolean $nullTerm = false; - ${hashSetForType(elementType)} $primitiveTerm = new ${hashSetForType(elementType)}(); - """ - - case AddItemToSet(item, set) => - val itemEval = expressionEvaluator(item, ctx) - val setEval = expressionEvaluator(set, ctx) - - val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType - val htype = hashSetForType(elementType) - - itemEval.code + setEval.code + - s""" - if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { - (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); - } - boolean $nullTerm = false; - ${htype} $primitiveTerm = ($htype)${setEval.primitiveTerm}; - """ - - case CombineSets(left, right) => - val leftEval = expressionEvaluator(left, ctx) - val rightEval = expressionEvaluator(right, ctx) - - val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType - val htype = hashSetForType(elementType) - - leftEval.code + rightEval.code + - s""" - boolean $nullTerm = false; - ${htype} $primitiveTerm = - (${htype})${leftEval.primitiveTerm}; - $primitiveTerm.union((${htype})${rightEval.primitiveTerm}); - """ - - case MaxOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)}; - - if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm}; - $primitiveTerm = ${eval2.primitiveTerm}; - } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm}; - $primitiveTerm = ${eval1.primitiveTerm}; - } else { - if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm}; - } else { - $primitiveTerm = ${eval2.primitiveTerm}; - } - } - """ - - case MinOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)}; - - if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm}; - $primitiveTerm = ${eval2.primitiveTerm}; - } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm}; - $primitiveTerm = ${eval1.primitiveTerm}; - } else { - if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm}; - } else { - $primitiveTerm = ${eval2.primitiveTerm}; - } - } - """ - - case UnscaledValue(child) => - val childEval = expressionEvaluator(child, ctx) - - childEval.code + - s""" - boolean $nullTerm = ${childEval.nullTerm}; - long $primitiveTerm = $nullTerm ? -1 : ${childEval.primitiveTerm}.toUnscaledLong(); - """ - - case MakeDecimal(child, precision, scale) => - val eval = expressionEvaluator(child, ctx) - - eval.code + - s""" - boolean $nullTerm = ${eval.nullTerm}; - org.apache.spark.sql.types.Decimal $primitiveTerm = ${defaultPrimitive(DecimalType())}; - - if (!$nullTerm) { - $primitiveTerm = new org.apache.spark.sql.types.Decimal(); - $primitiveTerm = $primitiveTerm.setOrNull(${eval.primitiveTerm}, $precision, $scale); - $nullTerm = $primitiveTerm == null; - } - """ - } - - // If there was no match in the partial function above, we fall back on calling the interpreted - // expression evaluator. - val code: String = - primitiveEvaluation.lift.apply(e).getOrElse { - logError(s"No rules to generate $e") - ctx.references += e - s""" - /* expression: ${e} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); - boolean $nullTerm = $objectTerm == null; - ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; - if (!$nullTerm) $primitiveTerm = (${termForType(e.dataType)})$objectTerm; - """ - } - - EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) - } - - protected def getColumn(inputRow: String, dataType: DataType, ordinal: Int) = { + def getColumn(dataType: DataType, ordinal: Int): String = { dataType match { - case StringType => s"(${stringType})$inputRow.apply($ordinal)" - case dt: DataType if isNativeType(dt) => s"$inputRow.${accessorForType(dt)}($ordinal)" - case _ => s"(${termForType(dataType)})$inputRow.apply($ordinal)" + case StringType => s"(org.apache.spark.sql.types.UTF8String)i.apply($ordinal)" + case dt: DataType if isNativeType(dt) => s"i.${accessorForType(dt)}($ordinal)" + case _ => s"(${termForType(dataType)})i.apply($ordinal)" } } - protected def setColumn( - destinationRow: String, - dataType: DataType, - ordinal: Int, - value: String): String = { + def setColumn(destinationRow: String, dataType: DataType, ordinal: Int, value: String): String = { dataType match { case StringType => s"$destinationRow.update($ordinal, $value)" case dt: DataType if isNativeType(dt) => @@ -621,24 +84,24 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } } - protected def accessorForType(dt: DataType) = dt match { + def accessorForType(dt: DataType): String = dt match { case IntegerType => "getInt" case other => s"get${termForType(dt)}" } - protected def mutatorForType(dt: DataType) = dt match { + def mutatorForType(dt: DataType): String = dt match { case IntegerType => "setInt" case other => s"set${termForType(dt)}" } - protected def hashSetForType(dt: DataType): String = dt match { + def hashSetForType(dt: DataType): String = dt match { case IntegerType => classOf[IntegerHashSet].getName case LongType => classOf[LongHashSet].getName case unsupportedType => sys.error(s"Code generation not support for hashset of type $unsupportedType") } - protected def primitiveForType(dt: DataType): String = dt match { + def primitiveForType(dt: DataType): String = dt match { case IntegerType => "int" case LongType => "long" case ShortType => "short" @@ -654,7 +117,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case _ => "Object" } - protected def defaultPrimitive(dt: DataType): String = dt match { + def defaultPrimitive(dt: DataType): String = dt match { case BooleanType => "false" case FloatType => "-1.0f" case ShortType => "-1" @@ -668,7 +131,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case _ => "null" } - protected def termForType(dt: DataType): String = dt match { + def termForType(dt: DataType): String = dt match { case IntegerType => "Integer" case LongType => "Long" case ShortType => "Short" @@ -687,11 +150,96 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** * List of data types that have special accessors and setters in [[Row]]. */ - protected val nativeTypes = + val nativeTypes = Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType) /** * Returns true if the data type has a special accessor and setter in [[Row]]. */ - protected def isNativeType(dt: DataType) = nativeTypes.contains(dt) + def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt) +} + +/** + * A base class for generators of byte code to perform expression evaluation. Includes a set of + * helpers for referring to Catalyst types and building trees that perform evaluation of individual + * expressions. + */ +abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { + + protected val rowType = classOf[Row].getName + protected val exprType = classOf[Expression].getName + protected val mutableRowType = classOf[MutableRow].getName + protected val genericMutableRowType = classOf[GenericMutableRow].getName + + /** + * Can be flipped on manually in the console to add (expensive) expression evaluation trace code. + */ + var debugLogging = false + + /** + * Generates a class for a given input expression. Called when there is not cached code + * already available. + */ + protected def create(in: InType): OutType + + /** + * Canonicalizes an input expression. Used to avoid double caching expressions that differ only + * cosmetically. + */ + protected def canonicalize(in: InType): InType + + /** Binds an input expression to a given input schema */ + protected def bind(in: InType, inputSchema: Seq[Attribute]): InType + + /** + * Compile the Java source code into a Java class, using Janino. + * + * It will track the time used to compile + */ + protected def compile(code: String): Class[_] = { + val startTime = System.nanoTime() + val clazz = new ClassBodyEvaluator(code).getClazz() + val endTime = System.nanoTime() + def timeMs: Double = (endTime - startTime).toDouble / 1000000 + logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms") + clazz + } + + /** + * A cache of generated classes. + * + * From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most + * fundamental difference is that a ConcurrentMap persists all elements that are added to it until + * they are explicitly removed. A Cache on the other hand is generally configured to evict entries + * automatically, in order to constrain its memory footprint. Note that this cache does not use + * weak keys/values and thus does not respond to memory pressure. + */ + protected val cache = CacheBuilder.newBuilder() + .maximumSize(1000) + .build( + new CacheLoader[InType, OutType]() { + override def load(in: InType): OutType = { + val startTime = System.nanoTime() + val result = create(in) + val endTime = System.nanoTime() + def timeMs: Double = (endTime - startTime).toDouble / 1000000 + logInfo(s"Code generated expression $in in $timeMs ms") + result + } + }) + + /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ + def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType = + generate(bind(expressions, inputSchema)) + + /** Generates the requested evaluator given already bound expression(s). */ + def generate(expressions: InType): OutType = cache.get(canonicalize(expressions)) + + /** + * Create a new codegen context for expression evaluator, used to store those + * expressions that don't support codegen + */ + def newCodeGenContext(): CodeGenContext = { + new CodeGenContext(new mutable.ArrayBuffer[Expression]()) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 638b53fe0fe2f..02b7d3fae6767 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -37,13 +37,13 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { val ctx = newCodeGenContext() val projectionCode = expressions.zipWithIndex.map { case (e, i) => - val evaluationCode = expressionEvaluator(e, ctx) + val evaluationCode = e.gen(ctx) evaluationCode.code + s""" if(${evaluationCode.nullTerm}) mutableRow.setNullAt($i); else - ${setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)}; + ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)}; """ }.mkString("\n") val code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 0ff840dab393c..d3c219fddc53c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -52,8 +52,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit val ctx = newCodeGenContext() val comparisons = ordering.zipWithIndex.map { case (order, i) => - val evalA = expressionEvaluator(order.child, ctx) - val evalB = expressionEvaluator(order.child, ctx) + val evalA = order.child.gen(ctx) + val evalB = order.child.gen(ctx) val asc = order.direction == Ascending val compare = order.child.dataType match { case BinaryType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index fb18769f00da3..dd4474de05df9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -38,7 +38,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { protected def create(predicate: Expression): ((Row) => Boolean) = { val ctx = newCodeGenContext() - val eval = expressionEvaluator(predicate, ctx) + val eval = predicate.gen(ctx) val code = s""" import org.apache.spark.sql.Row; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index d5be1fc12e0f0..0e8ad76f65bad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -45,12 +45,12 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val ctx = newCodeGenContext() val columns = expressions.zipWithIndex.map { case (e, i) => - s"private ${primitiveForType(e.dataType)} c$i = ${defaultPrimitive(e.dataType)};\n" + s"private ${ctx.primitiveForType(e.dataType)} c$i = ${ctx.defaultPrimitive(e.dataType)};\n" }.mkString("\n ") val initColumns = expressions.zipWithIndex.map { case (e, i) => - val eval = expressionEvaluator(e, ctx) + val eval = e.gen(ctx) s""" { // column$i @@ -68,10 +68,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { }.mkString("\n ") val updateCases = expressions.zipWithIndex.map { case (e, i) => - s"case $i: { c$i = (${termForType(e.dataType)})value; return;}" + s"case $i: { c$i = (${ctx.termForType(e.dataType)})value; return;}" }.mkString("\n ") - val specificAccessorFunctions = nativeTypes.map { dataType => + val specificAccessorFunctions = ctx.nativeTypes.map { dataType => val cases = expressions.zipWithIndex.map { case (e, i) if e.dataType == dataType => s"case $i: return c$i;" @@ -80,21 +80,21 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { if (cases.count(_ != '\n') > 0) { s""" @Override - public ${primitiveForType(dataType)} ${accessorForType(dataType)}(int i) { + public ${ctx.primitiveForType(dataType)} ${ctx.accessorForType(dataType)}(int i) { if (isNullAt(i)) { - return ${defaultPrimitive(dataType)}; + return ${ctx.defaultPrimitive(dataType)}; } switch (i) { $cases } - return ${defaultPrimitive(dataType)}; + return ${ctx.defaultPrimitive(dataType)}; }""" } else { "" } }.mkString("\n") - val specificMutatorFunctions = nativeTypes.map { dataType => + val specificMutatorFunctions = ctx.nativeTypes.map { dataType => val cases = expressions.zipWithIndex.map { case (e, i) if e.dataType == dataType => s"case $i: { c$i = value; return; }" @@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { if (cases.count(_ != '\n') > 0) { s""" @Override - public void ${mutatorForType(dataType)}(int i, ${primitiveForType(dataType)} value) { + public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.primitiveForType(dataType)} value) { nullBits[i] = false; switch (i) { $cases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 65ba18924afe1..76273a5b7ee68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} import org.apache.spark.sql.types._ /** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ @@ -35,6 +36,14 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { childResult.asInstanceOf[Decimal].toUnscaledLong } } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval = child.gen(ctx) + eval.code +s""" + boolean ${ev.nullTerm} = ${eval.nullTerm}; + long ${ev.primitiveTerm} = ${ev.nullTerm} ? -1 : ${eval.primitiveTerm}.toUnscaledLong(); + """ + } } /** Create a Decimal from an unscaled Long value */ @@ -53,4 +62,20 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale) } } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.nullTerm} = ${eval.nullTerm}; + org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} = + ${ctx.defaultPrimitive(DecimalType())}; + + if (!${ev.nullTerm}) { + ${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal(); + ${ev.primitiveTerm} = + ${ev.primitiveTerm}.setOrNull(${eval.primitiveTerm}, $precision, $scale); + ${ev.nullTerm} = ${ev.primitiveTerm} == null; + } + """ + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index d3ca3d9a4b18b..d9fbda9511a5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, EvaluatedExpression} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -79,6 +80,43 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres override def toString: String = if (value != null) value.toString else "null" override def eval(input: Row): Any = value + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + if (value == null) { + s""" + final boolean ${ev.nullTerm} = true; + ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)}; + """ + } else { + dataType match { + case StringType => + val v = value.asInstanceOf[UTF8String] + val arr = s"new byte[]{${v.getBytes.map(_.toString).mkString(", ")}}" + s""" + final boolean ${ev.nullTerm} = false; + org.apache.spark.sql.types.UTF8String ${ev.primitiveTerm} = + new org.apache.spark.sql.types.UTF8String().set(${arr}); + """ + case FloatType => + s""" + final boolean ${ev.nullTerm} = false; + float ${ev.primitiveTerm} = ${value}f; + """ + case dt: DecimalType => + s""" + final boolean ${ev.nullTerm} = false; + ${ctx.primitiveForType(dt)} ${ev.primitiveTerm} = new ${ctx.primitiveForType(dt)}().set($value); + """ + case dt: NumericType => + s""" + final boolean ${ev.nullTerm} = false; + ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = $value; + """ + case other => + super.genSource(ctx, ev) + } + } + } } // TODO: Specialize diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 5070570b4740d..2af0f96146c1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.types.DataType @@ -51,6 +52,25 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } result } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + s""" + boolean ${ev.nullTerm} = true; + ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)}; + """ + + children.map { e => + val eval = e.gen(ctx) + s""" + if(${ev.nullTerm}) { + ${eval.code} + if(!${eval.nullTerm}) { + ${ev.nullTerm} = false; + ${ev.primitiveTerm} = ${eval.primitiveTerm}; + } + } + """ + }.mkString("\n") + } } case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { @@ -61,6 +81,14 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr child.eval(input) == null } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval = child.gen(ctx) + eval.code + s""" + final boolean ${ev.nullTerm} = false; + final boolean ${ev.primitiveTerm} = ${eval.nullTerm}; + """ + } + override def toString: String = s"IS NULL $child" } @@ -72,6 +100,14 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E override def eval(input: Row): Any = { child.eval(input) != null } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.nullTerm} = false; + boolean ${ev.primitiveTerm} = !${eval.nullTerm}; + """ + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 807021d50e8e0..b6b2c7db28960 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType} +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types._ object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -82,6 +83,11 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex case b: Boolean => !b } } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + // Uh, bad function name... + castOrNull(ctx, ev, c => s"!($c)", BooleanType) + } } /** @@ -141,6 +147,26 @@ case class And(left: Expression, right: Expression) } } } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.nullTerm} = false; + boolean ${ev.primitiveTerm} = false; + + if (!${eval1.nullTerm} && !${eval1.primitiveTerm}) { + } else { + ${eval2.code} + if (!${eval2.nullTerm} && !${eval2.primitiveTerm}) { + } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { + ${ev.primitiveTerm} = true; + } else { + ${ev.nullTerm} = true; + } + } + """ + } } case class Or(left: Expression, right: Expression) @@ -167,10 +193,44 @@ case class Or(left: Expression, right: Expression) } } } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.nullTerm} = false; + boolean ${ev.primitiveTerm} = false; + + if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) { + ${ev.primitiveTerm} = true; + } else { + ${eval2.code} + if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) { + ${ev.primitiveTerm} = true; + } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { + ${ev.primitiveTerm} = false; + } else { + ${ev.nullTerm} = true; + } + } + """ + } } abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + left.dataType match { + case dt: NumericType => evaluateAs(BooleanType) (ctx, ev, { + (eval1, eval2) => s"$eval1 $symbol $eval2" + }) + case dt: TimestampType => + super.genSource(ctx, ev) + case other => evaluateAs(BooleanType) (ctx, ev, { + (eval1, eval2) => s"$eval1.compare($eval2) $symbol 0" + }) + } + } override def checkInputDataTypes(): TypeCheckResult = { if (left.dataType != right.dataType) { @@ -216,6 +276,17 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison if (left.dataType != BinaryType) l == r else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression) = { + left.dataType match { + case BinaryType() => + evaluateAs (BooleanType) (ctx, ev, { + case (eval1, eval2) => + s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)" + }) + case other => + evaluateAs (BooleanType) (ctx, ev, { case (eval1, eval2) => s"$eval1 == $eval2" }) + } + } } case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { @@ -236,6 +307,22 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp l == r } } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val cmpCode = if (left.dataType.isInstanceOf[BinaryType]) { + s"java.util.Arrays.equals((byte[])${eval1.primitiveTerm}, (byte[])${eval2.primitiveTerm})" + } else { + s"${eval1.primitiveTerm} == ${eval2.primitiveTerm}" + } + eval1.code + eval2.code + + s""" + final boolean ${ev.nullTerm} = false; + final boolean ${ev.primitiveTerm} = (${eval1.nullTerm} && ${eval2.nullTerm}) || + (!${eval1.nullTerm} && $cmpCode); + """ + } } case class LessThan(left: Expression, right: Expression) extends BinaryComparison { @@ -309,6 +396,26 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi falseValue.eval(input) } } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val condEval = predicate.gen(ctx) + val trueEval = trueValue.gen(ctx) + val falseEval = falseValue.gen(ctx) + + s""" + boolean ${ev.nullTerm} = false; + ${ctx.primitiveForType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultPrimitive(dataType)}; + ${condEval.code} + if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { + ${trueEval.code} + ${ev.nullTerm} = ${trueEval.nullTerm}; + ${ev.primitiveTerm} = ${trueEval.primitiveTerm}; + } else { + ${falseEval.code} + ${ev.nullTerm} = ${falseEval.nullTerm}; + ${ev.primitiveTerm} = ${falseEval.primitiveTerm}; + } + """ + } override def toString: String = s"if ($predicate) $trueValue else $falseValue" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index b65bf165f21db..e6ae81c2aad52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -60,6 +61,14 @@ case class NewSet(elementType: DataType) extends LeafExpression { new OpenHashSet[Any]() } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + s""" + boolean ${ev.nullTerm} = false; + ${ctx.hashSetForType(elementType)} ${ev.primitiveTerm} = + new ${ctx.hashSetForType(elementType)}(); + """ + } + override def toString: String = s"new Set($dataType)" } @@ -91,6 +100,23 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } } + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val itemEval = item.gen(ctx) + val setEval = set.gen(ctx) + + val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType + val htype = ctx.hashSetForType(elementType) + + itemEval.code + setEval.code + + s""" + if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { + (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); + } + boolean ${ev.nullTerm} = false; + ${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm}; + """ + } + override def toString: String = s"$set += $item" } @@ -124,6 +150,22 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres null } } + + override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + val leftEval = left.gen(ctx) + val rightEval = right.gen(ctx) + + val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType + val htype = ctx.hashSetForType(elementType) + + leftEval.code + rightEval.code + + s""" + boolean ${ev.nullTerm} = false; + ${htype} ${ev.primitiveTerm} = + (${htype})${leftEval.primitiveTerm}; + ${ev.primitiveTerm}.union((${htype})${rightEval.primitiveTerm}); + """ + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index 8cfd853afa35f..b577de1d5aab9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -33,7 +33,7 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { } catch { case e: Throwable => val ctx = GenerateProjection.newCodeGenContext() - val evaluated = GenerateProjection.expressionEvaluator(expression, ctx) + val evaluated = expression.gen(ctx) fail( s""" |Code generation of $expression failed: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index 9ab1f7d7ad0db..9da72521ec3ec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -29,7 +29,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { expected: Any, inputRow: Row = EmptyRow): Unit = { val ctx = GenerateProjection.newCodeGenContext() - lazy val evaluated = GenerateProjection.expressionEvaluator(expression, ctx) + lazy val evaluated = expression.gen(ctx) val plan = try { GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)