Skip to content

Commit

Permalink
Cleanup addition of ordering requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
marmbrus committed Apr 15, 2015
1 parent b198278 commit 7ddd656
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] {
}

object RowOrdering {
def getOrderingFromDataTypes(dataTypes: Seq[DataType]): RowOrdering =
def forSchema(dataTypes: Seq[DataType]): RowOrdering =
new RowOrdering(dataTypes.zipWithIndex.map {
case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ sealed trait Partitioning {
* only compatible if the `numPartitions` of them is the same.
*/
def compatibleWith(other: Partitioning): Boolean

/** Returns the expressions that are used to key the partitioning. */
def keyExpressions: Seq[Expression]
}

case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
Expand All @@ -106,6 +109,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
case UnknownPartitioning(_) => true
case _ => false
}

override def keyExpressions: Seq[Expression] = Nil
}

case object SinglePartition extends Partitioning {
Expand All @@ -117,6 +122,8 @@ case object SinglePartition extends Partitioning {
case SinglePartition => true
case _ => false
}

override def keyExpressions: Seq[Expression] = Nil
}

case object BroadcastPartitioning extends Partitioning {
Expand All @@ -128,6 +135,8 @@ case object BroadcastPartitioning extends Partitioning {
case SinglePartition => true
case _ => false
}

override def keyExpressions: Seq[Expression] = Nil
}

/**
Expand Down Expand Up @@ -158,6 +167,8 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case _ => false
}

override def keyExpressions: Seq[Expression] = expressions

override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
Expand Down Expand Up @@ -200,6 +211,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case _ => false
}

override def keyExpressions: Seq[Expression] = ordering.map(_.child)

override def eval(input: Row): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches =
Batch("Add exchange", Once, AddExchange(self)) :: Nil
Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil
}

protected[sql] def openSession(): SQLSession = {
Expand Down
147 changes: 101 additions & 46 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,30 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.util.MutablePair

object Exchange {
/** Returns true when the ordering expressions are a subset of the key. */
def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = {
desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet)
}
}

/**
* Shuffle data according to a new partition rule, and sort inside each partition if necessary.
* @param newPartitioning The new partitioning way that required by parent
* @param sort Whether we will sort inside each partition
* @param child Child operator
* :: DeveloperApi ::
* Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each
* resulting partition based on expressions from the partition key. It is invalid to construct an
* exchange operator with a `newOrdering` that cannot be calculated using the partitioning key.
*/
@DeveloperApi
case class Exchange(
newPartitioning: Partitioning,
sort: Boolean,
newOrdering: Seq[SortOrder],
child: SparkPlan)
extends UnaryNode {

override def outputPartitioning: Partitioning = newPartitioning

override def outputOrdering = newOrdering

override def output: Seq[Attribute] = child.output

/** We must copy rows when sort based shuffle is on */
Expand All @@ -51,6 +60,20 @@ case class Exchange(
private val bypassMergeThreshold =
child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)

private val keyOrdering = {
if (newOrdering.nonEmpty) {
val key = newPartitioning.keyExpressions
val boundOrdering = newOrdering.map { o =>
val ordinal = key.indexOf(o.child)
if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning")
o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable))
}
new RowOrdering(boundOrdering)
} else {
null // Ordering will not be used
}
}

override def execute(): RDD[Row] = attachTree(this , "execute") {
newPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
Expand All @@ -62,7 +85,9 @@ case class Exchange(
// we can avoid the defensive copies to improve performance. In the long run, we probably
// want to include information in shuffle dependencies to indicate whether elements in the
// source RDD should be copied.
val rdd = if ((sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || sort) {
val willMergeSort = sortBasedShuffleOn && numPartitions > bypassMergeThreshold

val rdd = if (willMergeSort || newOrdering.nonEmpty) {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
iter.map(r => (hashExpressions(r).copy(), r.copy()))
Expand All @@ -75,16 +100,12 @@ case class Exchange(
}
}
val part = new HashPartitioner(numPartitions)
val shuffled = sort match {
case false => new ShuffledRDD[Row, Row, Row](rdd, part)
case true =>
val sortingExpressions = expressions.zipWithIndex.map {
case (exp, index) =>
new SortOrder(BoundReference(index, exp.dataType, exp.nullable), Ascending)
}
val ordering = new RowOrdering(sortingExpressions, child.output)
new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering)
}
val shuffled =
if (newOrdering.nonEmpty) {
new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(keyOrdering)
} else {
new ShuffledRDD[Row, Row, Row](rdd, part)
}
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._2)

Expand All @@ -102,7 +123,12 @@ case class Exchange(
implicit val ordering = new RowOrdering(sortingExpressions, child.output)

val part = new RangePartitioner(numPartitions, rdd, ascending = true)
val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
val shuffled =
if (newOrdering.nonEmpty) {
new ShuffledRDD[Row, Null, Null](rdd, part).setKeyOrdering(keyOrdering)
} else {
new ShuffledRDD[Row, Null, Null](rdd, part)
}
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))

shuffled.map(_._1)
Expand Down Expand Up @@ -135,27 +161,35 @@ case class Exchange(
* Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
* of input data meets the
* [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
* each operator by inserting [[Exchange]] Operators where required.
* each operator by inserting [[Exchange]] Operators where required. Also ensure that the
* required input partition ordering requirements are met.
*/
private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] {
private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
// TODO: Determine the number of partitions.
def numPartitions: Int = sqlContext.conf.numShufflePartitions

def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator: SparkPlan =>
// Check if every child's outputPartitioning satisfies the corresponding
// True iff every child's outputPartitioning satisfies the corresponding
// required data distribution.
def meetsRequirements: Boolean =
!operator.requiredChildDistribution.zip(operator.children).map {
operator.requiredChildDistribution.zip(operator.children).forall {
case (required, child) =>
val valid = child.outputPartitioning.satisfies(required)
logDebug(
s"${if (valid) "Valid" else "Invalid"} distribution," +
s"required: $required current: ${child.outputPartitioning}")
valid
}.exists(!_)
}

// Check if outputPartitionings of children are compatible with each other.
// True iff any of the children are incorrectly sorted.
def needsAnySort: Boolean =
operator.requiredChildOrdering.zip(operator.children).exists {
case (required, child) => required.nonEmpty && required != child
}


// True iff outputPartitionings of children are compatible with each other.
// It is possible that every child satisfies its required data distribution
// but two children have incompatible outputPartitionings. For example,
// A dataset is range partitioned by "a.asc" (RangePartitioning) and another
Expand All @@ -172,40 +206,61 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
case Seq(a,b) => a compatibleWith b
}.exists(!_)

// Check if the partitioning we want to ensure is the same as the child's output
// partitioning. If so, we do not need to add the Exchange operator.
def addExchangeIfNecessary(
// Adds Exchange or Sort operators as required
def addOperatorsIfNecessary(
partitioning: Partitioning,
child: SparkPlan,
rowOrdering: Option[Ordering[Row]] = None): SparkPlan = {
val needSort = child.outputOrdering != rowOrdering
if (child.outputPartitioning != partitioning || needSort) {
// TODO: if only needSort, we need only sort each partition instead of an Exchange
Exchange(partitioning, sort = needSort, child)
rowOrdering: Seq[SortOrder],
child: SparkPlan): SparkPlan = {
val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering
val needsShuffle = child.outputPartitioning != partitioning
val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering)

if (needSort && needsShuffle && canSortWithShuffle) {
Exchange(partitioning, rowOrdering, child)
} else {
child
val withShuffle = if (needsShuffle) {
Exchange(partitioning, Nil, child)
} else {
child
}

val withSort = if (needSort) {
Sort(rowOrdering, global = false, withShuffle)
} else {
withShuffle
}

withSort
}
}

if (meetsRequirements && compatible) {
if (meetsRequirements && compatible && !needsAnySort) {
operator
} else {
// At least one child does not satisfies its required data distribution or
// at least one child's outputPartitioning is not compatible with another child's
// outputPartitioning. In this case, we need to add Exchange operators.
val repartitionedChildren = operator.requiredChildDistribution.zip(
operator.children.zip(operator.requiredChildOrdering)
).map {
case (AllTuples, (child, _)) =>
addExchangeIfNecessary(SinglePartition, child)
case (ClusteredDistribution(clustering), (child, rowOrdering)) =>
addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child, rowOrdering)
case (OrderedDistribution(ordering), (child, None)) =>
addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child)
case (UnspecifiedDistribution, (child, _)) => child
case (dist, _) => sys.error(s"Don't know how to ensure $dist")
val requirements =
(operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)

val fixedChildren = requirements.zipped.map {
case (AllTuples, rowOrdering, child) =>
addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
case (ClusteredDistribution(clustering), rowOrdering, child) =>
addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
case (OrderedDistribution(ordering), rowOrdering, child) =>
addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), Nil, child)

case (UnspecifiedDistribution, Seq(), child) =>
child
case (UnspecifiedDistribution, rowOrdering, child) =>
Sort(rowOrdering, global = false, child)

case (dist, ordering, _) =>
sys.error(s"Don't know how to ensure $dist with ordering $ordering")
}
operator.withNewChildren(repartitionedChildren)

operator.withNewChildren(fixedChildren)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
Seq.fill(children.size)(UnspecifiedDistribution)

/** Specifies how data is ordered in each partition. */
def outputOrdering: Option[Ordering[Row]] = None
def outputOrdering: Seq[SortOrder] = Nil

/** Specifies sort order for each partition requirements on the input data for this operator. */
def requiredChildOrdering: Seq[Option[Ordering[Row]]] = Seq.fill(children.size)(None)
def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)

/**
* Runs this query returning the result as an RDD.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
case logical.Repartition(expressions, child) =>
execution.Exchange(
HashPartitioning(expressions, numPartitions), sort = false, planLater(child)) :: Nil
HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil
case e @ EvaluatePython(udf, child, _) =>
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
iter.map(resuableProjection)
}

/**
* outputOrdering of Project is not always same with child's outputOrdering if the certain
* key is pruned, however, if the key is pruned then we must not require child using this
* ordering from upper layer, so it is fine to keep it to avoid some unnecessary sorting.
*/
override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
}

/**
Expand All @@ -63,7 +58,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
iter.filter(conditionEvaluator)
}

override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
}

/**
Expand Down Expand Up @@ -111,7 +106,7 @@ case class Limit(limit: Int, child: SparkPlan)
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = SinglePartition

override def outputOrdering: Option[Ordering[Row]] = child.outputOrdering
override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def executeCollect(): Array[Row] = child.executeTake(limit)

Expand Down Expand Up @@ -158,7 +153,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
// TODO: Pick num splits based on |limit|.
override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1)

override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder))
override def outputOrdering: Seq[SortOrder] = sortOrder
}

/**
Expand All @@ -185,7 +180,7 @@ case class Sort(

override def output: Seq[Attribute] = child.output

override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder))
override def outputOrdering: Seq[SortOrder] = sortOrder
}

/**
Expand Down Expand Up @@ -217,7 +212,7 @@ case class ExternalSort(

override def output: Seq[Attribute] = child.output

override def outputOrdering: Option[Ordering[Row]] = Some(new RowOrdering(sortOrder))
override def outputOrdering: Seq[SortOrder] = sortOrder
}

/**
Expand Down
Loading

0 comments on commit 7ddd656

Please sign in to comment.