Skip to content

Commit

Permalink
Optimize the Constant Folding by adding more rules
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghao-intel committed Apr 28, 2014
1 parent 2645d4f commit 3c045c7
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ abstract class Expression extends TreeNode[Expression] {
* - A [[expressions.Cast Cast]] or [[expressions.UnaryMinus UnaryMinus]] is foldable if its
* child is foldable.
*/
// TODO: Supporting more foldable expressions. For example, deterministic Hive UDFs.
def foldable: Boolean = false
def nullable: Boolean
def references: Set[Attribute]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,25 +94,9 @@ object ConstantFolding extends Rule[LogicalPlan] {
case q: LogicalPlan => q transformExpressionsDown {
// Skip redundant folding of literals.
case l: Literal => l
case e @ If(Literal(v, _), trueValue, falseValue) => if(v == true) trueValue else falseValue
case e @ In(Literal(v, _), list) if(list.exists(c => c match {
case Literal(candidate, _) if(candidate == v) => true
case _ => false
})) => Literal(true, BooleanType)
case e if e.foldable => Literal(e.eval(null), e.dataType)
}
}
}

/**
* The expression may be constant value, due to one or more of its children expressions is null or
* not null constantly, replaces [[catalyst.expressions.Expression Expressions]] with equivalent
* [[catalyst.expressions.Literal Literal]] values if possible caused by that.
*/
object NullPropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case l: Literal => l
case e @ Count(Literal(null, _)) => Literal(null, e.dataType)
case e @ Sum(Literal(null, _)) => Literal(null, e.dataType)
case e @ Average(Literal(null, _)) => Literal(null, e.dataType)
case e @ IsNull(Literal(null, _)) => Literal(true, BooleanType)
case e @ IsNull(Literal(_, _)) => Literal(false, BooleanType)
case e @ IsNull(c @ Rand) => Literal(false, BooleanType)
Expand All @@ -135,6 +119,11 @@ object NullPropagation extends Rule[LogicalPlan] {
Coalesce(newChildren)
}
}
case e @ If(Literal(v, _), trueValue, falseValue) => if(v == true) trueValue else falseValue
case e @ In(Literal(v, _), list) if(list.exists(c => c match {
case Literal(candidate, _) if(candidate == v) => true
case _ => false
})) => Literal(true, BooleanType)
// TODO put exceptional cases(Unary & Binary Expression) before here.
case e: UnaryExpression => e.child match {
case Literal(null, _) => Literal(null, e.dataType)
Expand All @@ -143,6 +132,7 @@ object NullPropagation extends Rule[LogicalPlan] {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
}
case e if e.foldable => Literal(e.eval(null), e.dataType)
}
}
}
Expand Down
11 changes: 11 additions & 0 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.hive.common.`type`.HiveDecimal
import org.apache.hadoop.hive.ql.exec.UDF
import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry}
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
import org.apache.hadoop.hive.ql.udf.generic._
import org.apache.hadoop.hive.serde2.objectinspector._
import org.apache.hadoop.hive.serde2.objectinspector.primitive._
Expand Down Expand Up @@ -213,6 +214,16 @@ private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression])

@transient
protected lazy val returnInspector = function.initialize(argumentInspectors.toArray)

@transient
protected lazy val isUDFDeterministic = {
val udfType = function.getClass().getAnnotation(classOf[HiveUDFType])
(udfType != null && udfType.deterministic())
}

override def foldable = {
isUDFDeterministic && children.foldLeft(true)((prev, n) => prev && n.foldable)
}

val dataType: DataType = inspectorToDataType(returnInspector)

Expand Down

0 comments on commit 3c045c7

Please sign in to comment.