Skip to content

Commit

Permalink
Merge pull request #148 from mbautin/csd-1.6_SPARK-12213
Browse files Browse the repository at this point in the history
Backport: [SPARK-12213][SQL] use multiple partitions for single distinct query
  • Loading branch information
markhamstra committed Feb 1, 2016
2 parents edc1192 + b6db2cc commit bea8845
Show file tree
Hide file tree
Showing 10 changed files with 422 additions and 990 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst

private[spark] trait CatalystConf {
def caseSensitiveAnalysis: Boolean

protected[spark] def specializeSingleDistinctAggPlanning: Boolean
}

/**
Expand All @@ -31,13 +29,8 @@ object EmptyConf extends CatalystConf {
override def caseSensitiveAnalysis: Boolean = {
throw new UnsupportedOperationException
}

protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = {
throw new UnsupportedOperationException
}
}

/** A CatalystConf that can be used for local testing. */
case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf {
protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,8 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
.filter(_.isDistinct)
.groupBy(_.aggregateFunction.children.toSet)

val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) {
// When the flag is set to specialize single distinct agg planning,
// we will rely on our Aggregation strategy to handle queries with a single
// distinct column.
distinctAggGroups.size > 1
} else {
distinctAggGroups.size >= 1
}
if (shouldRewrite) {
// Aggregation strategy can handle the query with single distinct
if (distinctAggGroups.size > 1) {
// Create the attributes for the grouping id and the group by clause.
val gid = new AttributeReference("gid", IntegerType, false)()
val groupByMap = a.groupingExpressions.collect {
Expand Down
15 changes: 0 additions & 15 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -449,18 +449,6 @@ private[spark] object SQLConf {
doc = "When true, we could use `datasource`.`path` as table in SQL query"
)

val SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING =
booleanConf("spark.sql.specializeSingleDistinctAggPlanning",
defaultValue = Some(false),
isPublic = false,
doc = "When true, if a query only has a single distinct column and it has " +
"grouping expressions, we will use our planner rule to handle this distinct " +
"column (other cases are handled by DistinctAggregationRewriter). " +
"When false, we will always use DistinctAggregationRewriter to plan " +
"aggregation queries with DISTINCT keyword. This is an internal flag that is " +
"used to benchmark the performance impact of using DistinctAggregationRewriter to " +
"plan aggregation queries with a single distinct column.")

object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
val EXTERNAL_SORT = "spark.sql.planner.externalSort"
Expand Down Expand Up @@ -579,9 +567,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf {

private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES)

protected[spark] override def specializeSingleDistinctAggPlanning: Boolean =
getConf(SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
case class SortBasedAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
nonCompleteAggregateExpressions: Seq[AggregateExpression],
nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression],
completeAggregateAttributes: Seq[Attribute],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
Expand All @@ -42,10 +40,8 @@ case class SortBasedAggregate(
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))

override def outputsUnsafeRows: Boolean = false

override def outputsUnsafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = false

override def canProcessSafeRows: Boolean = true

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
Expand Down Expand Up @@ -76,31 +72,24 @@ case class SortBasedAggregate(
if (!hasInput && groupingExpressions.nonEmpty) {
// This is a grouped aggregate and the input iterator is empty,
// so return an empty iterator.
Iterator[InternalRow]()
Iterator[UnsafeRow]()
} else {
val groupingKeyProjection =
UnsafeProjection.create(groupingExpressions, child.output)

val outputIter = new SortBasedAggregationIterator(
groupingKeyProjection,
groupingExpressions.map(_.toAttribute),
groupingExpressions,
child.output,
iter,
nonCompleteAggregateExpressions,
nonCompleteAggregateAttributes,
completeAggregateExpressions,
completeAggregateAttributes,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
newMutableProjection,
outputsUnsafeRows,
numInputRows,
numOutputRows)
if (!hasInput && groupingExpressions.isEmpty) {
// There is no input and there is no grouping expressions.
// We need to output a single row as the output.
numOutputRows += 1
Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
Iterator[UnsafeRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
} else {
outputIter
}
Expand All @@ -109,7 +98,7 @@ case class SortBasedAggregate(
}

override def simpleString: String = {
val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions
val allAggregateExpressions = aggregateExpressions

val keyString = groupingExpressions.mkString("[", ",", "]")
val functionString = allAggregateExpressions.mkString("[", ",", "]")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,34 @@ import org.apache.spark.sql.execution.metric.LongSQLMetric

/**
* An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been
* sorted by values of [[groupingKeyAttributes]].
* sorted by values of [[groupingExpressions]].
*/
class SortBasedAggregationIterator(
groupingKeyProjection: InternalRow => InternalRow,
groupingKeyAttributes: Seq[Attribute],
groupingExpressions: Seq[NamedExpression],
valueAttributes: Seq[Attribute],
inputIterator: Iterator[InternalRow],
nonCompleteAggregateExpressions: Seq[AggregateExpression],
nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression],
completeAggregateAttributes: Seq[Attribute],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
outputsUnsafeRows: Boolean,
numInputRows: LongSQLMetric,
numOutputRows: LongSQLMetric)
extends AggregationIterator(
groupingKeyAttributes,
groupingExpressions,
valueAttributes,
nonCompleteAggregateExpressions,
nonCompleteAggregateAttributes,
completeAggregateExpressions,
completeAggregateAttributes,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
newMutableProjection,
outputsUnsafeRows) {

override protected def newBuffer: MutableRow = {
val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes)
newMutableProjection) {

/**
* Creates a new aggregation buffer and initializes buffer values
* for all aggregate functions.
*/
private def newBuffer: MutableRow = {
val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes)
val bufferRowSize: Int = bufferSchema.length

val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
Expand All @@ -76,10 +73,10 @@ class SortBasedAggregationIterator(
///////////////////////////////////////////////////////////////////////////

// The partition key of the current partition.
private[this] var currentGroupingKey: InternalRow = _
private[this] var currentGroupingKey: UnsafeRow = _

// The partition key of next partition.
private[this] var nextGroupingKey: InternalRow = _
private[this] var nextGroupingKey: UnsafeRow = _

// The first row of next partition.
private[this] var firstRowInNextGroup: InternalRow = _
Expand All @@ -94,7 +91,7 @@ class SortBasedAggregationIterator(
if (inputIterator.hasNext) {
initializeBuffer(sortBasedAggregationBuffer)
val inputRow = inputIterator.next()
nextGroupingKey = groupingKeyProjection(inputRow).copy()
nextGroupingKey = groupingProjection(inputRow).copy()
firstRowInNextGroup = inputRow.copy()
numInputRows += 1
sortedInputHasNewGroup = true
Expand All @@ -120,7 +117,7 @@ class SortBasedAggregationIterator(
while (!findNextPartition && inputIterator.hasNext) {
// Get the grouping key.
val currentRow = inputIterator.next()
val groupingKey = groupingKeyProjection(currentRow)
val groupingKey = groupingProjection(currentRow)
numInputRows += 1

// Check if the current row belongs the current input row.
Expand All @@ -146,7 +143,7 @@ class SortBasedAggregationIterator(

override final def hasNext: Boolean = sortedInputHasNewGroup

override final def next(): InternalRow = {
override final def next(): UnsafeRow = {
if (hasNext) {
// Process the current group.
processCurrentSortedGroup()
Expand All @@ -162,8 +159,8 @@ class SortBasedAggregationIterator(
}
}

def outputForEmptyGroupingKeyWithoutInput(): InternalRow = {
def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
initializeBuffer(sortBasedAggregationBuffer)
generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer)
generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,18 @@ import org.apache.spark.sql.types.StructType
case class TungstenAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
nonCompleteAggregateExpressions: Seq[AggregateExpression],
nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression],
completeAggregateAttributes: Seq[Attribute],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryNode {

private[this] val aggregateBufferAttributes = {
(nonCompleteAggregateExpressions ++ completeAggregateExpressions)
.flatMap(_.aggregateFunction.aggBufferAttributes)
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
}

require(TungstenAggregate.supportsAggregate(groupingExpressions, aggregateBufferAttributes))
require(TungstenAggregate.supportsAggregate(aggregateBufferAttributes))

override private[sql] lazy val metrics = Map(
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
Expand All @@ -53,9 +50,7 @@ case class TungstenAggregate(
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))

override def outputsUnsafeRows: Boolean = true

override def canProcessUnsafeRows: Boolean = true

override def canProcessSafeRows: Boolean = true

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
Expand Down Expand Up @@ -94,10 +89,8 @@ case class TungstenAggregate(
val aggregationIterator =
new TungstenAggregationIterator(
groupingExpressions,
nonCompleteAggregateExpressions,
nonCompleteAggregateAttributes,
completeAggregateExpressions,
completeAggregateAttributes,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
newMutableProjection,
Expand All @@ -119,7 +112,7 @@ case class TungstenAggregate(
}

override def simpleString: String = {
val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions
val allAggregateExpressions = aggregateExpressions

testFallbackStartsAt match {
case None =>
Expand All @@ -135,9 +128,7 @@ case class TungstenAggregate(
}

object TungstenAggregate {
def supportsAggregate(
groupingExpressions: Seq[Expression],
aggregateBufferAttributes: Seq[Attribute]): Boolean = {
def supportsAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = {
val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema)
}
Expand Down
Loading

0 comments on commit bea8845

Please sign in to comment.