From e57959d60bb841851623898790a5cb1cba314cdd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 4 Jun 2015 17:00:17 -0700 Subject: [PATCH] add type alias --- .../catalyst/expressions/BoundAttribute.scala | 4 +- .../spark/sql/catalyst/expressions/Cast.scala | 6 +-- .../sql/catalyst/expressions/Expression.scala | 20 ++------- .../sql/catalyst/expressions/arithmetic.scala | 16 +++---- .../expressions/codegen/CodeGenerator.scala | 42 +++++++++++-------- .../expressions/codegen/package.scala | 3 ++ .../expressions/decimalFunctions.scala | 6 +-- .../sql/catalyst/expressions/literals.scala | 6 +-- .../catalyst/expressions/nullFunctions.scala | 8 ++-- .../sql/catalyst/expressions/predicates.scala | 18 ++++---- .../spark/sql/catalyst/expressions/sets.scala | 14 +++---- 11 files changed, 69 insertions(+), 74 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 1d7f3b766a160..5978d1c931f37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.trees @@ -43,7 +43,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def exprId: ExprId = throw new UnsupportedOperationException - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { s""" final boolean ${ev.nullTerm} = i.isNullAt($ordinal); final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index bf8642cdde535..bcd7781c09e00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -435,7 +435,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w if (evaluated == null) null else cast(evaluated) } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = this match { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = this match { case Cast(child @ BinaryType(), StringType) => castOrNull (ctx, ev, c => @@ -465,7 +465,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))") case other => - super.genSource(ctx, ev) + super.genCode(ctx, ev) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 9b89a4bc744c3..6efa08626795e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ @@ -62,28 +62,14 @@ abstract class Expression extends TreeNode[Expression] { val primitiveTerm = ctx.freshName("primitiveTerm") val objectTerm = ctx.freshName("objectTerm") val ve = EvaluatedExpression("", nullTerm, primitiveTerm, objectTerm) - ve.code = genSource(ctx, ve) - - // Only inject debugging code if debugging is turned on. - // val debugCode = - // if (debugLogging) { - // val localLogger = log - // val localLoggerTree = reify { localLogger } - // s""" - // $localLoggerTree.debug( - // ${this.toString} + ": " + (if (${ev.nullTerm}) "null" else ${ev.primitiveTerm}.toString)) - // """ - // } else { - // "" - // } - + ve.code = genCode(ctx, ve) ve } /** * Returns Java source code for this expression */ - def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val e = this.asInstanceOf[Expression] ctx.references += e s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 79350dd3d65f2..6ae815e1d0096 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{Code, EvaluatedExpression, CodeGenContext} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -117,7 +117,7 @@ abstract class BinaryArithmetic extends BinaryExpression { } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { if (left.dataType.isInstanceOf[DecimalType]) { evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1.$decimalMethod($eval2)" } ) } else { @@ -205,7 +205,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val test = if (left.dataType.isInstanceOf[DecimalType]) { @@ -263,7 +263,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val test = if (left.dataType.isInstanceOf[DecimalType]) { @@ -406,7 +406,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { if (ctx.isNativeType(left.dataType)) { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -430,7 +430,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } """ } else { - super.genSource(ctx, ev) + super.genCode(ctx, ev) } } override def toString: String = s"MaxOf($left, $right)" @@ -460,7 +460,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { if (ctx.isNativeType(left.dataType)) { val eval1 = left.gen(ctx) @@ -486,7 +486,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } """ } else { - super.genSource(ctx, ev) + super.genCode(ctx, ev) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4f21a1892df25..c87258c622664 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -41,16 +41,22 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] * valid if `nullTerm` is set to `true`. * @param objectTerm A possibly boxed version of the result of evaluating this expression. */ -case class EvaluatedExpression(var code: String, - nullTerm: String, - primitiveTerm: String, - objectTerm: String) +case class EvaluatedExpression(var code: Code, + nullTerm: Term, + primitiveTerm: Term, + objectTerm: Term) /** - * A context for codegen - * @param references the expressions that don't support codegen + * A context for codegen, which is used to bookkeeping the expressions those are not supported + * by codegen, then they are evaluated directly. The unsupported expression is appended at the + * end of `references`, the position of it is kept in the code, used to access and evaluate it. */ -case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { +class CodeGenContext { + + /** + * Holding all the expressions those do not support codegen, will be evaluated directly. + */ + val references: Seq[Expression] = new mutable.ArrayBuffer[Expression]() protected val stringType = classOf[UTF8String].getName protected val decimalType = classOf[Decimal].getName @@ -63,11 +69,11 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - def freshName(prefix: String): String = { + def freshName(prefix: String): Term = { s"$prefix${curId.getAndIncrement}" } - def getColumn(dataType: DataType, ordinal: Int): String = { + def getColumn(dataType: DataType, ordinal: Int): Code = { dataType match { case StringType => s"(org.apache.spark.sql.types.UTF8String)i.apply($ordinal)" case dt: DataType if isNativeType(dt) => s"i.${accessorForType(dt)}($ordinal)" @@ -75,7 +81,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { } } - def setColumn(destinationRow: String, dataType: DataType, ordinal: Int, value: String): String = { + def setColumn(destinationRow: Term, dataType: DataType, ordinal: Int, value: Term): Code = { dataType match { case StringType => s"$destinationRow.update($ordinal, $value)" case dt: DataType if isNativeType(dt) => @@ -84,17 +90,17 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { } } - def accessorForType(dt: DataType): String = dt match { + def accessorForType(dt: DataType): Term = dt match { case IntegerType => "getInt" case other => s"get${boxedType(dt)}" } - def mutatorForType(dt: DataType): String = dt match { + def mutatorForType(dt: DataType): Term = dt match { case IntegerType => "setInt" case other => s"set${boxedType(dt)}" } - def hashSetForType(dt: DataType): String = dt match { + def hashSetForType(dt: DataType): Term = dt match { case IntegerType => classOf[IntegerHashSet].getName case LongType => classOf[LongHashSet].getName case unsupportedType => @@ -104,7 +110,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { /** * Return the primitive type for a DataType */ - def primitiveType(dt: DataType): String = dt match { + def primitiveType(dt: DataType): Term = dt match { case IntegerType => "int" case LongType => "long" case ShortType => "short" @@ -123,7 +129,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { /** * Return the representation of default value for given DataType */ - def defaultValue(dt: DataType): String = dt match { + def defaultValue(dt: DataType): Term = dt match { case BooleanType => "false" case FloatType => "-1.0f" case ShortType => "-1" @@ -140,7 +146,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { /** * Return the boxed type in Java */ - def boxedType(dt: DataType): String = dt match { + def boxedType(dt: DataType): Term = dt match { case IntegerType => "Integer" case LongType => "Long" case ShortType => "Short" @@ -159,7 +165,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) { /** * Returns a function to generate equal expression in Java */ - def equalFunc(dataType: DataType): ((String, String) => String) = dataType match { + def equalFunc(dataType: DataType): ((Term, Term) => Code) = dataType match { case BinaryType => { case (eval1, eval2) => s"java.util.Arrays.equals($eval1, $eval2)" } case dt if isNativeType(dt) => { case (eval1, eval2) => s"$eval1 == $eval2" } case other => { case (eval1, eval2) => s"$eval1.equals($eval2)" } @@ -257,6 +263,6 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * expressions that don't support codegen */ def newCodeGenContext(): CodeGenContext = { - new CodeGenContext(new mutable.ArrayBuffer[Expression]()) + new CodeGenContext } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 7f1b12cdd5800..6f9589d20445e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -27,6 +27,9 @@ import org.apache.spark.util.Utils */ package object codegen { + type Term = String + type Code = String + /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { val batches = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 68daea725cd40..250fe00b174bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext} import org.apache.spark.sql.types._ /** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ @@ -37,7 +37,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval = child.gen(ctx) eval.code +s""" boolean ${ev.nullTerm} = ${eval.nullTerm}; @@ -63,7 +63,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval = child.gen(ctx) eval.code + s""" boolean ${ev.nullTerm} = ${eval.nullTerm}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 366e1083eb687..159df36ececff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, EvaluatedExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, EvaluatedExpression} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -81,7 +81,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres override def eval(input: Row): Any = value - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { if (value == null) { s""" final boolean ${ev.nullTerm} = true; @@ -113,7 +113,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value; """ case other => - super.genSource(ctx, ev) + super.genCode(ctx, ev) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 79c97f651f540..46582173e93b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -53,7 +53,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { result } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { s""" boolean ${ev.nullTerm} = true; ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)}; @@ -81,7 +81,7 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr child.eval(input) == null } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval = child.gen(ctx) eval.code + s""" final boolean ${ev.nullTerm} = false; @@ -101,7 +101,7 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E child.eval(input) != null } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval = child.gen(ctx) eval.code + s""" boolean ${ev.nullTerm} = false; @@ -132,7 +132,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate numNonNulls >= n } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val nonnull = ctx.freshName("nonnull") val code = children.map { e => val eval = e.gen(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3c1eeb07a91a4..4cd8bff0f4d47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ @@ -84,7 +84,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { castOrNull(ctx, ev, c => s"!($c)") } } @@ -146,7 +146,7 @@ case class And(left: Expression, right: Expression) } } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) s""" @@ -192,7 +192,7 @@ case class Or(left: Expression, right: Expression) } } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) s""" @@ -218,14 +218,14 @@ case class Or(left: Expression, right: Expression) abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { left.dataType match { case dt: NumericType if ctx.isNativeType(dt) => evaluate (ctx, ev, { (c1, c3) => s"$c1 $symbol $c3" }) case TimestampType => // java.sql.Timestamp does not have compare() - super.genSource(ctx, ev) + super.genCode(ctx, ev) case other => evaluate (ctx, ev, { (c1, c2) => s"$c1.compare($c2) $symbol 0" }) @@ -276,7 +276,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison if (left.dataType != BinaryType) l == r else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression) = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression) = { evaluate(ctx, ev, ctx.equalFunc(left.dataType)) } } @@ -300,7 +300,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val equalCode = ctx.equalFunc(left.dataType)(eval1.primitiveTerm, eval2.primitiveTerm) @@ -383,7 +383,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi falseValue.eval(input) } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val condEval = predicate.gen(ctx) val trueEval = trueValue.gen(ctx) val falseEval = falseValue.gen(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 22755b6ecb7e9..d62212d669276 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -61,7 +61,7 @@ case class NewSet(elementType: DataType) extends LeafExpression { new OpenHashSet[Any]() } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { elementType match { case IntegerType | LongType => s""" @@ -69,7 +69,7 @@ case class NewSet(elementType: DataType) extends LeafExpression { ${ctx.hashSetForType(elementType)} ${ev.primitiveTerm} = new ${ctx.hashSetForType(elementType)}(); """ - case _ => super.genSource(ctx, ev) + case _ => super.genCode(ctx, ev) } } @@ -104,7 +104,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType elementType match { case IntegerType | LongType => @@ -119,7 +119,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { boolean ${ev.nullTerm} = false; ${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm}; """ - case _ => super.genSource(ctx, ev) + case _ => super.genCode(ctx, ev) } } @@ -157,7 +157,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType elementType match { case IntegerType | LongType => @@ -170,7 +170,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres ${htype} ${ev.primitiveTerm} = ${leftEval.primitiveTerm}; ${ev.primitiveTerm}.union(${rightEval.primitiveTerm}); """ - case _ => super.genSource(ctx, ev) + case _ => super.genCode(ctx, ev) } } } @@ -191,7 +191,7 @@ case class CountSet(child: Expression) extends UnaryExpression { } } - override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = { + override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = { castOrNull(ctx, ev, c => s"$c.size().toLong()") }