Skip to content

Commit

Permalink
[SPARK-6117] [SQL] simplify implementation, add test for DF without n…
Browse files Browse the repository at this point in the history
…umeric columns
  • Loading branch information
azagrebin committed Mar 18, 2015
1 parent 9daf31e commit ddb3950
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 64 deletions.
83 changes: 26 additions & 57 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -751,56 +751,6 @@ class DataFrame private[sql](
select(colNames :_*)
}

/**
* Compute specified aggregations for given columns of this [[DataFrame]].
* Each row of the resulting [[DataFrame]] contains column with aggregation name
* and columns with aggregation results for each given column.
* The aggregations are described as a List of mappings of their name to function
* which generates aggregation expression from column name.
*
* Note: can process only simple aggregation expressions
* which can be parsed by spark [[SqlParser]]
*
* {{{
* val aggregations = List(
* "max" -> (col => s"max($col)"), // expression computes max
* "avg" -> (col => s"sum($col)/count($col)")) // expression computes average
* df.multipleAggExpr("summary", aggregations, "age", "height")
*
* // summary age height
* // max 92.0 192.0
* // avg 53.0 178.0
* }}}
*/
@scala.annotation.varargs
private def multipleAggExpr(
aggCol: String,
aggregations: List[(String, String => String)],
cols: String*): DataFrame = {

val sqlParser = new SqlParser()

def addAggNameCol(aggDF: DataFrame, aggName: String = "") =
aggDF.selectExpr(s"'$aggName' as $aggCol"::cols.toList:_*)

def unionWithNextAgg(aggSoFarDF: DataFrame, nextAgg: (String, String => String)) =
nextAgg match { case (aggName, colToAggExpr) =>
val nextAggDF = if (cols.nonEmpty) {
def colToAggCol(col: String) =
Column(sqlParser.parseExpression(colToAggExpr(col))).as(col)
val aggCols = cols.map(colToAggCol)
agg(aggCols.head, aggCols.tail:_*)
} else {
sqlContext.emptyDataFrame
}
val nextAggWithNameDF = addAggNameCol(nextAggDF, aggName)
aggSoFarDF.unionAll(nextAggWithNameDF)
}

val emptyAgg = addAggNameCol(this).limit(0)
aggregations.foldLeft(emptyAgg)(unionWithNextAgg)
}

/**
* Compute numerical statistics for given columns of this [[DataFrame]]:
* count, mean (avg), stddev (standard deviation), min, max.
Expand All @@ -821,14 +771,33 @@ class DataFrame private[sql](
*/
@scala.annotation.varargs
def describe(cols: String*): DataFrame = {

def aggCol(name: String = "") = s"'$name' as summary"
val statistics = List[(String, Expression => Expression)](
"count" -> (expr => Count(expr)),
"mean" -> (expr => Average(expr)),
"stddev" -> (expr => Sqrt(Subtract(Average(Multiply(expr, expr)),
Multiply(Average(expr), Average(expr))))),
"min" -> (expr => Min(expr)),
"max" -> (expr => Max(expr)))

val numCols = if (cols.isEmpty) numericColumns.map(_.prettyString) else cols
val aggregations = List[(String, String => String)](
"count" -> (col => s"count($col)"),
"mean" -> (col => s"avg($col)"),
"stddev" -> (col => s"sqrt(avg($col*$col) - avg($col)*avg($col))"),
"min" -> (col => s"min($col)"),
"max" -> (col => s"max($col)"))
multipleAggExpr("summary", aggregations, numCols:_*)

// union all statistics starting from empty one
var description = selectExpr(aggCol()::numCols.toList:_*).limit(0)
for ((name, colToAgg) <- statistics) {
// generate next statistic aggregation
val nextAgg = if (numCols.nonEmpty) {
val aggCols = numCols.map(c => Column(colToAgg(Column(c).expr)).as(c))
agg(aggCols.head, aggCols.tail:_*)
} else {
sqlContext.emptyDataFrame
}
// add statistic name column
val nextStat = nextAgg.selectExpr(aggCol(name)::numCols.toList:_*)
description = description.unionAll(nextStat)
}
description
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,18 +439,22 @@ class DataFrameSuite extends QueryTest {
test("describe") {
def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq

val describeAllCols = describeTestData.describe("age", "height")
val describeTwoCols = describeTestData.describe("age", "height")
assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height"))
checkAnswer(describeTwoCols, describeResult)

val describeAllCols = describeTestData.describe()
assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height"))
checkAnswer(describeAllCols, describeResult)

val describeNoCols = describeTestData.describe()
assert(getSchemaAsSeq(describeNoCols) === Seq("summary", "age", "height"))
checkAnswer(describeNoCols, describeResult)

val describeOneCol = describeTestData.describe("age")
assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age"))
checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} )

val describeNoCol = describeTestData.select("name").describe()
assert(getSchemaAsSeq(describeNoCol) === Seq("summary"))
checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _) => Row(s)} )

val emptyDescription = describeTestData.limit(0).describe()
assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height"))
checkAnswer(emptyDescription, emptyDescribeResult)
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ object TestData {
Row("count", 4.0, 4.0) ::
Row("mean", 33.0, 178.0) ::
Row("stddev", 16.583123951777, 10.0) ::
Row("min", 16.0, 164) ::
Row("max", 60.0, 192) :: Nil
Row("min", 16.0, 164.0) ::
Row("max", 60.0, 192.0) :: Nil
val emptyDescribeResult =
Row("count", 0, 0) ::
Row("mean", null, null) ::
Expand Down

0 comments on commit ddb3950

Please sign in to comment.