diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 7c2b8a9407884..2c7c58e66b855 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst private[spark] trait CatalystConf { def caseSensitiveAnalysis: Boolean - - protected[spark] def specializeSingleDistinctAggPlanning: Boolean } /** @@ -31,13 +29,8 @@ object EmptyConf extends CatalystConf { override def caseSensitiveAnalysis: Boolean = { throw new UnsupportedOperationException } - - protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = { - throw new UnsupportedOperationException - } } /** A CatalystConf that can be used for local testing. */ case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf { - protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 9c78f6d4cc71b..4e7d1341028ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -123,15 +123,8 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP .filter(_.isDistinct) .groupBy(_.aggregateFunction.children.toSet) - val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) { - // When the flag is set to specialize single distinct agg planning, - // we will rely on our Aggregation strategy to handle queries with a single - // distinct column. - distinctAggGroups.size > 1 - } else { - distinctAggGroups.size >= 1 - } - if (shouldRewrite) { + // Aggregation strategy can handle the query with single distinct + if (distinctAggGroups.size > 1) { // Create the attributes for the grouping id and the group by clause. val gid = new AttributeReference("gid", IntegerType, false)() val groupByMap = a.groupingExpressions.collect { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 58adf64e49869..3d819262859f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -449,18 +449,6 @@ private[spark] object SQLConf { doc = "When true, we could use `datasource`.`path` as table in SQL query" ) - val SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING = - booleanConf("spark.sql.specializeSingleDistinctAggPlanning", - defaultValue = Some(false), - isPublic = false, - doc = "When true, if a query only has a single distinct column and it has " + - "grouping expressions, we will use our planner rule to handle this distinct " + - "column (other cases are handled by DistinctAggregationRewriter). " + - "When false, we will always use DistinctAggregationRewriter to plan " + - "aggregation queries with DISTINCT keyword. This is an internal flag that is " + - "used to benchmark the performance impact of using DistinctAggregationRewriter to " + - "plan aggregation queries with a single distinct column.") - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" @@ -579,9 +567,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) - protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = - getConf(SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING) - /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 008478a6a0e17..0c74df0aa5fdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.execution.aggregate +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import scala.collection.mutable.ArrayBuffer - /** - * The base class of [[SortBasedAggregationIterator]]. + * The base class of [[SortBasedAggregationIterator]] and [[TungstenAggregationIterator]]. * It mainly contains two parts: * 1. It initializes aggregate functions. * 2. It creates two functions, `processRow` and `generateOutput` based on [[AggregateMode]] of @@ -33,64 +33,58 @@ import scala.collection.mutable.ArrayBuffer * is used to generate result. */ abstract class AggregationIterator( - groupingKeyAttributes: Seq[Attribute], - valueAttributes: Seq[Attribute], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], + inputAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean) - extends Iterator[InternalRow] with Logging { + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection)) + extends Iterator[UnsafeRow] with Logging { /////////////////////////////////////////////////////////////////////////// // Initializing functions. /////////////////////////////////////////////////////////////////////////// - // An Seq of all AggregateExpressions. - // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final - // are at the beginning of the allAggregateExpressions. - protected val allAggregateExpressions = - nonCompleteAggregateExpressions ++ completeAggregateExpressions - - require( - allAggregateExpressions.map(_.mode).distinct.length <= 2, - s"$allAggregateExpressions are not supported becuase they have more than 2 distinct modes.") - /** - * The distinct modes of AggregateExpressions. Right now, we can handle the following mode: - * - Partial-only: all AggregateExpressions have the mode of Partial; - * - PartialMerge-only: all AggregateExpressions have the mode of PartialMerge); - * - Final-only: all AggregateExpressions have the mode of Final; - * - Final-Complete: some AggregateExpressions have the mode of Final and - * others have the mode of Complete; - * - Complete-only: nonCompleteAggregateExpressions is empty and we have AggregateExpressions - * with mode Complete in completeAggregateExpressions; and - * - Grouping-only: there is no AggregateExpression. - */ - protected val aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = - nonCompleteAggregateExpressions.map(_.mode).distinct.headOption -> - completeAggregateExpressions.map(_.mode).distinct.headOption + * The following combinations of AggregationMode are supported: + * - Partial + * - PartialMerge (for single distinct) + * - Partial and PartialMerge (for single distinct) + * - Final + * - Complete (for SortBasedAggregate with functions that does not support Partial) + * - Final and Complete (currently not used) + * + * TODO: AggregateMode should have only two modes: Update and Merge, AggregateExpression + * could have a flag to tell it's final or not. + */ + { + val modes = aggregateExpressions.map(_.mode).distinct.toSet + require(modes.size <= 2, + s"$aggregateExpressions are not supported because they have more than 2 distinct modes.") + require(modes.subsetOf(Set(Partial, PartialMerge)) || modes.subsetOf(Set(Final, Complete)), + s"$aggregateExpressions can't have Partial/PartialMerge and Final/Complete in the same time.") + } // Initialize all AggregateFunctions by binding references if necessary, // and set inputBufferOffset and mutableBufferOffset. - protected val allAggregateFunctions: Array[AggregateFunction] = { + protected def initializeAggregateFunctions( + expressions: Seq[AggregateExpression], + startingInputBufferOffset: Int): Array[AggregateFunction] = { var mutableBufferOffset = 0 - var inputBufferOffset: Int = initialInputBufferOffset - val functions = new Array[AggregateFunction](allAggregateExpressions.length) + var inputBufferOffset: Int = startingInputBufferOffset + val functions = new Array[AggregateFunction](expressions.length) var i = 0 - while (i < allAggregateExpressions.length) { - val func = allAggregateExpressions(i).aggregateFunction - val funcWithBoundReferences: AggregateFunction = allAggregateExpressions(i).mode match { + while (i < expressions.length) { + val func = expressions(i).aggregateFunction + val funcWithBoundReferences: AggregateFunction = expressions(i).mode match { case Partial | Complete if func.isInstanceOf[ImperativeAggregate] => // We need to create BoundReferences if the function is not an // expression-based aggregate function (it does not support code-gen) and the mode of // this function is Partial or Complete because we will call eval of this // function's children in the update method of this aggregate function. // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, valueAttributes) + BindReferences.bindReference(func, inputAttributes) case _ => // We only need to set inputBufferOffset for aggregate functions with mode // PartialMerge and Final. @@ -117,15 +111,18 @@ abstract class AggregationIterator( functions } + protected val aggregateFunctions: Array[AggregateFunction] = + initializeAggregateFunctions(aggregateExpressions, initialInputBufferOffset) + // Positions of those imperative aggregate functions in allAggregateFunctions. // For example, we have func1, func2, func3, func4 in aggregateFunctions, and // func2 and func3 are imperative aggregate functions. // ImperativeAggregateFunctionPositions will be [1, 2]. - private[this] val allImperativeAggregateFunctionPositions: Array[Int] = { + protected[this] val allImperativeAggregateFunctionPositions: Array[Int] = { val positions = new ArrayBuffer[Int]() var i = 0 - while (i < allAggregateFunctions.length) { - allAggregateFunctions(i) match { + while (i < aggregateFunctions.length) { + aggregateFunctions(i) match { case agg: DeclarativeAggregate => case _ => positions += i } @@ -134,17 +131,9 @@ abstract class AggregationIterator( positions.toArray } - // All AggregateFunctions functions with mode Partial, PartialMerge, or Final. - private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.take(nonCompleteAggregateExpressions.length) - - // All imperative aggregate functions with mode Partial, PartialMerge, or Final. - private[this] val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] = - nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func } - // The projection used to initialize buffer values for all expression-based aggregates. - private[this] val expressionAggInitialProjection = { - val initExpressions = allAggregateFunctions.flatMap { + protected[this] val expressionAggInitialProjection = { + val initExpressions = aggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.initialValues // For the positions corresponding to imperative aggregate functions, we'll use special // no-op expressions which are ignored during projection code-generation. @@ -154,248 +143,112 @@ abstract class AggregationIterator( } // All imperative AggregateFunctions. - private[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] = + protected[this] val allImperativeAggregateFunctions: Array[ImperativeAggregate] = allImperativeAggregateFunctionPositions - .map(allAggregateFunctions) + .map(aggregateFunctions) .map(_.asInstanceOf[ImperativeAggregate]) - /////////////////////////////////////////////////////////////////////////// - // Methods and fields used by sub-classes. - /////////////////////////////////////////////////////////////////////////// - // Initializing functions used to process a row. - protected val processRow: (MutableRow, InternalRow) => Unit = { - val rowToBeProcessed = new JoinedRow - val aggregationBufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) - aggregationMode match { - // Partial-only - case (Some(Partial), None) => - val updateExpressions = nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val expressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - expressionAggUpdateProjection.target(currentBuffer) - // Process all expression-based aggregate functions. - expressionAggUpdateProjection(rowToBeProcessed(currentBuffer, row)) - // Process all imperative aggregate functions. - var i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - } - - // PartialMerge-only or Final-only - case (Some(PartialMerge), None) | (Some(Final), None) => - val inputAggregationBufferSchema = if (initialInputBufferOffset == 0) { - // If initialInputBufferOffset, the input value does not contain - // grouping keys. - // This part is pretty hacky. - allAggregateFunctions.flatMap(_.inputAggBufferAttributes).toSeq - } else { - groupingKeyAttributes ++ allAggregateFunctions.flatMap(_.inputAggBufferAttributes) - } - // val inputAggregationBufferSchema = - // groupingKeyAttributes ++ - // allAggregateFunctions.flatMap(_.cloneBufferAttributes) - val mergeExpressions = nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - // This projection is used to merge buffer values for all expression-based aggregates. - val expressionAggMergeProjection = - newMutableProjection( - mergeExpressions, - aggregationBufferSchema ++ inputAggregationBufferSchema)() - - (currentBuffer: MutableRow, row: InternalRow) => { - // Process all expression-based aggregate functions. - expressionAggMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, row)) - // Process all imperative aggregate functions. - var i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 - } - } - - // Final-Complete - case (Some(Final), Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All imperative aggregate functions with mode Complete. - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - - // The first initialInputBufferOffset values of the input aggregation buffer is - // for grouping expressions and distinct columns. - val groupingAttributesAndDistinctColumns = valueAttributes.take(initialInputBufferOffset) - - val completeOffsetExpressions = - Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - // We do not touch buffer values of aggregate functions with the Final mode. - val finalOffsetExpressions = - Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - - val mergeInputSchema = - aggregationBufferSchema ++ - groupingAttributesAndDistinctColumns ++ - nonCompleteAggregateFunctions.flatMap(_.inputAggBufferAttributes) - val mergeExpressions = - nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } ++ completeOffsetExpressions - val finalExpressionAggMergeProjection = - newMutableProjection(mergeExpressions, mergeInputSchema)() - - val updateExpressions = - finalOffsetExpressions ++ completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeExpressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - val input = rowToBeProcessed(currentBuffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeExpressionAggUpdateProjection.target(currentBuffer)(input) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - - // For all aggregate functions with mode Final, merge buffers. - finalExpressionAggMergeProjection.target(currentBuffer)(input) - i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 + protected def generateProcessRow( + expressions: Seq[AggregateExpression], + functions: Seq[AggregateFunction], + inputAttributes: Seq[Attribute]): (MutableRow, InternalRow) => Unit = { + val joinedRow = new JoinedRow + if (expressions.nonEmpty) { + val mergeExpressions = functions.zipWithIndex.flatMap { + case (ae: DeclarativeAggregate, i) => + expressions(i).mode match { + case Partial | Complete => ae.updateExpressions + case PartialMerge | Final => ae.mergeExpressions } - } - - // Complete-only - case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All imperative aggregate functions with mode Complete. - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - - val updateExpressions = - completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeExpressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() - - (currentBuffer: MutableRow, row: InternalRow) => { - val input = rowToBeProcessed(currentBuffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeExpressionAggUpdateProjection.target(currentBuffer)(input) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 + case (agg: AggregateFunction, _) => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } + val updateFunctions = functions.zipWithIndex.collect { + case (ae: ImperativeAggregate, i) => + expressions(i).mode match { + case Partial | Complete => + (buffer: MutableRow, row: InternalRow) => ae.update(buffer, row) + case PartialMerge | Final => + (buffer: MutableRow, row: InternalRow) => ae.merge(buffer, row) } + } + // This projection is used to merge buffer values for all expression-based aggregates. + val aggregationBufferSchema = functions.flatMap(_.aggBufferAttributes) + val updateProjection = + newMutableProjection(mergeExpressions, aggregationBufferSchema ++ inputAttributes)() + + (currentBuffer: MutableRow, row: InternalRow) => { + // Process all expression-based aggregate functions. + updateProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) + // Process all imperative aggregate functions. + var i = 0 + while (i < updateFunctions.length) { + updateFunctions(i)(currentBuffer, row) + i += 1 } - + } + } else { // Grouping only. - case (None, None) => (currentBuffer: MutableRow, row: InternalRow) => {} - - case other => - sys.error( - s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " + - s"support evaluate modes $other in this iterator.") + (currentBuffer: MutableRow, row: InternalRow) => {} } } - // Initializing the function used to generate the output row. - protected val generateOutput: (InternalRow, MutableRow) => InternalRow = { - val rowToBeEvaluated = new JoinedRow - val safeOutputRow = new SpecificMutableRow(resultExpressions.map(_.dataType)) - val mutableOutput = if (outputsUnsafeRows) { - UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutputRow) - } else { - safeOutputRow - } - - aggregationMode match { - // Partial-only or PartialMerge-only: every output row is basically the values of - // the grouping expressions and the corresponding aggregation buffer. - case (Some(Partial), None) | (Some(PartialMerge), None) => - // Because we cannot copy a joinedRow containing a UnsafeRow (UnsafeRow does not - // support generic getter), we create a mutable projection to output the - // JoinedRow(currentGroupingKey, currentBuffer) - val bufferSchema = nonCompleteAggregateFunctions.flatMap(_.aggBufferAttributes) - val resultProjection = - newMutableProjection( - groupingKeyAttributes ++ bufferSchema, - groupingKeyAttributes ++ bufferSchema)() - resultProjection.target(mutableOutput) - - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - resultProjection(rowToBeEvaluated(currentGroupingKey, currentBuffer)) - // rowToBeEvaluated(currentGroupingKey, currentBuffer) - } + protected val processRow: (MutableRow, InternalRow) => Unit = + generateProcessRow(aggregateExpressions, aggregateFunctions, inputAttributes) - // Final-only, Complete-only and Final-Complete: every output row contains values representing - // resultExpressions. - case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => - val bufferSchemata = - allAggregateFunctions.flatMap(_.aggBufferAttributes) - val evalExpressions = allAggregateFunctions.map { - case ae: DeclarativeAggregate => ae.evaluateExpression - case agg: AggregateFunction => NoOp - } - val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)() - val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes - // TODO: Use unsafe row. - val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType)) - expressionAggEvalProjection.target(aggregateResult) - val resultProjection = - newMutableProjection( - resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)() - resultProjection.target(mutableOutput) + protected val groupingProjection: UnsafeProjection = + UnsafeProjection.create(groupingExpressions, inputAttributes) + protected val groupingAttributes = groupingExpressions.map(_.toAttribute) - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - // Generate results for all expression-based aggregate functions. - expressionAggEvalProjection(currentBuffer) - // Generate results for all imperative aggregate functions. - var i = 0 - while (i < allImperativeAggregateFunctions.length) { - aggregateResult.update( - allImperativeAggregateFunctionPositions(i), - allImperativeAggregateFunctions(i).eval(currentBuffer)) - i += 1 - } - resultProjection(rowToBeEvaluated(currentGroupingKey, aggregateResult)) + // Initializing the function used to generate the output row. + protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = { + val joinedRow = new JoinedRow + val modes = aggregateExpressions.map(_.mode).distinct + val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) + if (modes.contains(Final) || modes.contains(Complete)) { + val evalExpressions = aggregateFunctions.map { + case ae: DeclarativeAggregate => ae.evaluateExpression + case agg: AggregateFunction => NoOp + } + val aggregateResult = new SpecificMutableRow(aggregateAttributes.map(_.dataType)) + val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() + expressionAggEvalProjection.target(aggregateResult) + + val resultProjection = + UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateAttributes) + + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + // Generate results for all expression-based aggregate functions. + expressionAggEvalProjection(currentBuffer) + // Generate results for all imperative aggregate functions. + var i = 0 + while (i < allImperativeAggregateFunctions.length) { + aggregateResult.update( + allImperativeAggregateFunctionPositions(i), + allImperativeAggregateFunctions(i).eval(currentBuffer)) + i += 1 } - + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) + } + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + val resultProjection = UnsafeProjection.create( + groupingAttributes ++ bufferAttributes, + groupingAttributes ++ bufferAttributes) + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + resultProjection(joinedRow(currentGroupingKey, currentBuffer)) + } + } else { // Grouping-only: we only output values of grouping expressions. - case (None, None) => - val resultProjection = - newMutableProjection(resultExpressions, groupingKeyAttributes)() - resultProjection.target(mutableOutput) - - (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { - resultProjection(currentGroupingKey) - } - - case other => - sys.error( - s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " + - s"support evaluate modes $other in this iterator.") + val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + resultProjection(currentGroupingKey) + } } } + protected val generateOutput: (UnsafeRow, MutableRow) => UnsafeRow = + generateResultProjection() + /** Initializes buffer values for all aggregate functions. */ protected def initializeBuffer(buffer: MutableRow): Unit = { expressionAggInitialProjection.target(buffer)(EmptyRow) @@ -405,10 +258,4 @@ abstract class AggregationIterator( i += 1 } } - - /** - * Creates a new aggregation buffer and initializes buffer values - * for all aggregate functions. - */ - protected def newBuffer: MutableRow } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index ee982453c3287..c5470a6989de7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -29,10 +29,8 @@ import org.apache.spark.sql.execution.metric.SQLMetrics case class SortBasedAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) @@ -42,10 +40,8 @@ case class SortBasedAggregate( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def outputsUnsafeRows: Boolean = false - + override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = false - override def canProcessSafeRows: Boolean = true override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -76,31 +72,24 @@ case class SortBasedAggregate( if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input iterator is empty, // so return an empty iterator. - Iterator[InternalRow]() + Iterator[UnsafeRow]() } else { - val groupingKeyProjection = - UnsafeProjection.create(groupingExpressions, child.output) - val outputIter = new SortBasedAggregationIterator( - groupingKeyProjection, - groupingExpressions.map(_.toAttribute), + groupingExpressions, child.output, iter, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, + aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, newMutableProjection, - outputsUnsafeRows, numInputRows, numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { // There is no input and there is no grouping expressions. // We need to output a single row as the output. numOutputRows += 1 - Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) + Iterator[UnsafeRow](outputIter.outputForEmptyGroupingKeyWithoutInput()) } else { outputIter } @@ -109,7 +98,7 @@ case class SortBasedAggregate( } override def simpleString: String = { - val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + val allAggregateExpressions = aggregateExpressions val keyString = groupingExpressions.mkString("[", ",", "]") val functionString = allAggregateExpressions.mkString("[", ",", "]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index fe5c3195f867b..ac920aa8bc7f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -24,37 +24,34 @@ import org.apache.spark.sql.execution.metric.LongSQLMetric /** * An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been - * sorted by values of [[groupingKeyAttributes]]. + * sorted by values of [[groupingExpressions]]. */ class SortBasedAggregationIterator( - groupingKeyProjection: InternalRow => InternalRow, - groupingKeyAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], valueAttributes: Seq[Attribute], inputIterator: Iterator[InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - outputsUnsafeRows: Boolean, numInputRows: LongSQLMetric, numOutputRows: LongSQLMetric) extends AggregationIterator( - groupingKeyAttributes, + groupingExpressions, valueAttributes, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, + aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection, - outputsUnsafeRows) { - - override protected def newBuffer: MutableRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) + newMutableProjection) { + + /** + * Creates a new aggregation buffer and initializes buffer values + * for all aggregate functions. + */ + private def newBuffer: MutableRow = { + val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val bufferRowSize: Int = bufferSchema.length val genericMutableBuffer = new GenericMutableRow(bufferRowSize) @@ -76,10 +73,10 @@ class SortBasedAggregationIterator( /////////////////////////////////////////////////////////////////////////// // The partition key of the current partition. - private[this] var currentGroupingKey: InternalRow = _ + private[this] var currentGroupingKey: UnsafeRow = _ // The partition key of next partition. - private[this] var nextGroupingKey: InternalRow = _ + private[this] var nextGroupingKey: UnsafeRow = _ // The first row of next partition. private[this] var firstRowInNextGroup: InternalRow = _ @@ -94,7 +91,7 @@ class SortBasedAggregationIterator( if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) val inputRow = inputIterator.next() - nextGroupingKey = groupingKeyProjection(inputRow).copy() + nextGroupingKey = groupingProjection(inputRow).copy() firstRowInNextGroup = inputRow.copy() numInputRows += 1 sortedInputHasNewGroup = true @@ -120,7 +117,7 @@ class SortBasedAggregationIterator( while (!findNextPartition && inputIterator.hasNext) { // Get the grouping key. val currentRow = inputIterator.next() - val groupingKey = groupingKeyProjection(currentRow) + val groupingKey = groupingProjection(currentRow) numInputRows += 1 // Check if the current row belongs the current input row. @@ -146,7 +143,7 @@ class SortBasedAggregationIterator( override final def hasNext: Boolean = sortedInputHasNewGroup - override final def next(): InternalRow = { + override final def next(): UnsafeRow = { if (hasNext) { // Process the current group. processCurrentSortedGroup() @@ -162,8 +159,8 @@ class SortBasedAggregationIterator( } } - def outputForEmptyGroupingKeyWithoutInput(): InternalRow = { + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { initializeBuffer(sortBasedAggregationBuffer) - generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer) + generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 920de615e1d86..b8849c827048a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -30,21 +30,18 @@ import org.apache.spark.sql.types.StructType case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { private[this] val aggregateBufferAttributes = { - (nonCompleteAggregateExpressions ++ completeAggregateExpressions) - .flatMap(_.aggregateFunction.aggBufferAttributes) + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) } - require(TungstenAggregate.supportsAggregate(groupingExpressions, aggregateBufferAttributes)) + require(TungstenAggregate.supportsAggregate(aggregateBufferAttributes)) override private[sql] lazy val metrics = Map( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), @@ -53,9 +50,7 @@ case class TungstenAggregate( "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -94,10 +89,8 @@ case class TungstenAggregate( val aggregationIterator = new TungstenAggregationIterator( groupingExpressions, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, + aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, newMutableProjection, @@ -119,7 +112,7 @@ case class TungstenAggregate( } override def simpleString: String = { - val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions + val allAggregateExpressions = aggregateExpressions testFallbackStartsAt match { case None => @@ -135,9 +128,7 @@ case class TungstenAggregate( } object TungstenAggregate { - def supportsAggregate( - groupingExpressions: Seq[Expression], - aggregateBufferAttributes: Seq[Attribute]): Boolean = { + def supportsAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = { val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 04391443920ac..582fdbe547061 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -17,17 +17,15 @@ package org.apache.spark.sql.execution.aggregate -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.unsafe.KVIterator -import org.apache.spark.{InternalAccumulator, Logging, TaskContext} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.execution.metric.LongSQLMetric +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnsafeKVExternalSorter} import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.KVIterator +import org.apache.spark.{InternalAccumulator, Logging, TaskContext} /** * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s. @@ -63,15 +61,11 @@ import org.apache.spark.sql.types.StructType * * @param groupingExpressions * expressions for grouping keys - * @param nonCompleteAggregateExpressions + * @param aggregateExpressions * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Partial]], * [[PartialMerge]], or [[Final]]. - * @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions' + * @param aggregateAttributes the attributes of the aggregateExpressions' * outputs when they are stored in the final aggregation buffer. - * @param completeAggregateExpressions - * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Complete]]. - * @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs - * when they are stored in the final aggregation buffer. * @param resultExpressions * expressions for generating output rows. * @param newMutableProjection @@ -83,10 +77,8 @@ import org.apache.spark.sql.types.StructType */ class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression], - completeAggregateAttributes: Seq[Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), @@ -97,378 +89,62 @@ class TungstenAggregationIterator( numOutputRows: LongSQLMetric, dataSize: LongSQLMetric, spillSize: LongSQLMetric) - extends Iterator[UnsafeRow] with Logging { + extends AggregationIterator( + groupingExpressions, + originalInputAttributes, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection) with Logging { /////////////////////////////////////////////////////////////////////////// // Part 1: Initializing aggregate functions. /////////////////////////////////////////////////////////////////////////// - // A Seq containing all AggregateExpressions. - // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final - // are at the beginning of the allAggregateExpressions. - private[this] val allAggregateExpressions: Seq[AggregateExpression] = - nonCompleteAggregateExpressions ++ completeAggregateExpressions - - // Check to make sure we do not have more than three modes in our AggregateExpressions. - // If we have, users are hitting a bug and we throw an IllegalStateException. - if (allAggregateExpressions.map(_.mode).distinct.length > 2) { - throw new IllegalStateException( - s"$allAggregateExpressions should have no more than 2 kinds of modes.") - } - // Remember spill data size of this task before execute this operator so that we can // figure out how many bytes we spilled for this operator. private val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled - // - // The modes of AggregateExpressions. Right now, we can handle the following mode: - // - Partial-only: - // All AggregateExpressions have the mode of Partial. - // For this case, aggregationMode is (Some(Partial), None). - // - PartialMerge-only: - // All AggregateExpressions have the mode of PartialMerge). - // For this case, aggregationMode is (Some(PartialMerge), None). - // - Final-only: - // All AggregateExpressions have the mode of Final. - // For this case, aggregationMode is (Some(Final), None). - // - Final-Complete: - // Some AggregateExpressions have the mode of Final and - // others have the mode of Complete. For this case, - // aggregationMode is (Some(Final), Some(Complete)). - // - Complete-only: - // nonCompleteAggregateExpressions is empty and we have AggregateExpressions - // with mode Complete in completeAggregateExpressions. For this case, - // aggregationMode is (None, Some(Complete)). - // - Grouping-only: - // There is no AggregateExpression. For this case, AggregationMode is (None,None). - // - private[this] var aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = { - nonCompleteAggregateExpressions.map(_.mode).distinct.headOption -> - completeAggregateExpressions.map(_.mode).distinct.headOption - } - - // Initialize all AggregateFunctions by binding references, if necessary, - // and setting inputBufferOffset and mutableBufferOffset. - private def initializeAllAggregateFunctions( - startingInputBufferOffset: Int): Array[AggregateFunction] = { - var mutableBufferOffset = 0 - var inputBufferOffset: Int = startingInputBufferOffset - val functions = new Array[AggregateFunction](allAggregateExpressions.length) - var i = 0 - while (i < allAggregateExpressions.length) { - val func = allAggregateExpressions(i).aggregateFunction - val aggregateExpressionIsNonComplete = i < nonCompleteAggregateExpressions.length - // We need to use this mode instead of func.mode in order to handle aggregation mode switching - // when switching to sort-based aggregation: - val mode = if (aggregateExpressionIsNonComplete) aggregationMode._1 else aggregationMode._2 - val funcWithBoundReferences = mode match { - case Some(Partial) | Some(Complete) if func.isInstanceOf[ImperativeAggregate] => - // We need to create BoundReferences if the function is not an - // expression-based aggregate function (it does not support code-gen) and the mode of - // this function is Partial or Complete because we will call eval of this - // function's children in the update method of this aggregate function. - // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, originalInputAttributes) - case _ => - // We only need to set inputBufferOffset for aggregate functions with mode - // PartialMerge and Final. - val updatedFunc = func match { - case function: ImperativeAggregate => - function.withNewInputAggBufferOffset(inputBufferOffset) - case function => function - } - inputBufferOffset += func.aggBufferSchema.length - updatedFunc - } - val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match { - case function: ImperativeAggregate => - // Set mutableBufferOffset for this function. It is important that setting - // mutableBufferOffset happens after all potential bindReference operations - // because bindReference will create a new instance of the function. - function.withNewMutableAggBufferOffset(mutableBufferOffset) - case function => function - } - mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length - functions(i) = funcWithUpdatedAggBufferOffset - i += 1 - } - functions - } - - private[this] var allAggregateFunctions: Array[AggregateFunction] = - initializeAllAggregateFunctions(initialInputBufferOffset) - - // Positions of those imperative aggregate functions in allAggregateFunctions. - // For example, say that we have func1, func2, func3, func4 in aggregateFunctions, and - // func2 and func3 are imperative aggregate functions. Then - // allImperativeAggregateFunctionPositions will be [1, 2]. Note that this does not need to be - // updated when falling back to sort-based aggregation because the positions of the aggregate - // functions do not change in that case. - private[this] val allImperativeAggregateFunctionPositions: Array[Int] = { - val positions = new ArrayBuffer[Int]() - var i = 0 - while (i < allAggregateFunctions.length) { - allAggregateFunctions(i) match { - case agg: DeclarativeAggregate => - case _ => positions += i - } - i += 1 - } - positions.toArray - } - /////////////////////////////////////////////////////////////////////////// // Part 2: Methods and fields used by setting aggregation buffer values, // processing input rows from inputIter, and generating output // rows. /////////////////////////////////////////////////////////////////////////// - // The projection used to initialize buffer values for all expression-based aggregates. - // Note that this projection does not need to be updated when switching to sort-based aggregation - // because the schema of empty aggregation buffers does not change in that case. - private[this] val expressionAggInitialProjection: MutableProjection = { - val initExpressions = allAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.initialValues - // For the positions corresponding to imperative aggregate functions, we'll use special - // no-op expressions which are ignored during projection code-generation. - case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp) - } - newMutableProjection(initExpressions, Nil)() - } - // Creates a new aggregation buffer and initializes buffer values. - // This function should be only called at most three times (when we create the hash map, - // when we switch to sort-based aggregation, and when we create the re-used buffer for - // sort-based aggregation). + // This function should be only called at most two times (when we create the hash map, + // and when we create the re-used buffer for sort-based aggregation). private def createNewAggregationBuffer(): UnsafeRow = { - val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes) + val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val buffer: UnsafeRow = UnsafeProjection.create(bufferSchema.map(_.dataType)) .apply(new GenericMutableRow(bufferSchema.length)) // Initialize declarative aggregates' buffer values expressionAggInitialProjection.target(buffer)(EmptyRow) // Initialize imperative aggregates' buffer values - allAggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) + aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) buffer } - // Creates a function used to process a row based on the given inputAttributes. - private def generateProcessRow( - inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = { - - val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes) - val joinedRow = new JoinedRow() - - aggregationMode match { - // Partial-only - case (Some(Partial), None) => - val updateExpressions = allAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val imperativeAggregateFunctions: Array[ImperativeAggregate] = - allAggregateFunctions.collect { case func: ImperativeAggregate => func} - val expressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - expressionAggUpdateProjection.target(currentBuffer) - // Process all expression-based aggregate functions. - expressionAggUpdateProjection(joinedRow(currentBuffer, row)) - // Process all imperative aggregate functions - var i = 0 - while (i < imperativeAggregateFunctions.length) { - imperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - } - - // PartialMerge-only or Final-only - case (Some(PartialMerge), None) | (Some(Final), None) => - val mergeExpressions = allAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val imperativeAggregateFunctions: Array[ImperativeAggregate] = - allAggregateFunctions.collect { case func: ImperativeAggregate => func} - // This projection is used to merge buffer values for all expression-based aggregates. - val expressionAggMergeProjection = - newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - // Process all expression-based aggregate functions. - expressionAggMergeProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) - // Process all imperative aggregate functions. - var i = 0 - while (i < imperativeAggregateFunctions.length) { - imperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 - } - } - - // Final-Complete - case (Some(Final), Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - val nonCompleteAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.take(nonCompleteAggregateExpressions.length) - val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] = - nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func } - - val completeOffsetExpressions = - Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - val mergeExpressions = - nonCompleteAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } ++ completeOffsetExpressions - val finalMergeProjection = - newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() - - // We do not touch buffer values of aggregate functions with the Final mode. - val finalOffsetExpressions = - Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - val input = joinedRow(currentBuffer, row) - // For all aggregate functions with mode Complete, update buffers. - completeUpdateProjection.target(currentBuffer)(input) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - - // For all aggregate functions with mode Final, merge buffer values in row to - // currentBuffer. - finalMergeProjection.target(currentBuffer)(input) - i = 0 - while (i < nonCompleteImperativeAggregateFunctions.length) { - nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) - i += 1 - } - } - - // Complete-only - case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction] = - allAggregateFunctions.takeRight(completeAggregateExpressions.length) - // All imperative aggregate functions with mode Complete. - val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = - completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - - val updateExpressions = completeAggregateFunctions.flatMap { - case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) - } - val completeExpressionAggUpdateProjection = - newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() - - (currentBuffer: UnsafeRow, row: InternalRow) => { - // For all aggregate functions with mode Complete, update buffers. - completeExpressionAggUpdateProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) - var i = 0 - while (i < completeImperativeAggregateFunctions.length) { - completeImperativeAggregateFunctions(i).update(currentBuffer, row) - i += 1 - } - } - - // Grouping only. - case (None, None) => (currentBuffer: UnsafeRow, row: InternalRow) => {} - - case other => - throw new IllegalStateException( - s"${aggregationMode} should not be passed into TungstenAggregationIterator.") - } - } - // Creates a function used to generate output rows. - private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = { - - val groupingAttributes = groupingExpressions.map(_.toAttribute) - val bufferAttributes = allAggregateFunctions.flatMap(_.aggBufferAttributes) - - aggregationMode match { - // Partial-only or PartialMerge-only: every output row is basically the values of - // the grouping expressions and the corresponding aggregation buffer. - case (Some(Partial), None) | (Some(PartialMerge), None) => - val groupingKeySchema = StructType.fromAttributes(groupingAttributes) - val bufferSchema = StructType.fromAttributes(bufferAttributes) - val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) - - (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - unsafeRowJoiner.join(currentGroupingKey, currentBuffer) - } - - // Final-only, Complete-only and Final-Complete: a output row is generated based on - // resultExpressions. - case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => - val joinedRow = new JoinedRow() - val evalExpressions = allAggregateFunctions.map { - case ae: DeclarativeAggregate => ae.evaluateExpression - case agg: AggregateFunction => NoOp - } - val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() - // These are the attributes of the row produced by `expressionAggEvalProjection` - val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes - val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType)) - expressionAggEvalProjection.target(aggregateResult) - val resultProjection = - UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema) - - val allImperativeAggregateFunctions: Array[ImperativeAggregate] = - allAggregateFunctions.collect { case func: ImperativeAggregate => func} - - (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - // Generate results for all expression-based aggregate functions. - expressionAggEvalProjection(currentBuffer) - // Generate results for all imperative aggregate functions. - var i = 0 - while (i < allImperativeAggregateFunctions.length) { - aggregateResult.update( - allImperativeAggregateFunctionPositions(i), - allImperativeAggregateFunctions(i).eval(currentBuffer)) - i += 1 - } - resultProjection(joinedRow(currentGroupingKey, aggregateResult)) - } - - // Grouping-only: a output row is generated from values of grouping expressions. - case (None, None) => - val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) - - (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - resultProjection(currentGroupingKey) - } - - case other => - throw new IllegalStateException( - s"${aggregationMode} should not be passed into TungstenAggregationIterator.") + override protected def generateResultProjection(): (UnsafeRow, MutableRow) => UnsafeRow = { + val modes = aggregateExpressions.map(_.mode).distinct + if (modes.nonEmpty && !modes.contains(Final) && !modes.contains(Complete)) { + // Fast path for partial aggregation, UnsafeRowJoiner is usually faster than projection + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val bufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) + val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + val bufferSchema = StructType.fromAttributes(bufferAttributes) + val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + + (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => { + unsafeRowJoiner.join(currentGroupingKey, currentBuffer.asInstanceOf[UnsafeRow]) + } + } else { + super.generateResultProjection() } } - // An UnsafeProjection used to extract grouping keys from the input rows. - private[this] val groupProjection = - UnsafeProjection.create(groupingExpressions, originalInputAttributes) - - // A function used to process a input row. Its first argument is the aggregation buffer - // and the second argument is the input row. - private[this] var processRow: (UnsafeRow, InternalRow) => Unit = - generateProcessRow(originalInputAttributes) - - // A function used to generate output rows based on the grouping keys (first argument) - // and the corresponding aggregation buffer (second argument). - private[this] var generateOutput: (UnsafeRow, UnsafeRow) => UnsafeRow = - generateResultProjection() - // An aggregation buffer containing initial buffer values. It is used to // initialize other aggregation buffers. private[this] val initialAggregationBuffer: UnsafeRow = createNewAggregationBuffer() @@ -482,7 +158,7 @@ class TungstenAggregationIterator( // all groups and their corresponding aggregation buffers for hash-based aggregation. private[this] val hashMap = new UnsafeFixedWidthAggregationMap( initialAggregationBuffer, - StructType.fromAttributes(allAggregateFunctions.flatMap(_.aggBufferAttributes)), + StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), TaskContext.get().taskMemoryManager(), 1024 * 16, // initial capacity @@ -499,7 +175,7 @@ class TungstenAggregationIterator( if (groupingExpressions.isEmpty) { // If there is no grouping expressions, we can just reuse the same buffer over and over again. // Note that it would be better to eliminate the hash map entirely in the future. - val groupingKey = groupProjection.apply(null) + val groupingKey = groupingProjection.apply(null) val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) while (inputIter.hasNext) { val newInput = inputIter.next() @@ -511,7 +187,7 @@ class TungstenAggregationIterator( while (inputIter.hasNext) { val newInput = inputIter.next() numInputRows += 1 - val groupingKey = groupProjection.apply(newInput) + val groupingKey = groupingProjection.apply(newInput) var buffer: UnsafeRow = null if (i < fallbackStartsAt) { buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) @@ -565,25 +241,18 @@ class TungstenAggregationIterator( private def switchToSortBasedAggregation(): Unit = { logInfo("falling back to sort based aggregation.") - // Set aggregationMode, processRow, and generateOutput for sort-based aggregation. - val newAggregationMode = aggregationMode match { - case (Some(Partial), None) => (Some(PartialMerge), None) - case (None, Some(Complete)) => (Some(Final), None) - case (Some(Final), Some(Complete)) => (Some(Final), None) + // Basically the value of the KVIterator returned by externalSorter + // will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it. + val newExpressions = aggregateExpressions.map { + case agg @ AggregateExpression(_, Partial, _) => + agg.copy(mode = PartialMerge) + case agg @ AggregateExpression(_, Complete, _) => + agg.copy(mode = Final) case other => other } - aggregationMode = newAggregationMode - - allAggregateFunctions = initializeAllAggregateFunctions(startingInputBufferOffset = 0) - - // Basically the value of the KVIterator returned by externalSorter - // will just aggregation buffer. At here, we use inputAggBufferAttributes. - val newInputAttributes: Seq[Attribute] = - allAggregateFunctions.flatMap(_.inputAggBufferAttributes) - - // Set up new processRow and generateOutput. - processRow = generateProcessRow(newInputAttributes) - generateOutput = generateResultProjection() + val newFunctions = initializeAggregateFunctions(newExpressions, 0) + val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes) + sortBasedProcessRow = generateProcessRow(newExpressions, newFunctions, newInputAttributes) // Step 5: Get the sorted iterator from the externalSorter. sortedKVIterator = externalSorter.sortedIterator() @@ -632,6 +301,9 @@ class TungstenAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer() + // The function used to process rows in a group + private[this] var sortBasedProcessRow: (MutableRow, InternalRow) => Unit = null + // Processes rows in the current group. It will stop when it find a new group. private def processCurrentSortedGroup(): Unit = { // First, we need to copy nextGroupingKey to currentGroupingKey. @@ -640,7 +312,7 @@ class TungstenAggregationIterator( // We create a variable to track if we see the next group. var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. - processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + sortBasedProcessRow(sortBasedAggregationBuffer, firstRowInNextGroup) // The search will stop when we see the next group or there is no // input row left in the iter. @@ -655,16 +327,15 @@ class TungstenAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey.equals(groupingKey)) { - processRow(sortBasedAggregationBuffer, inputAggregationBuffer) + sortBasedProcessRow(sortBasedAggregationBuffer, inputAggregationBuffer) hasNext = sortedKVIterator.next() } else { // We find a new group. findNextPartition = true // copyFrom will fail when - nextGroupingKey.copyFrom(groupingKey) // = groupingKey.copy() - firstRowInNextGroup.copyFrom(inputAggregationBuffer) // = inputAggregationBuffer.copy() - + nextGroupingKey.copyFrom(groupingKey) + firstRowInNextGroup.copyFrom(inputAggregationBuffer) } } // We have not seen a new group. It means that there is no new row in the input diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 76b938cdb694e..83379ae90f703 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -42,16 +42,45 @@ object Utils { SortBasedAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = Nil, - nonCompleteAggregateAttributes = Nil, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, + aggregateExpressions = completeAggregateExpressions, + aggregateAttributes = completeAggregateAttributes, initialInputBufferOffset = 0, resultExpressions = resultExpressions, child = child ) :: Nil } + private def createAggregate( + requiredChildDistributionExpressions: Option[Seq[Expression]] = None, + groupingExpressions: Seq[NamedExpression] = Nil, + aggregateExpressions: Seq[AggregateExpression] = Nil, + aggregateAttributes: Seq[Attribute] = Nil, + initialInputBufferOffset: Int = 0, + resultExpressions: Seq[NamedExpression] = Nil, + child: SparkPlan): SparkPlan = { + val usesTungstenAggregate = TungstenAggregate.supportsAggregate( + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) + if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } else { + SortBasedAggregate( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } + } + def planAggregateWithoutDistinct( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], @@ -59,9 +88,6 @@ object Utils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use TungstenAggregate. - val usesTungstenAggregate = TungstenAggregate.supportsAggregate( - groupingExpressions, - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) // 1. Create an Aggregate Operator for partial aggregations. @@ -73,29 +99,14 @@ object Utils { groupingAttributes ++ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - val partialAggregate = if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], + val partialAggregate = createAggregate( + requiredChildDistributionExpressions = None, groupingExpressions = groupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, + aggregateExpressions = partialAggregateExpressions, + aggregateAttributes = partialAggregateAttributes, initialInputBufferOffset = 0, resultExpressions = partialResultExpressions, child = child) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], - groupingExpressions = groupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialResultExpressions, - child = child) - } // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) @@ -105,29 +116,14 @@ object Utils { expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } - val finalAggregate = if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = groupingExpressions.length, - resultExpressions = resultExpressions, - child = partialAggregate) - } else { - SortBasedAggregate( + val finalAggregate = createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, initialInputBufferOffset = groupingExpressions.length, resultExpressions = resultExpressions, child = partialAggregate) - } finalAggregate :: Nil } @@ -140,99 +136,99 @@ object Utils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct - val usesTungstenAggregate = TungstenAggregate.supportsAggregate( - groupingExpressions, - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) - // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one // DISTINCT aggregate function, all of those functions will have the same column expressions. // For example, it would be valid for functionsWithDistinct to be // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is // disallowed because those two distinct aggregates have different column expressions. - val distinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children - val namedDistinctColumnExpressions = distinctColumnExpressions.map { + val distinctExpressions = functionsWithDistinct.head.aggregateFunction.children + val namedDistinctExpressions = distinctExpressions.map { case ne: NamedExpression => ne case other => Alias(other, other.toString)() } - val distinctColumnAttributes = namedDistinctColumnExpressions.map(_.toAttribute) + val distinctAttributes = namedDistinctExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) // 1. Create an Aggregate Operator for partial aggregations. val partialAggregate: SparkPlan = { - val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) - val partialAggregateAttributes = - partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } // We will group by the original grouping expression, plus an additional expression for the // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping // expressions will be [key, value]. - val partialAggregateGroupingExpressions = - groupingExpressions ++ namedDistinctColumnExpressions - val partialAggregateResult = - groupingAttributes ++ - distinctColumnAttributes ++ - partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = None, - groupingExpressions = partialAggregateGroupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialAggregateResult, - child = child) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = None, - groupingExpressions = partialAggregateGroupingExpressions, - nonCompleteAggregateExpressions = partialAggregateExpressions, - nonCompleteAggregateAttributes = partialAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = 0, - resultExpressions = partialAggregateResult, - child = child) - } + createAggregate( + groupingExpressions = groupingExpressions ++ namedDistinctExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ distinctAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) } // 2. Create an Aggregate Operator for partial merge aggregations. val partialMergeAggregate: SparkPlan = { - val partialMergeAggregateExpressions = - functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - val partialMergeAggregateAttributes = - partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - val partialMergeAggregateResult = - groupingAttributes ++ - distinctColumnAttributes ++ - partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes ++ distinctColumnAttributes, - nonCompleteAggregateExpressions = partialMergeAggregateExpressions, - nonCompleteAggregateAttributes = partialMergeAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = partialMergeAggregateResult, - child = partialAggregate) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes ++ distinctColumnAttributes, - nonCompleteAggregateExpressions = partialMergeAggregateExpressions, - nonCompleteAggregateAttributes = partialMergeAggregateAttributes, - completeAggregateExpressions = Nil, - completeAggregateAttributes = Nil, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = partialMergeAggregateResult, - child = partialAggregate) + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes ++ distinctAttributes), + groupingExpressions = groupingAttributes ++ distinctAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, + resultExpressions = groupingAttributes ++ distinctAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate) + } + + // 3. Create an Aggregate operator for partial aggregation (for distinct) + val distinctColumnAttributeLookup = distinctExpressions.zip(distinctAttributes).toMap + val rewrittenDistinctFunctions = functionsWithDistinct.map { + // Children of an AggregateFunction with DISTINCT keyword has already + // been evaluated. At here, we need to replace original children + // to AttributeReferences. + case agg @ AggregateExpression(aggregateFunction, mode, true) => + aggregateFunction.transformDown(distinctColumnAttributeLookup) + .asInstanceOf[AggregateFunction] } - // 3. Create an Aggregate Operator for the final aggregation. + val partialDistinctAggregate: SparkPlan = { + val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val mergeAggregateAttributes = mergeAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } + val (distinctAggregateExpressions, distinctAggregateAttributes) = + rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => + // We rewrite the aggregate function to a non-distinct aggregation because + // its input will have distinct arguments. + // We just keep the isDistinct setting to true, so when users look at the query plan, + // they still can see distinct aggregations. + val expr = AggregateExpression(func, Partial, isDistinct = true) + // Use original AggregationFunction to lookup attributes, which is used to build + // aggregateFunctionToAttribute + val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + (expr, attr) + }.unzip + + val partialAggregateResult = groupingAttributes ++ + mergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) ++ + distinctAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + createAggregate( + groupingExpressions = groupingAttributes, + aggregateExpressions = mergeAggregateExpressions ++ distinctAggregateExpressions, + aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes, + initialInputBufferOffset = (groupingAttributes ++ distinctAttributes).length, + resultExpressions = partialAggregateResult, + child = partialMergeAggregate) + } + + // 4. Create an Aggregate Operator for the final aggregation. val finalAndCompleteAggregate: SparkPlan = { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result @@ -241,49 +237,27 @@ object Utils { expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } - val distinctColumnAttributeLookup = - distinctColumnExpressions.zip(distinctColumnAttributes).toMap - val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { - // Children of an AggregateFunction with DISTINCT keyword has already - // been evaluated. At here, we need to replace original children - // to AttributeReferences. - case agg @ AggregateExpression(aggregateFunction, mode, true) => - val rewrittenAggregateFunction = aggregateFunction - .transformDown(distinctColumnAttributeLookup) - .asInstanceOf[AggregateFunction] + val (distinctAggregateExpressions, distinctAggregateAttributes) = + rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, // they still can see distinct aggregations. - val rewrittenAggregateExpression = - AggregateExpression(rewrittenAggregateFunction, Complete, isDistinct = true) - - val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true) - (rewrittenAggregateExpression, aggregateFunctionAttribute) + val expr = AggregateExpression(func, Final, isDistinct = true) + // Use original AggregationFunction to lookup attributes, which is used to build + // aggregateFunctionToAttribute + val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + (expr, attr) }.unzip - if (usesTungstenAggregate) { - TungstenAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = resultExpressions, - child = partialMergeAggregate) - } else { - SortBasedAggregate( - requiredChildDistributionExpressions = Some(groupingAttributes), - groupingExpressions = groupingAttributes, - nonCompleteAggregateExpressions = finalAggregateExpressions, - nonCompleteAggregateAttributes = finalAggregateAttributes, - completeAggregateExpressions = completeAggregateExpressions, - completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = resultExpressions, - child = partialMergeAggregate) - } + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions ++ distinctAggregateExpressions, + aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = partialDistinctAggregate) } finalAndCompleteAggregate :: Nil diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 064c0004b801e..5550198c02fbf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ -import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} @@ -552,80 +551,73 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te } test("single distinct column set") { - Seq(true, false).foreach { specializeSingleDistinctAgg => - val conf = - (SQLConf.SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING.key, - specializeSingleDistinctAgg.toString) - withSQLConf(conf) { - // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. - checkAnswer( - sqlContext.sql( - """ - |SELECT - | min(distinct value1), - | sum(distinct value1), - | avg(value1), - | avg(value2), - | max(distinct value1) - |FROM agg2 - """.stripMargin), - Row(-60, 70.0, 101.0/9.0, 5.6, 100)) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | mydoubleavg(distinct value1), - | avg(value1), - | avg(value2), - | key, - | mydoubleavg(value1 - 1), - | mydoubleavg(distinct value1) * 0.1, - | avg(value1 + value2) - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: - Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: - Row(null, null, 3.0, 3, null, null, null) :: - Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | key, - | mydoubleavg(distinct value1), - | mydoublesum(value2), - | mydoublesum(distinct value1), - | mydoubleavg(distinct value1), - | mydoubleavg(value1) - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: - Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: - Row(3, null, 3.0, null, null, null) :: - Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | count(value1), - | count(*), - | count(1), - | count(DISTINCT value1), - | key - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(3, 3, 3, 2, 1) :: - Row(3, 4, 4, 2, 2) :: - Row(0, 2, 2, 0, 3) :: - Row(3, 4, 4, 3, null) :: Nil) - } - } + // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | min(distinct value1), + | sum(distinct value1), + | avg(value1), + | avg(value2), + | max(distinct value1) + |FROM agg2 + """.stripMargin), + Row(-60, 70.0, 101.0/9.0, 5.6, 100)) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | mydoubleavg(distinct value1), + | avg(value1), + | avg(value2), + | key, + | mydoubleavg(value1 - 1), + | mydoubleavg(distinct value1) * 0.1, + | avg(value1 + value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: + Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: + Row(null, null, 3.0, 3, null, null, null) :: + Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | mydoubleavg(distinct value1), + | mydoublesum(value2), + | mydoublesum(distinct value1), + | mydoubleavg(distinct value1), + | mydoubleavg(value1) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: + Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: + Row(3, null, 3.0, null, null, null) :: + Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value1), + | count(*), + | count(1), + | count(DISTINCT value1), + | key + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(3, 3, 3, 2, 1) :: + Row(3, 4, 4, 2, 2) :: + Row(0, 2, 2, 0, 3) :: + Row(3, 4, 4, 3, null) :: Nil) } test("single distinct multiple columns set") {