Skip to content

Commit

Permalink
improve coverage and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jun 6, 2015
1 parent bad6828 commit f42c732
Show file tree
Hide file tree
Showing 20 changed files with 440 additions and 237 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
s"""
final boolean ${ev.nullTerm} = i.isNullAt($ordinal);
final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ?
boolean ${ev.isNull} = i.isNullAt($ordinal);
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ev.isNull} ?
${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)});
"""
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ abstract class Expression extends TreeNode[Expression] {
* @return [[GeneratedExpressionCode]]
*/
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
val nullTerm = ctx.freshName("nullTerm")
val primitiveTerm = ctx.freshName("primitiveTerm")
val ve = GeneratedExpressionCode("", nullTerm, primitiveTerm)
val isNull = ctx.freshName("isNull")
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
ve.code = genCode(ctx, ve)
ve
}
Expand All @@ -82,11 +82,11 @@ abstract class Expression extends TreeNode[Expression] {
val objectTerm = ctx.freshName("obj")
s"""
/* expression: ${this} */
final Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
final boolean ${ev.nullTerm} = ${objectTerm} == null;
${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(e.dataType)};
if (!${ev.nullTerm}) {
${ev.primitiveTerm} = (${ctx.boxedType(e.dataType)})${objectTerm};
Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
boolean ${ev.isNull} = ${objectTerm} == null;
${ctx.primitiveType(e.dataType)} ${ev.primitive} = ${ctx.defaultValue(e.dataType)};
if (!${ev.isNull}) {
${ev.primitive} = (${ctx.boxedType(e.dataType)})${objectTerm};
}
"""
}
Expand Down Expand Up @@ -175,18 +175,18 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express

val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)
val resultCode = f(eval1.primitive, eval2.primitive)

s"""
${eval1.code}
boolean ${ev.nullTerm} = ${eval1.nullTerm};
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
if (!${ev.nullTerm}) {
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${eval2.code}
if(!${eval2.nullTerm}) {
${ev.primitiveTerm} = $resultCode;
if(!${eval2.isNull}) {
${ev.primitive} = $resultCode;
} else {
${ev.nullTerm} = true;
${ev.isNull} = true;
}
}
"""
Expand Down Expand Up @@ -216,12 +216,12 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
ev: GeneratedExpressionCode,
f: Term => Code): Code = {
val eval = child.gen(ctx)
// reuse the previous nullTerm
ev.nullTerm = eval.nullTerm
// reuse the previous isNull
ev.isNull = eval.isNull
eval.code + s"""
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
if (!${ev.nullTerm}) {
${ev.primitiveTerm} = ${f(eval.primitiveTerm)};
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.primitive} = ${f(eval.primitive)};
}
"""
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {

private lazy val numeric = TypeUtils.getNumeric(dataType)

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match {
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()")
case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)")
}

protected override def evalInternal(evalE: Any) = numeric.negate(evalE)
}

Expand All @@ -68,6 +73,21 @@ case class Sqrt(child: Expression) extends UnaryArithmetic {
if (value < 0) null
else math.sqrt(value)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
if (${eval.primitive} < 0.0) {
${ev.isNull} = true;
} else {
${ev.primitive} = java.lang.Math.sqrt(${eval.primitive});
}
}
"""
}
}

/**
Expand Down Expand Up @@ -216,9 +236,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val test = if (left.dataType.isInstanceOf[DecimalType]) {
s"${eval2.primitiveTerm}.isZero()"
s"${eval2.primitive}.isZero()"
} else {
s"${eval2.primitiveTerm} == 0"
s"${eval2.primitive} == 0"
}
val method = if (left.dataType.isInstanceOf[DecimalType]) {
s".$decimalMethod"
Expand All @@ -227,12 +247,12 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
}
eval1.code + eval2.code +
s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)};
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
${ev.nullTerm} = true;
boolean ${ev.isNull} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
if (${eval1.isNull} || ${eval2.isNull} || $test) {
${ev.isNull} = true;
} else {
${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm});
${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive});
}
"""
}
Expand Down Expand Up @@ -276,9 +296,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val test = if (left.dataType.isInstanceOf[DecimalType]) {
s"${eval2.primitiveTerm}.isZero()"
s"${eval2.primitive}.isZero()"
} else {
s"${eval2.primitiveTerm} == 0"
s"${eval2.primitive} == 0"
}
val method = if (left.dataType.isInstanceOf[DecimalType]) {
s".$decimalMethod"
Expand All @@ -287,12 +307,12 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
}
eval1.code + eval2.code +
s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)};
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
${ev.nullTerm} = true;
boolean ${ev.isNull} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
if (${eval1.isNull} || ${eval2.isNull} || $test) {
${ev.isNull} = true;
} else {
${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm});
${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive});
}
"""
}
Expand Down Expand Up @@ -387,6 +407,10 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic {
((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any]
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dataType)})~($c)")
}

protected override def evalInternal(evalE: Any) = not(evalE)
}

Expand Down Expand Up @@ -419,21 +443,21 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
eval1.code + eval2.code + s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
boolean ${ev.isNull} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitive} =
${ctx.defaultValue(left.dataType)};

if (${eval1.nullTerm}) {
${ev.nullTerm} = ${eval2.nullTerm};
${ev.primitiveTerm} = ${eval2.primitiveTerm};
} else if (${eval2.nullTerm}) {
${ev.nullTerm} = ${eval1.nullTerm};
${ev.primitiveTerm} = ${eval1.primitiveTerm};
if (${eval1.isNull}) {
${ev.isNull} = ${eval2.isNull};
${ev.primitive} = ${eval2.primitive};
} else if (${eval2.isNull}) {
${ev.isNull} = ${eval1.isNull};
${ev.primitive} = ${eval1.primitive};
} else {
if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) {
${ev.primitiveTerm} = ${eval1.primitiveTerm};
if (${eval1.primitive} > ${eval2.primitive}) {
${ev.primitive} = ${eval1.primitive};
} else {
${ev.primitiveTerm} = ${eval2.primitiveTerm};
${ev.primitive} = ${eval2.primitive};
}
}
"""
Expand Down Expand Up @@ -475,21 +499,21 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
val eval2 = right.gen(ctx)

eval1.code + eval2.code + s"""
boolean ${ev.nullTerm} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
boolean ${ev.isNull} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitive} =
${ctx.defaultValue(left.dataType)};

if (${eval1.nullTerm}) {
${ev.nullTerm} = ${eval2.nullTerm};
${ev.primitiveTerm} = ${eval2.primitiveTerm};
} else if (${eval2.nullTerm}) {
${ev.nullTerm} = ${eval1.nullTerm};
${ev.primitiveTerm} = ${eval1.primitiveTerm};
if (${eval1.isNull}) {
${ev.isNull} = ${eval2.isNull};
${ev.primitive} = ${eval2.primitive};
} else if (${eval2.isNull}) {
${ev.isNull} = ${eval1.isNull};
${ev.primitive} = ${eval1.primitive};
} else {
if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) {
${ev.primitiveTerm} = ${eval1.primitiveTerm};
if (${eval1.primitive} < ${eval2.primitive}) {
${ev.primitive} = ${eval1.primitive};
} else {
${ev.primitiveTerm} = ${eval2.primitiveTerm};
${ev.primitive} = ${eval2.primitive};
}
}
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
* Java source for evaluating an [[Expression]] given a [[Row]] of input.
*
* @param code The sequence of statements required to evaluate the expression.
* @param nullTerm A term that holds a boolean value representing whether the expression evaluated
* @param isNull A term that holds a boolean value representing whether the expression evaluated
* to null.
* @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not
* valid if `nullTerm` is set to `true`.
* @param primitive A term for a possible primitive value of the result of the evaluation. Not
* valid if `isNull` is set to `true`.
*/
case class GeneratedExpressionCode(var code: Code, var nullTerm: Term, var primitiveTerm: Term)
case class GeneratedExpressionCode(var code: Code, var isNull: Term, var primitive: Term)

/**
* A context for codegen, which is used to bookkeeping the expressions those are not supported
Expand Down Expand Up @@ -149,9 +149,9 @@ class CodeGenContext {
def defaultValue(dt: DataType): Term = dt match {
case BooleanType => "false"
case FloatType => "-1.0f"
case ShortType => "-1"
case LongType => "-1"
case ByteType => "-1"
case ShortType => "(short)-1"
case LongType => "-1L"
case ByteType => "(byte)-1"
case DoubleType => "-1.0"
case IntegerType => "-1"
case DateType => "-1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
val evaluationCode = e.gen(ctx)
evaluationCode.code +
s"""
if(${evaluationCode.nullTerm})
if(${evaluationCode.isNull})
mutableRow.setNullAt($i);
else
mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitiveTerm)};
mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)};
"""
}.mkString("\n")
val code = s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
case BinaryType =>
s"""
{
byte[] x = ${if (asc) evalA.primitiveTerm else evalB.primitiveTerm};
byte[] y = ${if (!asc) evalB.primitiveTerm else evalA.primitiveTerm};
byte[] x = ${if (asc) evalA.primitive else evalB.primitive};
byte[] y = ${if (!asc) evalB.primitive else evalA.primitive};
int j = 0;
while (j < x.length && j < y.length) {
if (x[j] != y[j]) return x[j] - y[j];
Expand All @@ -73,16 +73,16 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
}"""
case _: NumericType =>
s"""
if (${evalA.primitiveTerm} != ${evalB.primitiveTerm}) {
if (${evalA.primitiveTerm} > ${evalB.primitiveTerm}) {
if (${evalA.primitive} != ${evalB.primitive}) {
if (${evalA.primitive} > ${evalB.primitive}) {
return ${if (asc) "1" else "-1"};
} else {
return ${if (asc) "-1" else "1"};
}
}"""
case _ =>
s"""
int comp = ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm});
int comp = ${evalA.primitive}.compare(${evalB.primitive});
if (comp != 0) {
return ${if (asc) "comp" else "-comp"};
}"""
Expand All @@ -93,11 +93,11 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
${evalA.code}
i = $b;
${evalB.code}
if (${evalA.nullTerm} && ${evalB.nullTerm}) {
if (${evalA.isNull} && ${evalB.isNull}) {
// Nothing
} else if (${evalA.nullTerm}) {
} else if (${evalA.isNull}) {
return ${if (order.direction == Ascending) "-1" else "1"};
} else if (${evalB.nullTerm}) {
} else if (${evalB.isNull}) {
return ${if (order.direction == Ascending) "1" else "-1"};
} else {
$compare
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] {
@Override
public boolean eval(Row i) {
${eval.code}
return !${eval.nullTerm} && ${eval.primitiveTerm};
return !${eval.isNull} && ${eval.primitive};
}
}"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
{
// column$i
${eval.code}
nullBits[$i] = ${eval.nullTerm};
if(!${eval.nullTerm}) {
c$i = ${eval.primitiveTerm};
nullBits[$i] = ${eval.isNull};
if (!${eval.isNull}) {
c$i = ${eval.primitive};
}
}
"""
Expand Down Expand Up @@ -122,7 +122,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
case LongType => s"$col ^ ($col >>> 32)"
case FloatType => s"Float.floatToIntBits($col)"
case DoubleType =>
s"Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32)"
s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))"
case _ => s"$col.hashCode()"
}
s"isNullAt($i) ? 0 : ($nonNull)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.nullTerm} = ${eval.nullTerm};
${ctx.decimalType} ${ev.primitiveTerm} = null;
boolean ${ev.isNull} = ${eval.isNull};
${ctx.decimalType} ${ev.primitive} = null;

if (!${ev.nullTerm}) {
${ev.primitiveTerm} = (new ${ctx.decimalType}()).setOrNull(
${eval.primitiveTerm}, $precision, $scale);
${ev.nullTerm} = ${ev.primitiveTerm} == null;
if (!${ev.isNull}) {
${ev.primitive} = (new ${ctx.decimalType}()).setOrNull(
${eval.primitive}, $precision, $scale);
${ev.isNull} = ${ev.primitive} == null;
}
"""
}
Expand Down
Loading

0 comments on commit f42c732

Please sign in to comment.