Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datasets] Customized serializer for Arrow JSON ParseOptions in read_json #27911

Merged
merged 1 commit into from
Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/ray/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ray
from ray.data._internal.arrow_serialization import (
_register_arrow_json_parseoptions_serializer,
_register_arrow_json_readoptions_serializer,
)
from ray.data._internal.compute import ActorPoolStrategy
Expand Down Expand Up @@ -35,9 +36,11 @@
read_text,
)

# Register custom Arrow JSON ReadOptions serializer after worker has initialized.
# Register custom Arrow JSON ReadOptions and ParseOptions serializer after worker has
# initialized.
if ray.is_initialized():
_register_arrow_json_readoptions_serializer()
_register_arrow_json_parseoptions_serializer()
else:
pass
# ray._internal.worker._post_init_hooks.append(_register_arrow_json_readoptions_serializer)
Expand Down
38 changes: 37 additions & 1 deletion python/ray/data/_internal/arrow_serialization.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os

RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION = (
"RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION"
)


def _register_arrow_json_readoptions_serializer():
import ray

if (
os.environ.get(
"RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION",
RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION,
"0",
)
== "1"
Expand All @@ -27,3 +31,35 @@ def _register_arrow_json_readoptions_serializer():
serializer=lambda opts: (opts.use_threads, opts.block_size),
deserializer=lambda args: pajson.ReadOptions(*args),
)


def _register_arrow_json_parseoptions_serializer():
import ray

if (
os.environ.get(
RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION,
"0",
)
== "1"
):
import logging

logger = logging.getLogger(__name__)
logger.info("Disabling custom Arrow JSON ParseOptions serialization.")
return

try:
import pyarrow.json as pajson
except ModuleNotFoundError:
return

ray.util.register_serializer(
pajson.ParseOptions,
serializer=lambda opts: (
opts.explicit_schema,
opts.newlines_in_values,
opts.unexpected_field_behavior,
),
deserializer=lambda args: pajson.ParseOptions(*args),
)
15 changes: 11 additions & 4 deletions python/ray/data/datasource/file_based_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ray.data._internal.arrow_block import ArrowRow
from ray.data._internal.arrow_serialization import (
_register_arrow_json_parseoptions_serializer,
_register_arrow_json_readoptions_serializer,
)
from ray.data._internal.block_list import BlockMetadata
Expand Down Expand Up @@ -379,11 +380,14 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
read_stream = self._delegate._read_stream
filesystem = _wrap_s3_serialization_workaround(self._filesystem)
read_options = reader_args.get("read_options")
if read_options is not None:
parse_options = reader_args.get("parse_options")
if read_options is not None or parse_options is not None:
import pyarrow.json as pajson

if isinstance(read_options, pajson.ReadOptions):
_register_arrow_json_readoptions_serializer()
if isinstance(parse_options, pajson.ParseOptions):
_register_arrow_json_parseoptions_serializer()

if open_stream_args is None:
open_stream_args = {}
Expand Down Expand Up @@ -695,13 +699,16 @@ def _wrap_and_register_arrow_serialization_workaround(kwargs: dict) -> dict:
# TODO(Clark): Remove this serialization workaround once Datasets only supports
# pyarrow >= 8.0.0.
read_options = kwargs.get("read_options")
if read_options is not None:
parse_options = kwargs.get("parse_options")
if read_options is not None or parse_options is not None:
import pyarrow.json as pajson

# Register a custom serializer instead of wrapping the options, since a
# custom reducer will suffice.
if isinstance(read_options, pajson.ReadOptions):
# Register a custom serializer instead of wrapping the options, since a
# custom reducer will suffice.
_register_arrow_json_readoptions_serializer()
if isinstance(parse_options, pajson.ParseOptions):
_register_arrow_json_parseoptions_serializer()

return kwargs

Expand Down
43 changes: 43 additions & 0 deletions python/ray/data/tests/test_dataset_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -1981,6 +1981,49 @@ def test_json_read_with_read_options(
assert "{one: int64, two: string}" in str(ds), ds


@pytest.mark.parametrize(
"fs,data_path,endpoint_url",
[
(None, lazy_fixture("local_path"), None),
(lazy_fixture("local_fs"), lazy_fixture("local_path"), None),
(lazy_fixture("s3_fs"), lazy_fixture("s3_path"), lazy_fixture("s3_server")),
],
)
def test_json_read_with_parse_options(
ray_start_regular_shared,
fs,
data_path,
endpoint_url,
):
# Arrow's JSON ParseOptions isn't serializable in pyarrow < 8.0.0, so this test
# covers our custom ParseOptions serializer, similar to ReadOptions in above test.
# TODO(chengsu): Remove this test and our custom serializer once we require
# pyarrow >= 8.0.0.
if endpoint_url is None:
storage_options = {}
else:
storage_options = dict(client_kwargs=dict(endpoint_url=endpoint_url))

df1 = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})
path1 = os.path.join(data_path, "test1.json")
df1.to_json(path1, orient="records", lines=True, storage_options=storage_options)
ds = ray.data.read_json(
path1,
filesystem=fs,
parse_options=pajson.ParseOptions(
explicit_schema=pa.schema([("two", pa.string())]),
unexpected_field_behavior="ignore",
),
)
dsdf = ds.to_pandas()
assert len(dsdf.columns) == 1
assert (df1["two"]).equals(dsdf["two"])
# Test metadata ops.
assert ds.count() == 3
assert ds.input_files() == [_unwrap_protocol(path1)]
assert "{two: string}" in str(ds), ds


@pytest.mark.parametrize(
"fs,data_path,endpoint_url",
[
Expand Down