Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jun 6, 2015
1 parent f42c732 commit 9adaeaf
Show file tree
Hide file tree
Showing 11 changed files with 34 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
s"""
boolean ${ev.isNull} = i.isNullAt($ordinal);
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ev.isNull} ?
${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ?
${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)});
"""
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
super.genCode(ctx, ev)

case (BooleanType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c ? 1 : 0)")
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
case (dt: DecimalType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"$c.isZero()")
case (dt: NumericType, BooleanType) =>
Expand All @@ -469,7 +469,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (_: DecimalType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")
case (_: NumericType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)")
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)")

case other =>
super.genCode(ctx, ev)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,13 @@ abstract class Expression extends TreeNode[Expression] {
* @return Java source code
*/
protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val e = this.asInstanceOf[Expression]
ctx.references += e
ctx.references += this
val objectTerm = ctx.freshName("obj")
s"""
/* expression: ${this} */
Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
boolean ${ev.isNull} = ${objectTerm} == null;
${ctx.primitiveType(e.dataType)} ${ev.primitive} = ${ctx.defaultValue(e.dataType)};
${ctx.javaType(e.dataType)} ${ev.primitive} = ${ctx.defaultValue(e.dataType)};
if (!${ev.isNull}) {
${ev.primitive} = (${ctx.boxedType(e.dataType)})${objectTerm};
}
Expand Down Expand Up @@ -180,7 +179,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${eval2.code}
if(!${eval2.isNull}) {
Expand Down Expand Up @@ -219,7 +218,7 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
// reuse the previous isNull
ev.isNull = eval.isNull
eval.code + s"""
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.primitive} = ${f(eval.primitive)};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ case class Sqrt(child: Expression) extends UnaryArithmetic {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
if (${eval.primitive} < 0.0) {
${ev.isNull} = true;
Expand Down Expand Up @@ -144,7 +144,7 @@ abstract class BinaryArithmetic extends BinaryExpression {
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
defineCodeGen(ctx, ev, (eval1, eval2) =>
s"(${ctx.primitiveType(dataType)})($eval1 $symbol $eval2)")
s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
Expand Down Expand Up @@ -248,7 +248,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
eval1.code + eval2.code +
s"""
boolean ${ev.isNull} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
if (${eval1.isNull} || ${eval2.isNull} || $test) {
${ev.isNull} = true;
} else {
Expand Down Expand Up @@ -308,7 +308,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
eval1.code + eval2.code +
s"""
boolean ${ev.isNull} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
if (${eval1.isNull} || ${eval2.isNull} || $test) {
${ev.isNull} = true;
} else {
Expand Down Expand Up @@ -408,7 +408,7 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic {
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dataType)})~($c)")
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)})~($c)")
}

protected override def evalInternal(evalE: Any) = not(evalE)
Expand Down Expand Up @@ -444,7 +444,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
val eval2 = right.gen(ctx)
eval1.code + eval2.code + s"""
boolean ${ev.isNull} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitive} =
${ctx.javaType(left.dataType)} ${ev.primitive} =
${ctx.defaultValue(left.dataType)};

if (${eval1.isNull}) {
Expand Down Expand Up @@ -500,7 +500,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {

eval1.code + eval2.code + s"""
boolean ${ev.isNull} = false;
${ctx.primitiveType(left.dataType)} ${ev.primitive} =
${ctx.javaType(left.dataType)} ${ev.primitive} =
${ctx.defaultValue(left.dataType)};

if (${eval1.isNull}) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class CodeGenContext {
*/
val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]()

val stringType = classOf[UTF8String].getName
val decimalType = classOf[Decimal].getName
val stringType: String = classOf[UTF8String].getName
val decimalType: String = classOf[Decimal].getName

private val curId = new java.util.concurrent.atomic.AtomicInteger()

Expand Down Expand Up @@ -108,9 +108,9 @@ class CodeGenContext {
}

/**
* Return the primitive type for a DataType
* Return the Java type for a DataType
*/
def primitiveType(dt: DataType): Term = dt match {
def javaType(dt: DataType): Term = dt match {
case IntegerType => "int"
case LongType => "long"
case ShortType => "short"
Expand Down Expand Up @@ -140,7 +140,7 @@ class CodeGenContext {
case FloatType => "Float"
case BooleanType => "Boolean"
case DateType => "Integer"
case _ => primitiveType(dt)
case _ => javaType(dt)
}

/**
Expand Down Expand Up @@ -189,9 +189,9 @@ class CodeGenContext {
*/
abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging {

protected val exprType = classOf[Expression].getName
protected val mutableRowType = classOf[MutableRow].getName
protected val genericMutableRowType = classOf[GenericMutableRow].getName
protected val exprType: String = classOf[Expression].getName
protected val mutableRowType: String = classOf[MutableRow].getName
protected val genericMutableRowType: String = classOf[GenericMutableRow].getName

/**
* Can be flipped on manually in the console to add (expensive) expression evaluation trace code.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
val ctx = newCodeGenContext()
val columns = expressions.zipWithIndex.map {
case (e, i) =>
s"private ${ctx.primitiveType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n"
s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n"
}.mkString("\n ")

val initColumns = expressions.zipWithIndex.map {
Expand Down Expand Up @@ -80,7 +80,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
if (cases.count(_ != '\n') > 0) {
s"""
@Override
public ${ctx.primitiveType(dataType)} ${ctx.accessorForType(dataType)}(int i) {
public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) {
if (isNullAt(i)) {
return ${ctx.defaultValue(dataType)};
}
Expand All @@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
if (cases.count(_ != '\n') > 0) {
s"""
@Override
public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.primitiveType(dataType)} value) {
public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) {
nullBits[i] = false;
switch (i) {
$cases
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres

case ByteType | ShortType => // This must go before NumericType
ev.isNull = "false"
ev.primitive = s"(${ctx.primitiveType(dataType)})$value"
ev.primitive = s"(${ctx.javaType(dataType)})$value"
""
case dt: NumericType if !dt.isInstanceOf[DecimalType] =>
ev.isNull = "false"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive});
if (Double.valueOf(${ev.primitive}).isNaN()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
s"""
boolean ${ev.isNull} = true;
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
""" +
children.map { e =>
val eval = e.gen(ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
s"""
${condEval.code}
boolean ${ev.isNull} = false;
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${condEval.isNull} && ${condEval.primitive}) {
${trueEval.code}
${ev.isNull} = ${trueEval.isNull};
Expand Down Expand Up @@ -530,7 +530,7 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
s"""
boolean $got = false;
boolean ${ev.isNull} = true;
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
$cases
$other
"""
Expand Down Expand Up @@ -626,7 +626,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
s"""
boolean $got = false;
boolean ${ev.isNull} = true;
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
${keyEval.code}
$cases
$other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ case class NewSet(elementType: DataType) extends LeafExpression {
case IntegerType | LongType =>
ev.isNull = "false"
s"""
${ctx.primitiveType(dataType)} ${ev.primitive} = new ${ctx.primitiveType(dataType)}();
${ctx.javaType(dataType)} ${ev.primitive} = new ${ctx.javaType(dataType)}();
"""
case _ => super.genCode(ctx, ev)
}
Expand Down Expand Up @@ -109,7 +109,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
case IntegerType | LongType =>
val itemEval = item.gen(ctx)
val setEval = set.gen(ctx)
val htype = ctx.primitiveType(dataType)
val htype = ctx.javaType(dataType)

ev.isNull = "false"
ev.primitive = setEval.primitive
Expand Down Expand Up @@ -160,7 +160,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
case IntegerType | LongType =>
val leftEval = left.gen(ctx)
val rightEval = right.gen(ctx)
val htype = ctx.primitiveType(dataType)
val htype = ctx.javaType(dataType)

ev.isNull = leftEval.isNull
ev.primitive = leftEval.primitive
Expand Down

0 comments on commit 9adaeaf

Please sign in to comment.