Skip to content

Commit

Permalink
[Data] Add partitioning parameter to read_parquet (#47553)
Browse files Browse the repository at this point in the history
To extract path partition information with `read_parquet`, you pass a
PyArrow `partitioning` object to `dataset_kwargs`. For example:
```
schema = pa.schema([("one", pa.int32()), ("two", pa.string())])
partitioning = pa.dataset.partitioning(schema, flavor="hive")
ds = ray.data.read_parquet(... dataset_kwargs=dict(partitioning=partitioning))
```

This is problematic for two reasons:
1. It tightly couples the interface with the implementation;
partitioning only works if we use `pyarrow.Dataset` in a specific way in
the implementation.
2. It's inconsistent with all of the other file-based API. All other
APIs use expose a top-level `partitioning` parameter (rather than
`dataset_kwargs`) where you pass a Ray Data `Partitioning` object
(rather than a PyArrow partitioning object).

---------

Signed-off-by: Balaji Veeramani <[email protected]>
  • Loading branch information
bveeramani authored Sep 16, 2024
1 parent f9e8e97 commit 1c80db5
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 38 deletions.
88 changes: 87 additions & 1 deletion python/ray/data/_internal/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@
_handle_read_os_error,
)
from ray.data.datasource.parquet_meta_provider import ParquetMetadataProvider
from ray.data.datasource.partitioning import PathPartitionFilter
from ray.data.datasource.partitioning import (
PartitionDataType,
Partitioning,
PathPartitionFilter,
PathPartitionParser,
)
from ray.data.datasource.path_util import (
_has_file_extension,
_resolve_paths_and_filesystem,
Expand Down Expand Up @@ -164,6 +169,7 @@ def __init__(
schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None,
meta_provider: ParquetMetadataProvider = ParquetMetadataProvider(),
partition_filter: PathPartitionFilter = None,
partitioning: Optional[Partitioning] = Partitioning("hive"),
shuffle: Union[Literal["files"], None] = None,
include_paths: bool = False,
file_extensions: Optional[List[str]] = None,
Expand Down Expand Up @@ -214,10 +220,22 @@ def __init__(
if dataset_kwargs is None:
dataset_kwargs = {}

if "partitioning" in dataset_kwargs:
raise ValueError(
"The 'partitioning' parameter isn't supported in 'dataset_kwargs'. "
"Use the top-level 'partitioning' parameter instead."
)

# This datasource manually adds partition data at the Ray Data-level. To avoid
# duplicating the partition data, we disable PyArrow's partitioning.
dataset_kwargs["partitioning"] = None

pq_ds = get_parquet_dataset(paths, filesystem, dataset_kwargs)

if schema is None:
schema = pq_ds.schema
schema = _add_partition_fields_to_schema(partitioning, schema, pq_ds)

if columns:
schema = pa.schema(
[schema.field(column) for column in columns], schema.metadata
Expand Down Expand Up @@ -280,6 +298,7 @@ def __init__(
self._schema = schema
self._file_metadata_shuffler = None
self._include_paths = include_paths
self._partitioning = partitioning
if shuffle == "files":
self._file_metadata_shuffler = np.random.default_rng()

Expand Down Expand Up @@ -358,13 +377,15 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
columns,
schema,
include_paths,
partitioning,
) = (
self._block_udf,
self._to_batches_kwargs,
self._default_read_batch_size_rows,
self._columns,
self._schema,
self._include_paths,
self._partitioning,
)
read_tasks.append(
ReadTask(
Expand All @@ -376,6 +397,7 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
schema,
f,
include_paths,
partitioning,
),
meta,
)
Expand Down Expand Up @@ -403,6 +425,7 @@ def read_fragments(
schema,
serialized_fragments: List[SerializedFragment],
include_paths: bool,
partitioning: Partitioning,
) -> Iterator["pyarrow.Table"]:
# This import is necessary to load the tensor extension type.
from ray.data.extensions.tensor_extension import ArrowTensorType # noqa
Expand All @@ -421,6 +444,18 @@ def read_fragments(
use_threads = to_batches_kwargs.pop("use_threads", False)
batch_size = to_batches_kwargs.pop("batch_size", default_read_batch_size_rows)
for fragment in fragments:
partitions = {}
if partitioning is not None:
parse = PathPartitionParser(partitioning)
partitions = parse(fragment.path)

# Filter out partitions that aren't in the user-specified columns list.
if columns is not None:
partitions = {
field_name: value
for field_name, value in partitions.items()
if field_name in columns
}

def get_batch_iterable():
return fragment.to_batches(
Expand All @@ -440,6 +475,9 @@ def get_batch_iterable():
table = pa.Table.from_batches([batch], schema=schema)
if include_paths:
table = table.append_column("path", [[fragment.path]] * len(table))
if partitions:
table = _add_partitions_to_table(partitions, table)

# If the table is empty, drop it.
if table.num_rows > 0:
if block_udf is not None:
Expand Down Expand Up @@ -633,3 +671,51 @@ def sample_fragments(
sample_bar.close()

return sample_infos


def _add_partitions_to_table(
partitions: Dict[str, PartitionDataType], table: "pyarrow.Table"
) -> "pyarrow.Table":
import pyarrow as pa

for field_name, value in partitions.items():
column = pa.array([value] * len(table))
field_index = table.schema.get_field_index(field_name)
if field_index != -1:
table = table.set_column(field_index, field_name, column)
else:
table = table.append_column(field_name, column)

return table


def _add_partition_fields_to_schema(
partitioning: Partitioning,
schema: "pyarrow.Schema",
parquet_dataset: "pyarrow.dataset.Dataset",
) -> "pyarrow.Schema":
"""Return a new schema with partition fields added.
This function infers the partition fields from the first file path in the dataset.
"""
import pyarrow as pa

# If the dataset is empty, we can't infer the partitioning.
if len(parquet_dataset.fragments) == 0:
return schema

# If the dataset isn't partitioned, we don't need to add any fields.
if partitioning is None:
return schema

first_path = parquet_dataset.fragments[0].path
parse = PathPartitionParser(partitioning)
partitions = parse(first_path)
for field_name in partitions:
if field_name in partitioning.field_types:
field_type = pa.from_numpy_dtype(partitioning.field_types[field_name])
else:
field_type = pa.string()
schema = schema.append(pa.field(field_name, field_type))

return schema
40 changes: 36 additions & 4 deletions python/ray/data/datasource/partitioning.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import posixpath
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union

from ray.util.annotations import DeveloperAPI, PublicAPI

if TYPE_CHECKING:
import pyarrow


PartitionDataType = Type[Union[int, float, str, bool]]


@DeveloperAPI
class PartitionStyle(str, Enum):
"""Supported dataset partition styles.
Expand Down Expand Up @@ -82,13 +85,19 @@ class Partitioning:
#: Required when parsing DIRECTORY partitioned paths or generating
#: HIVE partitioned paths.
field_names: Optional[List[str]] = None
#: A dictionary that maps partition key names to their desired data type. If not
#: provided, the data type defaults to string.
field_types: Optional[Dict[str, PartitionDataType]] = None
#: Filesystem that will be used for partition path file I/O.
filesystem: Optional["pyarrow.fs.FileSystem"] = None

def __post_init__(self):
if self.base_dir is None:
self.base_dir = ""

if self.field_types is None:
self.field_types = {}

self._normalized_base_dir = None
self._resolved_filesystem = None

Expand Down Expand Up @@ -165,6 +174,7 @@ def of(
style: PartitionStyle = PartitionStyle.HIVE,
base_dir: Optional[str] = None,
field_names: Optional[List[str]] = None,
field_types: Optional[Dict[str, PartitionDataType]] = None,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
) -> "PathPartitionParser":
"""Creates a path-based partition parser using a flattened argument list.
Expand All @@ -180,12 +190,14 @@ def of(
partition key field names must match the order and length of partition
directories discovered. Partition key field names are not required to
exist in the dataset schema.
field_types: A dictionary that maps partition key names to their desired
data type. If not provided, the data type default to string.
filesystem: Filesystem that will be used for partition path file I/O.
Returns:
The new path-based partition parser.
"""
scheme = Partitioning(style, base_dir, field_names, filesystem)
scheme = Partitioning(style, base_dir, field_names, field_types, filesystem)
return PathPartitionParser(scheme)

def __init__(self, partitioning: Partitioning):
Expand Down Expand Up @@ -226,14 +238,20 @@ def __call__(self, path: str) -> Dict[str, str]:
Args:
path: Input file path to parse.
Returns:
Dictionary mapping directory partition keys to values from the input file
path. Returns an empty dictionary for unpartitioned files.
"""
dir_path = self._dir_path_trim_base(path)
if dir_path is None:
return {}
return self._parser_fn(dir_path)
partitions: Dict[str, str] = self._parser_fn(dir_path)

for field, data_type in self._scheme.field_types.items():
partitions[field] = _cast_value(partitions[field], data_type)

return partitions

@property
def scheme(self) -> Partitioning:
Expand Down Expand Up @@ -317,6 +335,7 @@ def of(
style: PartitionStyle = PartitionStyle.HIVE,
base_dir: Optional[str] = None,
field_names: Optional[List[str]] = None,
field_types: Optional[Dict[str, PartitionDataType]] = None,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
) -> "PathPartitionFilter":
"""Creates a path-based partition filter using a flattened argument list.
Expand Down Expand Up @@ -358,12 +377,14 @@ def do_assert(val, msg):
partition key field names must match the order and length of partition
directories discovered. Partition key field names are not required to
exist in the dataset schema.
field_types: A dictionary that maps partition key names to their desired
data type. If not provided, the data type defaults to string.
filesystem: Filesystem that will be used for partition path file I/O.
Returns:
The new path-based partition filter.
"""
scheme = Partitioning(style, base_dir, field_names, filesystem)
scheme = Partitioning(style, base_dir, field_names, field_types, filesystem)
path_partition_parser = PathPartitionParser(scheme)
return PathPartitionFilter(path_partition_parser, filter_fn)

Expand Down Expand Up @@ -422,3 +443,14 @@ def __call__(self, paths: List[str]) -> List[str]:
def parser(self) -> PathPartitionParser:
"""Returns the path partition parser for this filter."""
return self._parser


def _cast_value(value: str, data_type: PartitionDataType) -> Any:
if data_type is int:
return int(value)
elif data_type is float:
return float(value)
elif data_type is bool:
return value.lower() == "true"
else:
return value
4 changes: 4 additions & 0 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,7 @@ def read_parquet(
tensor_column_schema: Optional[Dict[str, Tuple[np.dtype, Tuple[int, ...]]]] = None,
meta_provider: Optional[ParquetMetadataProvider] = None,
partition_filter: Optional[PathPartitionFilter] = None,
partitioning: Optional[Partitioning] = Partitioning("hive"),
shuffle: Union[Literal["files"], None] = None,
include_paths: bool = False,
file_extensions: Optional[List[str]] = None,
Expand Down Expand Up @@ -703,6 +704,8 @@ def read_parquet(
partition_filter: A
:class:`~ray.data.datasource.partitioning.PathPartitionFilter`. Use
with a custom callback to read only selected partitions of a dataset.
partitioning: A :class:`~ray.data.datasource.partitioning.Partitioning` object
that describes how paths are organized. Defaults to HIVE partitioning.
shuffle: If setting to "files", randomly shuffle input files order before read.
Defaults to not shuffle with ``None``.
arrow_parquet_args: Other parquet read options to pass to PyArrow. For the full
Expand Down Expand Up @@ -747,6 +750,7 @@ def read_parquet(
schema=schema,
meta_provider=meta_provider,
partition_filter=partition_filter,
partitioning=partitioning,
shuffle=shuffle,
include_paths=include_paths,
file_extensions=file_extensions,
Expand Down
Loading

0 comments on commit 1c80db5

Please sign in to comment.