Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datasets] Force local metadata resolution when unserializable Partitioning object provided. #22477

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
WriteResult,
)
from ray.data.datasource.file_based_datasource import (
_wrap_s3_filesystem_workaround,
_unwrap_s3_filesystem_workaround,
_wrap_arrow_serialization_workaround,
_unwrap_arrow_serialization_workaround,
)
from ray.data.row import TableRow
from ray.data.aggregate import AggregateFn, Sum, Max, Min, Mean, Std
Expand Down Expand Up @@ -1868,7 +1868,7 @@ def write_datasource(self, datasource: Datasource[T], **write_args) -> None:
ctx,
blocks,
metadata,
_wrap_s3_filesystem_workaround(write_args),
_wrap_arrow_serialization_workaround(write_args),
)
)

Expand Down Expand Up @@ -2982,6 +2982,6 @@ def _do_write(
meta: List[BlockMetadata],
write_args: dict,
) -> List[ObjectRef[WriteResult]]:
write_args = _unwrap_s3_filesystem_workaround(write_args)
write_args = _unwrap_arrow_serialization_workaround(write_args)
DatasetContext._set_current(ctx)
return ds.do_write(blocks, meta, **write_args)
5 changes: 3 additions & 2 deletions python/ray/data/datasource/file_based_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,13 +511,14 @@ def __reduce__(self):
return _S3FileSystemWrapper._reconstruct, self._fs.__reduce__()


def _wrap_s3_filesystem_workaround(kwargs: dict) -> dict:
def _wrap_arrow_serialization_workaround(kwargs: dict) -> dict:
if "filesystem" in kwargs:
kwargs["filesystem"] = _wrap_s3_serialization_workaround(kwargs["filesystem"])

return kwargs


def _unwrap_s3_filesystem_workaround(kwargs: dict) -> dict:
def _unwrap_arrow_serialization_workaround(kwargs: dict) -> dict:
if isinstance(kwargs.get("filesystem"), _S3FileSystemWrapper):
kwargs["filesystem"] = kwargs["filesystem"].unwrap()
return kwargs
Expand Down
19 changes: 18 additions & 1 deletion python/ray/data/impl/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
import logging
from typing import List, Dict, Any
from typing import List, Dict, Any, Union
from types import ModuleType

from ray.remote_function import DEFAULT_REMOTE_FUNCTION_CPUS
import ray.ray_constants as ray_constants
Expand All @@ -11,6 +12,22 @@
_VERSION_VALIDATED = False


LazyModule = Union[None, bool, ModuleType]
_pyarrow: LazyModule = None


def _lazy_import_pyarrow() -> LazyModule:
global _pyarrow
if _pyarrow is None:
try:
import pyarrow as _pyarrow
except ModuleNotFoundError:
# If module is not found, set _pyarrow to False so we won't
# keep trying to import it on every _lazy_import_pyarrow() call.
_pyarrow = False
return _pyarrow


def _check_pyarrow_version():
global _VERSION_VALIDATED
if not _VERSION_VALIDATED:
Expand Down
22 changes: 15 additions & 7 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
ReadTask,
)
from ray.data.datasource.file_based_datasource import (
_wrap_s3_filesystem_workaround,
_unwrap_s3_filesystem_workaround,
_wrap_arrow_serialization_workaround,
_unwrap_arrow_serialization_workaround,
)
from ray.data.impl.delegating_block_builder import DelegatingBlockBuilder
from ray.data.impl.arrow_block import ArrowRow
Expand All @@ -58,7 +58,7 @@
from ray.data.impl.plan import ExecutionPlan
from ray.data.impl.remote_fn import cached_remote_fn
from ray.data.impl.stats import DatasetStats, get_or_create_stats_actor
from ray.data.impl.util import _get_spread_resources_iter
from ray.data.impl.util import _get_spread_resources_iter, _lazy_import_pyarrow

T = TypeVar("T")

Expand Down Expand Up @@ -205,9 +205,14 @@ def read_datasource(
Returns:
Dataset holding the data read from the datasource.
"""

# TODO(ekl) remove this feature flag.
if "RAY_DATASET_FORCE_LOCAL_METADATA" in os.environ:
force_local = "RAY_DATASET_FORCE_LOCAL_METADATA" in os.environ
pa = _lazy_import_pyarrow()
if pa:
partitioning = read_args.get("dataset_kwargs", {}).get("partitioning", None)
clarkzinzow marked this conversation as resolved.
Show resolved Hide resolved
force_local = force_local or isinstance(partitioning, pa.dataset.Partitioning)
clarkzinzow marked this conversation as resolved.
Show resolved Hide resolved

if force_local:
read_tasks = datasource.prepare_read(parallelism, **read_args)
else:
# Prepare read in a remote task so that in Ray client mode, we aren't
Expand All @@ -218,7 +223,10 @@ def read_datasource(
)
read_tasks = ray.get(
prepare_read.remote(
datasource, ctx, parallelism, _wrap_s3_filesystem_workaround(read_args)
datasource,
ctx,
parallelism,
_wrap_arrow_serialization_workaround(read_args),
)
)

Expand Down Expand Up @@ -873,6 +881,6 @@ def _get_metadata(table: Union["pyarrow.Table", "pandas.DataFrame"]) -> BlockMet
def _prepare_read(
ds: Datasource, ctx: DatasetContext, parallelism: int, kwargs: dict
) -> List[ReadTask]:
kwargs = _unwrap_s3_filesystem_workaround(kwargs)
kwargs = _unwrap_arrow_serialization_workaround(kwargs)
DatasetContext._set_current(ctx)
return ds.prepare_read(parallelism, **kwargs)
49 changes: 49 additions & 0 deletions python/ray/data/tests/test_dataset_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,55 @@ def test_parquet_read_partitioned_with_filter(ray_start_regular_shared, tmp_path
assert sorted(values) == [[1, "a"], [1, "a"]]


def test_parquet_read_partitioned_explicit(ray_start_regular_shared, tmp_path):
df = pd.DataFrame(
{"one": [1, 1, 1, 3, 3, 3], "two": ["a", "b", "c", "e", "f", "g"]}
)
table = pa.Table.from_pandas(df)
pq.write_to_dataset(
table,
root_path=str(tmp_path),
partition_cols=["one"],
use_legacy_dataset=False,
)

schema = pa.schema([("one", pa.int32()), ("two", pa.string())])
partitioning = pa.dataset.partitioning(schema, flavor="hive")

ds = ray.data.read_parquet(
str(tmp_path), dataset_kwargs=dict(partitioning=partitioning)
)

# Test metadata-only parquet ops.
assert ds._plan.execute()._num_computed() == 1
assert ds.count() == 6
assert ds.size_bytes() > 0
assert ds.schema() is not None
input_files = ds.input_files()
assert len(input_files) == 2, input_files
assert (
str(ds) == "Dataset(num_blocks=2, num_rows=6, "
"schema={two: string, one: int32})"
), ds
assert (
repr(ds) == "Dataset(num_blocks=2, num_rows=6, "
"schema={two: string, one: int32})"
), ds
assert ds._plan.execute()._num_computed() == 1

# Forces a data read.
values = [[s["one"], s["two"]] for s in ds.take()]
assert ds._plan.execute()._num_computed() == 2
assert sorted(values) == [
[1, "a"],
[1, "b"],
[1, "c"],
[3, "e"],
[3, "f"],
[3, "g"],
]


def test_parquet_read_with_udf(ray_start_regular_shared, tmp_path):
one_data = list(range(6))
df = pd.DataFrame({"one": one_data, "two": 2 * ["a"] + 2 * ["b"] + 2 * ["c"]})
Expand Down