Skip to content

Commit

Permalink
add type alias
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jun 5, 2015
1 parent 3ff25f8 commit e57959d
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext}
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext}
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.trees

Expand All @@ -43,7 +43,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)

override def exprId: ExprId = throw new UnsupportedOperationException

override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
s"""
final boolean ${ev.nullTerm} = i.isNullAt($ordinal);
final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext}
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext}
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -435,7 +435,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
if (evaluated == null) null else cast(evaluated)
}

override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = this match {
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = this match {

case Cast(child @ BinaryType(), StringType) =>
castOrNull (ctx, ev, c =>
Expand Down Expand Up @@ -465,7 +465,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
s"new org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))")

case other =>
super.genSource(ctx, ev)
super.genCode(ctx, ev)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext}
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -62,28 +62,14 @@ abstract class Expression extends TreeNode[Expression] {
val primitiveTerm = ctx.freshName("primitiveTerm")
val objectTerm = ctx.freshName("objectTerm")
val ve = EvaluatedExpression("", nullTerm, primitiveTerm, objectTerm)
ve.code = genSource(ctx, ve)

// Only inject debugging code if debugging is turned on.
// val debugCode =
// if (debugLogging) {
// val localLogger = log
// val localLoggerTree = reify { localLogger }
// s"""
// $localLoggerTree.debug(
// ${this.toString} + ": " + (if (${ev.nullTerm}) "null" else ${ev.primitiveTerm}.toString))
// """
// } else {
// ""
// }

ve.code = genCode(ctx, ve)
ve
}

/**
* Returns Java source code for this expression
*/
def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
val e = this.asInstanceOf[Expression]
ctx.references += e
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext}
import org.apache.spark.sql.catalyst.expressions.codegen.{Code, EvaluatedExpression, CodeGenContext}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -117,7 +117,7 @@ abstract class BinaryArithmetic extends BinaryExpression {
}
}

override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
if (left.dataType.isInstanceOf[DecimalType]) {
evaluate(ctx, ev, { case (eval1, eval2) => s"$eval1.$decimalMethod($eval2)" } )
} else {
Expand Down Expand Up @@ -205,7 +205,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
}
}

override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val test = if (left.dataType.isInstanceOf[DecimalType]) {
Expand Down Expand Up @@ -263,7 +263,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
}
}

override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val test = if (left.dataType.isInstanceOf[DecimalType]) {
Expand Down Expand Up @@ -406,7 +406,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
}

override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
if (ctx.isNativeType(left.dataType)) {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
Expand All @@ -430,7 +430,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
"""
} else {
super.genSource(ctx, ev)
super.genCode(ctx, ev)
}
}
override def toString: String = s"MaxOf($left, $right)"
Expand Down Expand Up @@ -460,7 +460,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
}

override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
if (ctx.isNativeType(left.dataType)) {

val eval1 = left.gen(ctx)
Expand All @@ -486,7 +486,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
"""
} else {
super.genSource(ctx, ev)
super.genCode(ctx, ev)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,22 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
* valid if `nullTerm` is set to `true`.
* @param objectTerm A possibly boxed version of the result of evaluating this expression.
*/
case class EvaluatedExpression(var code: String,
nullTerm: String,
primitiveTerm: String,
objectTerm: String)
case class EvaluatedExpression(var code: Code,
nullTerm: Term,
primitiveTerm: Term,
objectTerm: Term)

/**
* A context for codegen
* @param references the expressions that don't support codegen
* A context for codegen, which is used to bookkeeping the expressions those are not supported
* by codegen, then they are evaluated directly. The unsupported expression is appended at the
* end of `references`, the position of it is kept in the code, used to access and evaluate it.
*/
case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
class CodeGenContext {

/**
* Holding all the expressions those do not support codegen, will be evaluated directly.
*/
val references: Seq[Expression] = new mutable.ArrayBuffer[Expression]()

protected val stringType = classOf[UTF8String].getName
protected val decimalType = classOf[Decimal].getName
Expand All @@ -63,19 +69,19 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
* function.)
*/
def freshName(prefix: String): String = {
def freshName(prefix: String): Term = {
s"$prefix${curId.getAndIncrement}"
}

def getColumn(dataType: DataType, ordinal: Int): String = {
def getColumn(dataType: DataType, ordinal: Int): Code = {
dataType match {
case StringType => s"(org.apache.spark.sql.types.UTF8String)i.apply($ordinal)"
case dt: DataType if isNativeType(dt) => s"i.${accessorForType(dt)}($ordinal)"
case _ => s"(${boxedType(dataType)})i.apply($ordinal)"
}
}

def setColumn(destinationRow: String, dataType: DataType, ordinal: Int, value: String): String = {
def setColumn(destinationRow: Term, dataType: DataType, ordinal: Int, value: Term): Code = {
dataType match {
case StringType => s"$destinationRow.update($ordinal, $value)"
case dt: DataType if isNativeType(dt) =>
Expand All @@ -84,17 +90,17 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
}
}

def accessorForType(dt: DataType): String = dt match {
def accessorForType(dt: DataType): Term = dt match {
case IntegerType => "getInt"
case other => s"get${boxedType(dt)}"
}

def mutatorForType(dt: DataType): String = dt match {
def mutatorForType(dt: DataType): Term = dt match {
case IntegerType => "setInt"
case other => s"set${boxedType(dt)}"
}

def hashSetForType(dt: DataType): String = dt match {
def hashSetForType(dt: DataType): Term = dt match {
case IntegerType => classOf[IntegerHashSet].getName
case LongType => classOf[LongHashSet].getName
case unsupportedType =>
Expand All @@ -104,7 +110,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
/**
* Return the primitive type for a DataType
*/
def primitiveType(dt: DataType): String = dt match {
def primitiveType(dt: DataType): Term = dt match {
case IntegerType => "int"
case LongType => "long"
case ShortType => "short"
Expand All @@ -123,7 +129,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
/**
* Return the representation of default value for given DataType
*/
def defaultValue(dt: DataType): String = dt match {
def defaultValue(dt: DataType): Term = dt match {
case BooleanType => "false"
case FloatType => "-1.0f"
case ShortType => "-1"
Expand All @@ -140,7 +146,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
/**
* Return the boxed type in Java
*/
def boxedType(dt: DataType): String = dt match {
def boxedType(dt: DataType): Term = dt match {
case IntegerType => "Integer"
case LongType => "Long"
case ShortType => "Short"
Expand All @@ -159,7 +165,7 @@ case class CodeGenContext(references: mutable.ArrayBuffer[Expression]) {
/**
* Returns a function to generate equal expression in Java
*/
def equalFunc(dataType: DataType): ((String, String) => String) = dataType match {
def equalFunc(dataType: DataType): ((Term, Term) => Code) = dataType match {
case BinaryType => { case (eval1, eval2) => s"java.util.Arrays.equals($eval1, $eval2)" }
case dt if isNativeType(dt) => { case (eval1, eval2) => s"$eval1 == $eval2" }
case other => { case (eval1, eval2) => s"$eval1.equals($eval2)" }
Expand Down Expand Up @@ -257,6 +263,6 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
* expressions that don't support codegen
*/
def newCodeGenContext(): CodeGenContext = {
new CodeGenContext(new mutable.ArrayBuffer[Expression]())
new CodeGenContext
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ import org.apache.spark.util.Utils
*/
package object codegen {

type Term = String
type Code = String

/** Canonicalizes an expression so those that differ only by names can reuse the same code. */
object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] {
val batches =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, CodeGenContext}
import org.apache.spark.sql.catalyst.expressions.codegen.{EvaluatedExpression, Code, CodeGenContext}
import org.apache.spark.sql.types._

/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */
Expand All @@ -37,7 +37,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
}
}

override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
val eval = child.gen(ctx)
eval.code +s"""
boolean ${ev.nullTerm} = ${eval.nullTerm};
Expand All @@ -63,7 +63,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
}
}

override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.nullTerm} = ${eval.nullTerm};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, EvaluatedExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, EvaluatedExpression}
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -81,7 +81,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres

override def eval(input: Row): Any = value

override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
if (value == null) {
s"""
final boolean ${ev.nullTerm} = true;
Expand Down Expand Up @@ -113,7 +113,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value;
"""
case other =>
super.genSource(ctx, ev)
super.genCode(ctx, ev)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
result
}

override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
s"""
boolean ${ev.nullTerm} = true;
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
Expand Down Expand Up @@ -81,7 +81,7 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr
child.eval(input) == null
}

override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
val eval = child.gen(ctx)
eval.code + s"""
final boolean ${ev.nullTerm} = false;
Expand All @@ -101,7 +101,7 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E
child.eval(input) != null
}

override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.nullTerm} = false;
Expand Down Expand Up @@ -132,7 +132,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
numNonNulls >= n
}

override def genSource(ctx: CodeGenContext, ev: EvaluatedExpression): String = {
override def genCode(ctx: CodeGenContext, ev: EvaluatedExpression): Code = {
val nonnull = ctx.freshName("nonnull")
val code = children.map { e =>
val eval = e.gen(ctx)
Expand Down
Loading

0 comments on commit e57959d

Please sign in to comment.