Skip to content

Commit

Permalink
[SQL] Improve SparkSQL Aggregates
Browse files Browse the repository at this point in the history
* Add native min/max (was using hive before).
* Handle nulls correctly in Avg and Sum.

Author: Michael Armbrust <[email protected]>

Closes apache#683 from marmbrus/aggFixes and squashes the following commits:

64fe30b [Michael Armbrust] Improve SparkSQL Aggregates * Add native min/max (was using hive before). * Handle nulls correctly in Avg and Sum.
  • Loading branch information
marmbrus authored and rxin committed May 8, 2014
1 parent 6ed7e2c commit 19c8fb0
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val JOIN = Keyword("JOIN")
protected val LEFT = Keyword("LEFT")
protected val LIMIT = Keyword("LIMIT")
protected val MAX = Keyword("MAX")
protected val MIN = Keyword("MIN")
protected val NOT = Keyword("NOT")
protected val NULL = Keyword("NULL")
protected val ON = Keyword("ON")
Expand Down Expand Up @@ -318,6 +320,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } |
FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } |
AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } |
MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } |
MAX ~> "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } |
IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ {
case c ~ "," ~ t ~ "," ~ f => If(c,t,f)
} |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,67 @@ abstract class AggregateFunction
override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
}

case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = child.nullable
override def dataType = child.dataType
override def toString = s"MIN($child)"

override def asPartial: SplitEvaluation = {
val partialMin = Alias(Min(child), "PartialMin")()
SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil)
}

override def newInstance() = new MinFunction(child, this)
}

case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.

var currentMin: Any = _

override def update(input: Row): Unit = {
if (currentMin == null) {
currentMin = expr.eval(input)
} else if(GreaterThan(Literal(currentMin, expr.dataType), expr).eval(input) == true) {
currentMin = expr.eval(input)
}
}

override def eval(input: Row): Any = currentMin
}

case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = child.nullable
override def dataType = child.dataType
override def toString = s"MAX($child)"

override def asPartial: SplitEvaluation = {
val partialMax = Alias(Max(child), "PartialMax")()
SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil)
}

override def newInstance() = new MaxFunction(child, this)
}

case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.

var currentMax: Any = _

override def update(input: Row): Unit = {
if (currentMax == null) {
currentMax = expr.eval(input)
} else if(LessThan(Literal(currentMax, expr.dataType), expr).eval(input) == true) {
currentMax = expr.eval(input)
}
}

override def eval(input: Row): Any = currentMax
}


case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
Expand All @@ -97,7 +158,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
}

override def newInstance()= new CountFunction(child, this)
override def newInstance() = new CountFunction(child, this)
}

case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression {
Expand All @@ -106,7 +167,7 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi
override def nullable = false
override def dataType = IntegerType
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
override def newInstance()= new CountDistinctFunction(expressions, this)
override def newInstance() = new CountDistinctFunction(expressions, this)
}

case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
Expand All @@ -126,7 +187,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
partialCount :: partialSum :: Nil)
}

override def newInstance()= new AverageFunction(child, this)
override def newInstance() = new AverageFunction(child, this)
}

case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
Expand All @@ -142,7 +203,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
partialSum :: Nil)
}

override def newInstance()= new SumFunction(child, this)
override def newInstance() = new SumFunction(child, this)
}

case class SumDistinct(child: Expression)
Expand All @@ -153,7 +214,7 @@ case class SumDistinct(child: Expression)
override def dataType = child.dataType
override def toString = s"SUM(DISTINCT $child)"

override def newInstance()= new SumDistinctFunction(child, this)
override def newInstance() = new SumDistinctFunction(child, this)
}

case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
Expand All @@ -168,19 +229,21 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod
First(partialFirst.toAttribute),
partialFirst :: Nil)
}
override def newInstance()= new FirstFunction(child, this)
override def newInstance() = new FirstFunction(child, this)
}

case class AverageFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {

def this() = this(null, null) // Required for serialization.

private val zero = Cast(Literal(0), expr.dataType)

private var count: Long = _
private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(EmptyRow))
private val sum = MutableLiteral(zero.eval(EmptyRow))
private val sumAsDouble = Cast(sum, DoubleType)

private val addFunction = Add(sum, expr)
private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))

override def eval(input: Row): Any =
sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble
Expand Down Expand Up @@ -209,9 +272,11 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.

private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(null))
private val zero = Cast(Literal(0), expr.dataType)

private val sum = MutableLiteral(zero.eval(null))

private val addFunction = Add(sum, expr)
private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))

override def update(input: Row): Unit = {
sum.update(addFunction, input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ class SQLQuerySuite extends QueryTest {
Seq((1,3),(2,3),(3,3)))
}

test("aggregates with nulls") {
checkAnswer(
sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"),
(1, 3, 2, 6, 3) :: Nil
)
}

test("select *") {
checkAnswer(
sql("SELECT * FROM testData"),
Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,14 @@ object TestData {
List.fill(2)(StringData(null)) ++
List.fill(2)(StringData("test")))
nullableRepeatedData.registerAsTable("nullableRepeatedData")

case class NullInts(a: Integer)
val nullInts =
TestSQLContext.sparkContext.parallelize(
NullInts(1) ::
NullInts(2) ::
NullInts(3) ::
NullInts(null) :: Nil
)
nullInts.registerAsTable("nullInts")
}

0 comments on commit 19c8fb0

Please sign in to comment.