Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datasets] Allow specify batch_size when reading Parquet file #31165

Merged
merged 2 commits into from
Dec 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/ray/data/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,13 +374,14 @@ def _read_pieces(

logger.debug(f"Reading {len(pieces)} parquet pieces")
use_threads = reader_args.pop("use_threads", False)
batch_size = reader_args.pop("batch_size", PARQUET_READER_ROW_BATCH_SIZE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kwargs passing through the read_parquet() are for https://arrow.apache.org/docs/python/generated/pyarrow.parquet.read_table.html, which doesn't have a "batch_size".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The question is if it would make sense it to pass args from read_parquet() to things other than read_table() API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense. We are actually passing kwargs to pyarrow.parquet.ParquetDataset and Scanner.from_fragment. These APIs do not have same arguments as read_table().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems not desirable. Dataset is a distributed Arrow then it may make sense to pass through args to read_table, the single-node version of read, but not others.

Copy link
Contributor

@clarkzinzow clarkzinzow Dec 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jianoaix The ray.data.read_parquet() API is more of a distributed analog for pyarrow.parquet.ParquetDataset, where we expose certain features that the underlying ParquetDataset provides (e.g. reading path-based partition columns into the data, supporting zero-read filter pushdown on partition columns, etc.). We do actually have a distributed analog for pyarrow.parquet.read_table(), and that's ray.data.read_parquet_bulk(), which doesn't use pyarrow.parquet.ParquetDataset and instead directly uses pyarrow.parquet.read_table().

For this API, I think that directing users in the docs to pyarrow.dataset.Scanner.from_fragment() for **arrow_parquet_args and to pyarrow.parquet.ParquetDataset for dataset_kwargs would be best, and we should look at turning these passthrough arguments into top-level arguments that we define, with the passthrough being an implementation detail. Going forward, if we continue to build out our own format-agnostic partitioning machinery, we should eventually consider switching to pyarrow.parquet.read_table() if/when we achieve feature parity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know that. IIUC, the difference between read_table v.s. ParuqetDataset/Scanner is whether we stream read a single file. So in read_parquet_bulk(), when reading a single file, it's not streamed, whereas in read_parquet(), it is streaming. It looks to me converging on ParuqetDataset/Scanner for streaming single file is a good option for all those read APIs.

For this PR itself, since it's leveraging existing arg passing, LG to move forward. We may discuss in followup in how to make the APIs / arg passing better.

Copy link
Contributor

@clarkzinzow clarkzinzow Dec 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

Small note: the key difference is actually whether or not we're using Arrow's dataset stuff, which gives us a bunch of partitioning machinery. The streaming vs. full read was just an implementation detail for getting the performance out of read_parquet_bulk() that Amazon needed, since this was before we fixed the buffering for the streaming Parquet read. We could probably move the read_parquet_bulk() to a streaming read with a buffer size set and get the same performance as the full read.

for piece in pieces:
part = _get_partition_keys(piece.partition_expression)
batches = piece.to_batches(
use_threads=use_threads,
columns=columns,
schema=schema,
batch_size=PARQUET_READER_ROW_BATCH_SIZE,
batch_size=batch_size,
**reader_args,
)
for batch in batches:
Expand Down Expand Up @@ -461,6 +462,9 @@ def _sample_piece(
batch_size = max(
min(piece.metadata.num_rows, PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS), 1
)
# Use the batch_size calculated above, and ignore the one specified by user if set.
# This is to avoid sampling too few or too many rows.
reader_args.pop("batch_size", None)
batches = piece.to_batches(
columns=columns,
schema=schema,
Expand Down
4 changes: 2 additions & 2 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def read_parquet(
Dataset(num_blocks=..., num_rows=150, schema={sepal.length: double, ...})

For further arguments you can pass to pyarrow as a keyword argument, see
https://arrow.apache.org/docs/python/generated/pyarrow.parquet.read_table.html
https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_fragment

Args:
paths: A single file path or directory, or a list of file paths. Multiple
Expand All @@ -479,7 +479,7 @@ def read_parquet(
meta_provider: File metadata provider. Custom metadata providers may
be able to resolve file metadata more quickly and/or accurately.
arrow_parquet_args: Other parquet read options to pass to pyarrow, see
https://arrow.apache.org/docs/python/generated/pyarrow.parquet.read_table.html
https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html#pyarrow.dataset.Scanner.from_fragment

Returns:
Dataset holding Arrow records read from the specified paths.
Expand Down
7 changes: 7 additions & 0 deletions python/ray/data/tests/test_dataset_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,13 @@ def test_parquet_read_empty_file(ray_start_regular_shared, tmp_path):
pd.testing.assert_frame_equal(ds.to_pandas(), table.to_pandas())


def test_parquet_reader_batch_size(ray_start_regular_shared, tmp_path):
path = os.path.join(tmp_path, "data.parquet")
ray.data.range_tensor(1000, shape=(1000,)).write_parquet(path)
ds = ray.data.read_parquet(path, batch_size=10)
assert ds.count() == 1000


if __name__ == "__main__":
import sys

Expand Down