Skip to content

Commit

Permalink
[Datasets] Allow specify batch_size when reading Parquet file (#31165)
Browse files Browse the repository at this point in the history
This PR is to allow users to specify batch_size when reading Parquet file.
  • Loading branch information
c21 authored Dec 21, 2022
1 parent 9b51b01 commit c8443c0
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
6 changes: 5 additions & 1 deletion python/ray/data/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,13 +374,14 @@ def _read_pieces(

logger.debug(f"Reading {len(pieces)} parquet pieces")
use_threads = reader_args.pop("use_threads", False)
batch_size = reader_args.pop("batch_size", PARQUET_READER_ROW_BATCH_SIZE)
for piece in pieces:
part = _get_partition_keys(piece.partition_expression)
batches = piece.to_batches(
use_threads=use_threads,
columns=columns,
schema=schema,
batch_size=PARQUET_READER_ROW_BATCH_SIZE,
batch_size=batch_size,
**reader_args,
)
for batch in batches:
Expand Down Expand Up @@ -461,6 +462,9 @@ def _sample_piece(
batch_size = max(
min(piece.metadata.num_rows, PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS), 1
)
# Use the batch_size calculated above, and ignore the one specified by user if set.
# This is to avoid sampling too few or too many rows.
reader_args.pop("batch_size", None)
batches = piece.to_batches(
columns=columns,
schema=schema,
Expand Down
4 changes: 2 additions & 2 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def read_parquet(
Dataset(num_blocks=..., num_rows=150, schema={sepal.length: double, ...})
For further arguments you can pass to pyarrow as a keyword argument, see
https://arrow.apache.org/docs/python/generated/pyarrow.parquet.read_table.html
https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_fragment
Args:
paths: A single file path or directory, or a list of file paths. Multiple
Expand All @@ -479,7 +479,7 @@ def read_parquet(
meta_provider: File metadata provider. Custom metadata providers may
be able to resolve file metadata more quickly and/or accurately.
arrow_parquet_args: Other parquet read options to pass to pyarrow, see
https://arrow.apache.org/docs/python/generated/pyarrow.parquet.read_table.html
https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_fragment
Returns:
Dataset holding Arrow records read from the specified paths.
Expand Down
7 changes: 7 additions & 0 deletions python/ray/data/tests/test_dataset_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,13 @@ def test_parquet_read_empty_file(ray_start_regular_shared, tmp_path):
pd.testing.assert_frame_equal(ds.to_pandas(), table.to_pandas())


def test_parquet_reader_batch_size(ray_start_regular_shared, tmp_path):
path = os.path.join(tmp_path, "data.parquet")
ray.data.range_tensor(1000, shape=(1000,)).write_parquet(path)
ds = ray.data.read_parquet(path, batch_size=10)
assert ds.count() == 1000


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit c8443c0

Please sign in to comment.