From 2344bc0d48fc2a3ec91de69a6233665a0ae3635e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 4 Jun 2015 19:43:12 -0700 Subject: [PATCH] fix test --- .../expressions/codegen/CodeGenerator.scala | 17 +++++++++++++---- .../spark/sql/catalyst/expressions/sets.scala | 8 ++------ 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4885eec08fca9..06cc6e1024b01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -166,9 +166,12 @@ class CodeGenContext { * Returns a function to generate equal expression in Java */ def equalFunc(dataType: DataType): ((Term, Term) => Code) = dataType match { - case BinaryType => { case (eval1, eval2) => s"java.util.Arrays.equals($eval1, $eval2)" } - case dt if isNativeType(dt) => { case (eval1, eval2) => s"$eval1 == $eval2" } - case other => { case (eval1, eval2) => s"$eval1.equals($eval2)" } + case BinaryType => { case (eval1, eval2) => + s"java.util.Arrays.equals($eval1, $eval2)" } + case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType => + { case (eval1, eval2) => s"$eval1 == $eval2" } + case other => + { case (eval1, eval2) => s"$eval1.equals($eval2)" } } /** @@ -221,7 +224,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin */ protected def compile(code: String): Class[_] = { val startTime = System.nanoTime() - val clazz = new ClassBodyEvaluator(code).getClazz() + val clazz = try { + new ClassBodyEvaluator(code).getClazz() + } catch { + case e: Exception => + logError(s"failed to compile:\n $code", e) + throw e + } val endTime = System.nanoTime() def timeMs: Double = (endTime - startTime).toDouble / 1000000 logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 46cad9b019584..a0dae40d964e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -167,8 +167,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres leftEval.code + rightEval.code + s""" boolean ${ev.nullTerm} = false; - ${htype} ${ev.primitiveTerm} = ${leftEval.primitiveTerm}; - ${ev.primitiveTerm}.union(${rightEval.primitiveTerm}); + ${htype} ${ev.primitiveTerm} = (${htype})${leftEval.primitiveTerm}; + ${ev.primitiveTerm}.union((${htype})${rightEval.primitiveTerm}); """ case _ => super.genCode(ctx, ev) } @@ -191,9 +191,5 @@ case class CountSet(child: Expression) extends UnaryExpression { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - castOrNull(ctx, ev, c => s"$c.size().toLong()") - } - override def toString: String = s"$child.count()" }