Skip to content

Commit

Permalink
[data] Fix O(n^2) issues in simple_block sort (#19543)
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl authored Oct 21, 2021
1 parent 45f1ff0 commit 48ecb1f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
23 changes: 19 additions & 4 deletions python/ray/data/impl/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,31 @@ def sort_and_partition(self, boundaries: List[T], key: SortKeyT,
# *greater than* the boundary value instead.
col, _ = key[0]
comp_fn = pac.greater if descending else pac.less

# TODO(ekl) this is O(n^2) but in practice it's much faster than the
# O(n) algorithm, could be optimized.
boundary_indices = [
pac.sum(comp_fn(table[col], b)).as_py() for b in boundaries
]
### Compute the boundary indices in O(n) time via scan. # noqa
# boundary_indices = []
# remaining = boundaries.copy()
# values = table[col]
# for i, x in enumerate(values):
# while remaining and not comp_fn(x, remaining[0]).as_py():
# remaining.pop(0)
# boundary_indices.append(i)
# for _ in remaining:
# boundary_indices.append(len(values))

ret = []
prev_i = 0
for i in boundary_indices:
ret.append(table.slice(prev_i, i - prev_i))
# Slices need to be copied to avoid including the base table
# during serialization.
ret.append(_copy_table(table.slice(prev_i, i - prev_i)))
prev_i = i
ret.append(table.slice(prev_i))
ret.append(_copy_table(table.slice(prev_i)))
return ret

@staticmethod
Expand All @@ -286,8 +302,7 @@ def merge_sorted_blocks(


def _copy_table(table: "pyarrow.Table") -> "pyarrow.Table":
"""Copy the provided Arrow table.
"""
"""Copy the provided Arrow table."""
import pyarrow as pa

# Copy the table by copying each column and constructing a new table with
Expand Down
15 changes: 12 additions & 3 deletions python/ray/data/impl/simple_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,18 @@ def sort_and_partition(self, boundaries: List[T], key: SortKeyT,
key_fn = key if key else lambda x: x
comp_fn = lambda x, b: key_fn(x) > b \
if descending else lambda x, b: key_fn(x) < b # noqa E731
boundary_indices = [
len([1 for x in items if comp_fn(x, b)]) for b in boundaries
]

# Compute the boundary indices in O(n) time via scan.
boundary_indices = []
remaining = boundaries.copy()
for i, x in enumerate(items):
while remaining and not comp_fn(x, remaining[0]):
remaining.pop(0)
boundary_indices.append(i)
for _ in remaining:
boundary_indices.append(len(items))
assert len(boundary_indices) == len(boundaries)

ret = []
prev_i = 0
for i in boundary_indices:
Expand Down

0 comments on commit 48ecb1f

Please sign in to comment.