Skip to content

Commit

Permalink
[SPARK-2983] [PySpark] improve performance of sortByKey()
Browse files Browse the repository at this point in the history
1. skip partitionBy() when numOfPartition is 1
2. use bisect_left (O(lg(N))) instread of loop (O(N)) in
rangePartitioner

Author: Davies Liu <[email protected]>

Closes #1898 from davies/sort and squashes the following commits:

0a9608b [Davies Liu] Merge branch 'master' into sort
1cf9565 [Davies Liu] improve performance of sortByKey()
  • Loading branch information
davies authored and mateiz committed Aug 13, 2014
1 parent c974a71 commit 434bea1
Showing 1 changed file with 24 additions and 23 deletions.
47 changes: 24 additions & 23 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from threading import Thread
import warnings
import heapq
import bisect
from random import Random
from math import sqrt, log

Expand Down Expand Up @@ -574,6 +575,8 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
# noqa
>>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
>>> sc.parallelize(tmp).sortByKey(True, 1).collect()
[('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
>>> sc.parallelize(tmp).sortByKey(True, 2).collect()
[('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
>>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)]
Expand All @@ -584,42 +587,40 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
if numPartitions is None:
numPartitions = self._defaultReducePartitions()

bounds = list()
if numPartitions == 1:
if self.getNumPartitions() > 1:
self = self.coalesce(1)

def sort(iterator):
return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))

return self.mapPartitions(sort)

# first compute the boundary of each part via sampling: we want to partition
# the key-space into bins such that the bins have roughly the same
# number of (key, value) pairs falling into them
if numPartitions > 1:
rddSize = self.count()
# constant from Spark's RangePartitioner
maxSampleSize = numPartitions * 20.0
fraction = min(maxSampleSize / max(rddSize, 1), 1.0)

samples = self.sample(False, fraction, 1).map(
lambda (k, v): k).collect()
samples = sorted(samples, reverse=(not ascending), key=keyfunc)

# we have numPartitions many parts but one of the them has
# an implicit boundary
for i in range(0, numPartitions - 1):
index = (len(samples) - 1) * (i + 1) / numPartitions
bounds.append(samples[index])
rddSize = self.count()
maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()
samples = sorted(samples, reverse=(not ascending), key=keyfunc)

# we have numPartitions many parts but one of the them has
# an implicit boundary
bounds = [samples[len(samples) * (i + 1) / numPartitions]
for i in range(0, numPartitions - 1)]

def rangePartitionFunc(k):
p = 0
while p < len(bounds) and keyfunc(k) > bounds[p]:
p += 1
p = bisect.bisect_left(bounds, keyfunc(k))
if ascending:
return p
else:
return numPartitions - 1 - p

def mapFunc(iterator):
yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))

return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc)
.mapPartitions(mapFunc, preservesPartitioning=True)
.flatMap(lambda x: x, preservesPartitioning=True))
return self.partitionBy(numPartitions, rangePartitionFunc).mapPartitions(mapFunc, True)

def sortBy(self, keyfunc, ascending=True, numPartitions=None):
"""
Expand Down

0 comments on commit 434bea1

Please sign in to comment.