Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Add partitioning parameter to read_parquet #47553

Merged
merged 5 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion python/ray/data/_internal/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
from ray.data.datasource.datasource import ReadTask
from ray.data.datasource.file_meta_provider import _handle_read_os_error
from ray.data.datasource.parquet_meta_provider import ParquetMetadataProvider
from ray.data.datasource.partitioning import PathPartitionFilter
from ray.data.datasource.partitioning import (
Partitioning,
PathPartitionFilter,
PathPartitionParser,
)
from ray.data.datasource.path_util import (
_has_file_extension,
_resolve_paths_and_filesystem,
Expand Down Expand Up @@ -164,6 +168,7 @@ def __init__(
schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None,
meta_provider: ParquetMetadataProvider = ParquetMetadataProvider(),
partition_filter: PathPartitionFilter = None,
partitioning: Optional[Partitioning] = Partitioning("hive"),
shuffle: Union[Literal["files"], None] = None,
include_paths: bool = False,
file_extensions: Optional[List[str]] = None,
Expand Down Expand Up @@ -280,6 +285,7 @@ def __init__(
self._schema = schema
self._file_metadata_shuffler = None
self._include_paths = include_paths
self._partitioning = partitioning
if shuffle == "files":
self._file_metadata_shuffler = np.random.default_rng()

Expand Down Expand Up @@ -358,13 +364,15 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
columns,
schema,
include_paths,
partitioning,
) = (
self._block_udf,
self._to_batches_kwargs,
self._default_read_batch_size_rows,
self._columns,
self._schema,
self._include_paths,
self._partitioning,
)
read_tasks.append(
ReadTask(
Expand All @@ -376,6 +384,7 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
schema,
f,
include_paths,
partitioning,
),
meta,
)
Expand Down Expand Up @@ -403,6 +412,7 @@ def read_fragments(
schema,
serialized_fragments: List[SerializedFragment],
include_paths: bool,
partitioning: Partitioning,
) -> Iterator["pyarrow.Table"]:
# This import is necessary to load the tensor extension type.
from ray.data.extensions.tensor_extension import ArrowTensorType # noqa
Expand All @@ -421,6 +431,10 @@ def read_fragments(
use_threads = to_batches_kwargs.pop("use_threads", False)
batch_size = to_batches_kwargs.pop("batch_size", default_read_batch_size_rows)
for fragment in fragments:
partitions = {}
if partitioning is not None:
parse = PathPartitionParser(partitioning)
partitions = parse(fragment.path)

def get_batch_iterable():
return fragment.to_batches(
Expand All @@ -440,6 +454,9 @@ def get_batch_iterable():
table = pa.Table.from_batches([batch], schema=schema)
if include_paths:
table = table.append_column("path", [[fragment.path]] * len(table))
if partitions:
table = _add_partitions_to_table(table, partitions)

# If the table is empty, drop it.
if table.num_rows > 0:
if block_udf is not None:
Expand Down Expand Up @@ -633,3 +650,17 @@ def sample_fragments(
sample_bar.close()

return sample_infos


def _add_partitions_to_table(table, partitions):
bveeramani marked this conversation as resolved.
Show resolved Hide resolved
import pyarrow as pa

for field, value in partitions.items():
column = pa.array([value] * len(table))
field_index = table.schema.get_field_index(field)
if field_index != -1:
table = table.set_column(field_index, field, column)
else:
table = table.append_column(field, column)
bveeramani marked this conversation as resolved.
Show resolved Hide resolved

return table
2 changes: 2 additions & 0 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ def read_parquet(
tensor_column_schema: Optional[Dict[str, Tuple[np.dtype, Tuple[int, ...]]]] = None,
meta_provider: Optional[ParquetMetadataProvider] = None,
partition_filter: Optional[PathPartitionFilter] = None,
partitioning: Optional[Partitioning] = Partitioning("hive"),
shuffle: Union[Literal["files"], None] = None,
include_paths: bool = False,
file_extensions: Optional[List[str]] = None,
Expand Down Expand Up @@ -746,6 +747,7 @@ def read_parquet(
schema=schema,
meta_provider=meta_provider,
partition_filter=partition_filter,
partitioning=partitioning,
shuffle=shuffle,
include_paths=include_paths,
file_extensions=file_extensions,
Expand Down
30 changes: 8 additions & 22 deletions python/ray/data/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ray.data.context import DataContext
from ray.data.datasource import DefaultFileMetadataProvider, ParquetMetadataProvider
from ray.data.datasource.parquet_meta_provider import PARALLELIZE_META_FETCH_THRESHOLD
from ray.data.datasource.partitioning import Partitioning
from ray.data.datasource.path_util import _unwrap_protocol
from ray.data.tests.conftest import * # noqa
from ray.data.tests.mock_http_server import * # noqa
Expand Down Expand Up @@ -480,20 +481,8 @@ def test_parquet_read_partitioned(ray_start_regular_shared, fs, data_path):
assert ds.schema() is not None
input_files = ds.input_files()
assert len(input_files) == 2, input_files
assert str(ds) == (
"Dataset(\n"
" num_rows=6,\n"
" schema={two: string, "
"one: dictionary<values=int32, indices=int32, ordered=0>}\n"
")"
), ds
assert repr(ds) == (
"Dataset(\n"
" num_rows=6,\n"
" schema={two: string, "
"one: dictionary<values=int32, indices=int32, ordered=0>}\n"
")"
), ds
assert str(ds) == "Dataset(num_rows=6, schema={two: string, one: int64})", ds
assert repr(ds) == "Dataset(num_rows=6, schema={two: string, one: int64})", ds

# Forces a data read.
values = [[s["one"], s["two"]] for s in ds.take()]
Expand Down Expand Up @@ -575,7 +564,7 @@ def test_parquet_read_partitioned_with_columns(ray_start_regular_shared, fs, dat
columns=["y", "z"],
filesystem=fs,
)
assert ds.columns() == ["y", "z"]
assert set(ds.columns()) == {"y", "z"}
values = [[s["y"], s["z"]] for s in ds.take()]
assert sorted(values) == [
["a", 0.1],
Expand Down Expand Up @@ -653,11 +642,8 @@ def test_parquet_read_partitioned_explicit(ray_start_regular_shared, tmp_path):
use_legacy_dataset=False,
)

schema = pa.schema([("one", pa.int32()), ("two", pa.string())])
partitioning = pa.dataset.partitioning(schema, flavor="hive")
ds = ray.data.read_parquet(
str(tmp_path), dataset_kwargs=dict(partitioning=partitioning)
)
partitioning = Partitioning("hive", field_types={"one": int, "two": str})
ds = ray.data.read_parquet(str(tmp_path), partitioning=partitioning)

# Test metadata-only parquet ops.
assert ds.count() == 6
Expand All @@ -667,8 +653,8 @@ def test_parquet_read_partitioned_explicit(ray_start_regular_shared, tmp_path):
assert ds.schema() is not None
input_files = ds.input_files()
assert len(input_files) == 2, input_files
assert str(ds) == "Dataset(num_rows=6, schema={two: string, one: int32})", ds
assert repr(ds) == "Dataset(num_rows=6, schema={two: string, one: int32})", ds
assert str(ds) == "Dataset(num_rows=6, schema={two: string, one: int64})", ds
assert repr(ds) == "Dataset(num_rows=6, schema={two: string, one: int64})", ds

# Forces a data read.
values = [[s["one"], s["two"]] for s in ds.take()]
Expand Down