Skip to content

Commit

Permalink
[Data] combine_chunks before chunking pyarrow.Table block into batches (
Browse files Browse the repository at this point in the history
ray-project#34352)

pyarrow.Table.slice is slow when the table has many chunks which makes batching pyarrow block slow. The fix is combining chunks into a single one to make slice faster with the cost of an extra copy.

Signed-off-by: Jiajun Yao <[email protected]>
  • Loading branch information
jjyao authored and vitsai committed Apr 17, 2023
1 parent b2c750f commit ac6d3c3
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 38 deletions.
19 changes: 1 addition & 18 deletions python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,21 +652,4 @@ def gen():

def _copy_table(table: "pyarrow.Table") -> "pyarrow.Table":
"""Copy the provided Arrow table."""
import pyarrow as pa
from ray.air.util.transform_pyarrow import (
_concatenate_extension_column,
_is_column_extension_type,
)

# Copy the table by copying each column and constructing a new table with
# the same schema.
cols = table.columns
new_cols = []
for col in cols:
if _is_column_extension_type(col):
# Extension arrays don't support concatenation.
arr = _concatenate_extension_column(col)
else:
arr = col.combine_chunks()
new_cols.append(arr)
return pa.Table.from_arrays(new_cols, schema=table.schema)
return transform_pyarrow.combine_chunks(table)
23 changes: 23 additions & 0 deletions python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,26 @@ def concat_and_sort(
ret = concat(blocks)
indices = pyarrow.compute.sort_indices(ret, sort_keys=key)
return take_table(ret, indices)


def combine_chunks(table: "pyarrow.Table") -> "pyarrow.Table":
"""This is pyarrow.Table.combine_chunks()
with support for extension types.
This will create a new table by combining the chunks the input table has.
"""
from ray.air.util.transform_pyarrow import (
_concatenate_extension_column,
_is_column_extension_type,
)

cols = table.columns
new_cols = []
for col in cols:
if _is_column_extension_type(col):
# Extension arrays don't support concatenation.
arr = _concatenate_extension_column(col)
else:
arr = col.combine_chunks()
new_cols.append(arr)
return pyarrow.Table.from_arrays(new_cols, schema=table.schema)
30 changes: 30 additions & 0 deletions python/ray/data/_internal/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@

from ray.data.block import Block, BlockAccessor
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.arrow_block import ArrowBlockAccessor
from ray.data._internal.arrow_ops import transform_pyarrow

# pyarrow.Table.slice is slow when the table has many chunks
# so we combine chunks into a single one to make slice faster
# with the cost of an extra copy.
# See https://github.com/ray-project/ray/issues/31108 for more details.
# TODO(jjyao): remove this once
# https://github.com/apache/arrow/issues/35126 is resolved.
MIN_NUM_CHUNKS_TO_TRIGGER_COMBINE_CHUNKS = 2


class BatcherInterface:
Expand Down Expand Up @@ -130,6 +140,15 @@ def next_batch(self) -> Block:
output.add_block(accessor.slice(0, accessor.num_rows(), copy=False))
needed -= accessor.num_rows()
else:
if (
isinstance(accessor, ArrowBlockAccessor)
and block.num_columns > 0
and block.column(0).num_chunks
>= MIN_NUM_CHUNKS_TO_TRIGGER_COMBINE_CHUNKS
):
accessor = BlockAccessor.for_block(
transform_pyarrow.combine_chunks(block)
)
# We only need part of the block to fill out a batch.
output.add_block(accessor.slice(0, needed, copy=False))
# Add the rest of the block to the leftovers.
Expand Down Expand Up @@ -296,6 +315,17 @@ def next_batch(self) -> Block:
self._builder.add_block(self._shuffle_buffer)
# Build the new shuffle buffer.
self._shuffle_buffer = self._builder.build()
if (
isinstance(
BlockAccessor.for_block(self._shuffle_buffer), ArrowBlockAccessor
)
and self._shuffle_buffer.num_columns > 0
and self._shuffle_buffer.column(0).num_chunks
>= MIN_NUM_CHUNKS_TO_TRIGGER_COMBINE_CHUNKS
):
self._shuffle_buffer = transform_pyarrow.combine_chunks(
self._shuffle_buffer
)
# Reset the builder.
self._builder = DelegatingBlockBuilder()
# Invalidate the shuffle indices.
Expand Down
41 changes: 40 additions & 1 deletion python/ray/data/tests/test_batcher.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import time
import pytest

import pyarrow as pa

from ray.data._internal.batcher import ShufflingBatcher
from ray.data._internal.batcher import ShufflingBatcher, Batcher


def gen_block(num_rows):
Expand Down Expand Up @@ -127,6 +128,44 @@ def next_and_check(
)


def test_batching_pyarrow_table_with_many_chunks():
"""Make sure batching a pyarrow table with many chunks is fast.
See https://github.com/ray-project/ray/issues/31108 for more details.
"""
num_chunks = 5000
batch_size = 1024

batches = []
for _ in range(num_chunks):
batch = {}
for i in range(10):
batch[str(i)] = list(range(batch_size))
batches.append(pa.Table.from_pydict(batch))

block = pa.concat_tables(batches, promote=True)

start = time.perf_counter()
batcher = Batcher(batch_size, ensure_copy=False)
batcher.add(block)
batcher.done_adding()
while batcher.has_any():
batcher.next_batch()
duration = time.perf_counter() - start
assert duration < 10

start = time.perf_counter()
shuffling_batcher = ShufflingBatcher(
batch_size=batch_size, shuffle_buffer_min_size=batch_size
)
shuffling_batcher.add(block)
shuffling_batcher.done_adding()
while shuffling_batcher.has_any():
shuffling_batcher.next_batch()
duration = time.perf_counter() - start
assert duration < 20


if __name__ == "__main__":
import sys

Expand Down
25 changes: 6 additions & 19 deletions release/nightly_tests/dataset/map_batches_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import ray
from ray.data._internal.compute import ActorPoolStrategy, ComputeStrategy
from ray.data.dataset import Dataset
from ray.data.dataset import Dataset, MaterializedDatastream

from benchmark import Benchmark

Expand All @@ -22,9 +22,8 @@ def map_batches(
is_eager_executed: Optional[bool] = False,
) -> Dataset:

assert isinstance(input_ds, MaterializedDatastream)
ds = input_ds
if is_eager_executed:
ds.materialize()

for _ in range(num_calls):
ds = ds.map_batches(
Expand All @@ -34,16 +33,14 @@ def map_batches(
compute=compute,
)
if is_eager_executed:
ds.materialize()
ds = ds.materialize()
return ds


def run_map_batches_benchmark(benchmark: Benchmark):
input_ds = ray.data.read_parquet(
"s3://air-example-data/ursa-labs-taxi-data/by_year/2018/01"
)
lazy_input_ds = input_ds.lazy()
input_ds.materialize()
).materialize()

batch_formats = ["pandas", "numpy"]
batch_sizes = [1024, 2048, 4096, None]
Expand All @@ -52,16 +49,6 @@ def run_map_batches_benchmark(benchmark: Benchmark):
# Test different batch_size of map_batches.
for batch_format in batch_formats:
for batch_size in batch_sizes:
# TODO(chengsu): https://github.com/ray-project/ray/issues/31108
# Investigate why NumPy with batch_size being 1024, took much longer
# to finish.
if (
batch_format == "numpy"
and batch_size is not None
and batch_size == 1024
):
continue

num_calls = 2
test_name = f"map-batches-{batch_format}-{batch_size}-{num_calls}-eager"
benchmark.run(
Expand All @@ -77,7 +64,7 @@ def run_map_batches_benchmark(benchmark: Benchmark):
benchmark.run(
test_name,
map_batches,
input_ds=lazy_input_ds,
input_ds=input_ds,
batch_format=batch_format,
batch_size=batch_size,
num_calls=num_calls,
Expand Down Expand Up @@ -113,7 +100,7 @@ def run_map_batches_benchmark(benchmark: Benchmark):
benchmark.run(
test_name,
map_batches,
input_ds=lazy_input_ds,
input_ds=input_ds,
batch_format=batch_format,
batch_size=batch_size,
compute=compute,
Expand Down

0 comments on commit ac6d3c3

Please sign in to comment.