Skip to content

Commit

Permalink
Fix issue related to RangePartitioning:
Browse files Browse the repository at this point in the history
- We now defensively copy before computing the partition bounds, which is
  necessary in order to get accurate sampling.
- We now pass the actual partitioner into needToCopyObjectsBeforeShuffle(),
  which guards against the fact that RangePartitioner may produce a shuffle
  with fewer than `numPartitions` partitions.
  • Loading branch information
JoshRosen committed May 6, 2015
1 parent ad006a4 commit 6a6bfce
Showing 1 changed file with 37 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner}
import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.{SQLContext, Row}
Expand Down Expand Up @@ -81,21 +81,25 @@ case class Exchange(
*
* See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this issue.
*
* @param numPartitions the number of output partitions produced by the shuffle
* @param partitioner the partitioner for the shuffle
* @param serializer the serializer that will be used to write rows
* @return true if rows should be copied before being shuffled, false otherwise
*/
private def needToCopyObjectsBeforeShuffle(
numPartitions: Int,
partitioner: Partitioner,
serializer: Serializer): Boolean = {
// Note: even though we only use the partitioner's `numPartitions` field, we require it to be
// passed instead of directly passing the number of partitions in order to guard against
// corner-cases where a partitioner constructed with `numPartitions` partitions may output
// fewer partitions (like RangeParittioner, for example).
if (newOrdering.nonEmpty) {
// If a new ordering is required, then records will be sorted with Spark's `ExternalSorter`,
// which requires a defensive copy.
true
} else if (sortBasedShuffleOn) {
// Spark's sort-based shuffle also uses `ExternalSorter` to buffer records in memory.
// However, there are two special cases where we can avoid the copy, described below:
if (numPartitions <= bypassMergeThreshold) {
if (partitioner.numPartitions <= bypassMergeThreshold) {
// If the number of output partitions is sufficiently small, then Spark will fall back to
// the old hash-based shuffle write path which doesn't buffer deserialized records.
// Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
Expand Down Expand Up @@ -177,8 +181,9 @@ case class Exchange(
val keySchema = expressions.map(_.dataType).toArray
val valueSchema = child.output.map(_.dataType).toArray
val serializer = getSerializer(keySchema, valueSchema, numPartitions)
val part = new HashPartitioner(numPartitions)

val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions, serializer)) {
val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
iter.map(r => (hashExpressions(r).copy(), r.copy()))
Expand All @@ -190,55 +195,59 @@ case class Exchange(
iter.map(r => mutablePair.update(hashExpressions(r), r))
}
}
val part = new HashPartitioner(numPartitions)
val shuffled =
if (newOrdering.nonEmpty) {
new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(keyOrdering)
} else {
new ShuffledRDD[Row, Row, Row](rdd, part)
}
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
if (newOrdering.nonEmpty) {
shuffled.setKeyOrdering(keyOrdering)
}
shuffled.setSerializer(serializer)
shuffled.map(_._2)

case RangePartitioning(sortingExpressions, numPartitions) =>
val keySchema = child.output.map(_.dataType).toArray
val serializer = getSerializer(keySchema, null, numPartitions)

val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions, serializer)) {
child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))}
val childRdd = child.execute()
val part: Partitioner = {
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
// partition bounds. To get accurate samples, we need to copy the mutable keys.
val rddForSampling = childRdd.mapPartitions { iter =>
val mutablePair = new MutablePair[Row, Null]()
iter.map(row => mutablePair.update(row.copy(), null))
}
// TODO: RangePartitioner should take an Ordering.
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
new RangePartitioner(numPartitions, rddForSampling, ascending = true)
}

val rdd = if (needToCopyObjectsBeforeShuffle(part, serializer)) {
childRdd.mapPartitions { iter => iter.map(row => (row.copy(), null))}
} else {
child.execute().mapPartitions { iter =>
val mutablePair = new MutablePair[Row, Null](null, null)
childRdd.mapPartitions { iter =>
val mutablePair = new MutablePair[Row, Null]()
iter.map(row => mutablePair.update(row, null))
}
}

// TODO: RangePartitioner should take an Ordering.
implicit val ordering = new RowOrdering(sortingExpressions, child.output)

val part = new RangePartitioner(numPartitions, rdd, ascending = true)
val shuffled =
if (newOrdering.nonEmpty) {
new ShuffledRDD[Row, Null, Null](rdd, part).setKeyOrdering(keyOrdering)
} else {
new ShuffledRDD[Row, Null, Null](rdd, part)
}
val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
if (newOrdering.nonEmpty) {
shuffled.setKeyOrdering(keyOrdering)
}
shuffled.setSerializer(serializer)
shuffled.map(_._1)

case SinglePartition =>
val valueSchema = child.output.map(_.dataType).toArray
val serializer = getSerializer(null, valueSchema, 1)
val partitioner = new HashPartitioner(1)

val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions = 1, serializer)) {
val rdd = if (needToCopyObjectsBeforeShuffle(partitioner, serializer)) {
child.execute().mapPartitions { iter => iter.map(r => (null, r.copy())) }
} else {
child.execute().mapPartitions { iter =>
val mutablePair = new MutablePair[Null, Row]()
iter.map(r => mutablePair.update(null, r))
}
}
val partitioner = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
shuffled.setSerializer(serializer)
shuffled.map(_._2)
Expand Down

0 comments on commit 6a6bfce

Please sign in to comment.