From 18b42100ffdb19b7f45c8709df9210a9d30d7bb0 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 19 Oct 2021 22:46:49 -0700 Subject: [PATCH 1/5] fix it --- python/ray/data/impl/simple_block.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/ray/data/impl/simple_block.py b/python/ray/data/impl/simple_block.py index 7816b187d3a4..4ca4bb9f8d00 100644 --- a/python/ray/data/impl/simple_block.py +++ b/python/ray/data/impl/simple_block.py @@ -118,9 +118,19 @@ def sort_and_partition(self, boundaries: List[T], key: SortKeyT, key_fn = key if key else lambda x: x comp_fn = lambda x, b: key_fn(x) > b \ if descending else lambda x, b: key_fn(x) < b # noqa E731 - boundary_indices = [ + boundary_indices = [] + remaining = boundaries.copy() + for i, x in enumerate(items): + while remaining and not comp_fn(x, remaining[0]): + remaining.pop(0) + boundary_indices.append(i) + for _ in remaining: + boundary_indices.append(len(items)) + assert len(boundary_indices) == len(boundaries) + bx = [ len([1 for x in items if comp_fn(x, b)]) for b in boundaries ] + assert bx == boundary_indices ret = [] prev_i = 0 for i in boundary_indices: From 411955c4c735ed200f9243e519c52cc731553fc6 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 19 Oct 2021 22:56:39 -0700 Subject: [PATCH 2/5] fix --- python/ray/data/impl/arrow_block.py | 16 +++++++++++++--- python/ray/data/impl/simple_block.py | 7 +++---- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/python/ray/data/impl/arrow_block.py b/python/ray/data/impl/arrow_block.py index a9d0634930a4..69bb63ce5f82 100644 --- a/python/ray/data/impl/arrow_block.py +++ b/python/ray/data/impl/arrow_block.py @@ -267,9 +267,19 @@ def sort_and_partition(self, boundaries: List[T], key: SortKeyT, # *greater than* the boundary value instead. col, _ = key[0] comp_fn = pac.greater if descending else pac.less - boundary_indices = [ - pac.sum(comp_fn(table[col], b)).as_py() for b in boundaries - ] + + # Compute the boundary indices in O(n) time via scan. + boundary_indices = [] + remaining = boundaries.copy() + values = table[col] + for i, x in enumerate(values): + while remaining and not comp_fn(x, remaining[0]).as_py(): + remaining.pop(0) + boundary_indices.append(i) + for _ in remaining: + boundary_indices.append(len(values)) + assert len(boundary_indices) == len(boundaries) + ret = [] prev_i = 0 for i in boundary_indices: diff --git a/python/ray/data/impl/simple_block.py b/python/ray/data/impl/simple_block.py index 4ca4bb9f8d00..c5691e2e58b4 100644 --- a/python/ray/data/impl/simple_block.py +++ b/python/ray/data/impl/simple_block.py @@ -118,6 +118,8 @@ def sort_and_partition(self, boundaries: List[T], key: SortKeyT, key_fn = key if key else lambda x: x comp_fn = lambda x, b: key_fn(x) > b \ if descending else lambda x, b: key_fn(x) < b # noqa E731 + + # Compute the boundary indices in O(n) time via scan. boundary_indices = [] remaining = boundaries.copy() for i, x in enumerate(items): @@ -127,10 +129,7 @@ def sort_and_partition(self, boundaries: List[T], key: SortKeyT, for _ in remaining: boundary_indices.append(len(items)) assert len(boundary_indices) == len(boundaries) - bx = [ - len([1 for x in items if comp_fn(x, b)]) for b in boundaries - ] - assert bx == boundary_indices + ret = [] prev_i = 0 for i in boundary_indices: From 3dd8ff8c3e665d2d8847d656da91e087d55f4897 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 19 Oct 2021 23:06:13 -0700 Subject: [PATCH 3/5] update --- python/ray/data/impl/arrow_block.py | 47 +++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/python/ray/data/impl/arrow_block.py b/python/ray/data/impl/arrow_block.py index 69bb63ce5f82..2b5d91b7cf8e 100644 --- a/python/ray/data/impl/arrow_block.py +++ b/python/ray/data/impl/arrow_block.py @@ -268,24 +268,28 @@ def sort_and_partition(self, boundaries: List[T], key: SortKeyT, col, _ = key[0] comp_fn = pac.greater if descending else pac.less - # Compute the boundary indices in O(n) time via scan. - boundary_indices = [] - remaining = boundaries.copy() - values = table[col] - for i, x in enumerate(values): - while remaining and not comp_fn(x, remaining[0]).as_py(): - remaining.pop(0) - boundary_indices.append(i) - for _ in remaining: - boundary_indices.append(len(values)) - assert len(boundary_indices) == len(boundaries) + # TODO(ekl) this is O(n^2) but in practice it's much faster than the + # O(n) algorithm, could be optimized. + boundary_indices = [ + pac.sum(comp_fn(table[col], b)).as_py() for b in boundaries + ] + ### Compute the boundary indices in O(n) time via scan. # noqa + # boundary_indices = [] + # remaining = boundaries.copy() + # values = table[col] + # for i, x in enumerate(values): + # while remaining and not comp_fn(x, remaining[0]).as_py(): + # remaining.pop(0) + # boundary_indices.append(i) + # for _ in remaining: + # boundary_indices.append(len(values)) ret = [] prev_i = 0 for i in boundary_indices: - ret.append(table.slice(prev_i, i - prev_i)) + ret.append(_copy_table(table.slice(prev_i, i - prev_i))) prev_i = i - ret.append(table.slice(prev_i)) + ret.append(_copy_table(table.slice(prev_i))) return ret @staticmethod @@ -296,3 +300,20 @@ def merge_sorted_blocks( indices = pyarrow.compute.sort_indices(ret, sort_keys=key) ret = ret.take(indices) return ret, ArrowBlockAccessor(ret).get_metadata(None) + + +def _copy_table(table: "pyarrow.Table") -> "pyarrow.Table": + import pyarrow as pa + + cols = table.columns + new_cols = [] + for col in cols: + chunks = [] + for chunk in col.chunks: + if isinstance(chunk, pa.ExtensionArray): + new_chunk = type(chunk).from_numpy(chunk.to_numpy()) + else: + new_chunk = pa.array(chunk.to_numpy(), chunk.type) + chunks.append(new_chunk) + new_cols.append(pa.chunked_array(chunks, col.type)) + return pa.Table.from_arrays(new_cols, schema=table.schema) From 84be5528f99a86738f49fa4538a1766c2eef931f Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 20 Oct 2021 10:46:54 -0700 Subject: [PATCH 4/5] update --- python/ray/data/impl/arrow_block.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/python/ray/data/impl/arrow_block.py b/python/ray/data/impl/arrow_block.py index 2b5d91b7cf8e..34beba416b28 100644 --- a/python/ray/data/impl/arrow_block.py +++ b/python/ray/data/impl/arrow_block.py @@ -303,17 +303,21 @@ def merge_sorted_blocks( def _copy_table(table: "pyarrow.Table") -> "pyarrow.Table": + """Copy the provided Arrow table.""" import pyarrow as pa + # Copy the table by copying each column and constructing a new table with + # the same schema. cols = table.columns new_cols = [] for col in cols: - chunks = [] - for chunk in col.chunks: - if isinstance(chunk, pa.ExtensionArray): - new_chunk = type(chunk).from_numpy(chunk.to_numpy()) - else: - new_chunk = pa.array(chunk.to_numpy(), chunk.type) - chunks.append(new_chunk) - new_cols.append(pa.chunked_array(chunks, col.type)) + if col.num_chunks > 0 and isinstance(col.chunk(0), pa.ExtensionArray): + # If an extension array, we copy the underlying storage arrays. + chunk = col.chunk(0) + arr = type(chunk).from_storage( + chunk.type, pa.concat_arrays([c.storage for c in col.chunks])) + else: + # Otherwise, we copy the top-level chunk arrays. + arr = col.combine_chunks() + new_cols.append(arr) return pa.Table.from_arrays(new_cols, schema=table.schema) From a2c9101cb53b5be590491c7fb3c7db3ddca768e8 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 20 Oct 2021 10:47:28 -0700 Subject: [PATCH 5/5] update --- python/ray/data/impl/arrow_block.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/ray/data/impl/arrow_block.py b/python/ray/data/impl/arrow_block.py index 34beba416b28..b7d086862eb1 100644 --- a/python/ray/data/impl/arrow_block.py +++ b/python/ray/data/impl/arrow_block.py @@ -287,6 +287,8 @@ def sort_and_partition(self, boundaries: List[T], key: SortKeyT, ret = [] prev_i = 0 for i in boundary_indices: + # Slices need to be copied to avoid including the base table + # during serialization. ret.append(_copy_table(table.slice(prev_i, i - prev_i))) prev_i = i ret.append(_copy_table(table.slice(prev_i)))