From e03bfb56295fb05dfdd44928aa482b2249c5fec2 Mon Sep 17 00:00:00 2001 From: Matthew Deng Date: Fri, 9 Aug 2024 17:03:23 -0700 Subject: [PATCH 1/4] [data] add validation for shuffle arg Signed-off-by: Matthew Deng --- .../datasource/parquet_datasource.py | 5 +++- .../data/datasource/file_based_datasource.py | 9 +++++++ .../data/tests/test_file_based_datasource.py | 24 +++++++++++++++++++ 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/python/ray/data/_internal/datasource/parquet_datasource.py b/python/ray/data/_internal/datasource/parquet_datasource.py index da050627cec2..a73b78ff7bcf 100644 --- a/python/ray/data/_internal/datasource/parquet_datasource.py +++ b/python/ray/data/_internal/datasource/parquet_datasource.py @@ -33,6 +33,7 @@ get_generic_metadata_provider, ) from ray.data.datasource.datasource import ReadTask +from ray.data.datasource.file_based_datasource import _validate_shuffle_arg from ray.data.datasource.file_meta_provider import _handle_read_os_error from ray.data.datasource.parquet_meta_provider import ParquetMetadataProvider from ray.data.datasource.partitioning import PathPartitionFilter @@ -278,8 +279,10 @@ def __init__( self._to_batches_kwargs = to_batch_kwargs self._columns = columns self._schema = schema - self._file_metadata_shuffler = None self._include_paths = include_paths + + _validate_shuffle_arg(shuffle) + self._file_metadata_shuffler = None if shuffle == "files": self._file_metadata_shuffler = np.random.default_rng() diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index 8e539aaa577d..fc92ca458067 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -150,6 +150,7 @@ def __init__( "'file_extensions' field is set properly." ) + _validate_shuffle_arg(shuffle) self._file_metadata_shuffler = None if shuffle == "files": self._file_metadata_shuffler = np.random.default_rng() @@ -519,3 +520,11 @@ def _open_file_with_retry( max_attempts=OPEN_FILE_MAX_ATTEMPTS, max_backoff_s=OPEN_FILE_RETRY_MAX_BACKOFF_SECONDS, ) + + +def _validate_shuffle_arg(shuffle: Optional[str]) -> None: + if shuffle not in [None, "files"]: + raise ValueError( + f"Invalid value for 'shuffle': {shuffle}. " + "Valid values are None, 'files'." + ) diff --git a/python/ray/data/tests/test_file_based_datasource.py b/python/ray/data/tests/test_file_based_datasource.py index 621150e148c6..77abea99bcab 100644 --- a/python/ray/data/tests/test_file_based_datasource.py +++ b/python/ray/data/tests/test_file_based_datasource.py @@ -118,6 +118,30 @@ def test_windows_path(): assert _is_local_windows_path("c:\\some\\where/mixed") +@pytest.mark.parametrize( + "shuffle, valid", + [ + (None, True), + ("files", True), + (True, False), + (False, False), + ("file", False), + ], +) +def test_shuffle_arg(ray_start_regular_shared, tmp_path, shuffle, valid): + + path = os.path.join(tmp_path, "test.txt") + with open(path, "w"): + pass + + if valid: + FileBasedDatasource(path, shuffle=shuffle) + + else: + with pytest.raises(ValueError): + FileBasedDatasource(path, shuffle=shuffle) + + if __name__ == "__main__": import sys From 994cccb892a586964055a0811a178ee7f50336c4 Mon Sep 17 00:00:00 2001 From: Matthew Deng Date: Mon, 12 Aug 2024 14:31:44 -0700 Subject: [PATCH 2/4] address comments Signed-off-by: Matthew Deng --- .../datasource/parquet_datasource.py | 4 +-- python/ray/data/read_api.py | 10 ++++++ .../data/tests/test_file_based_datasource.py | 34 ++++++++++++------- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/python/ray/data/_internal/datasource/parquet_datasource.py b/python/ray/data/_internal/datasource/parquet_datasource.py index a73b78ff7bcf..8d7aca6c6a2e 100644 --- a/python/ray/data/_internal/datasource/parquet_datasource.py +++ b/python/ray/data/_internal/datasource/parquet_datasource.py @@ -33,7 +33,6 @@ get_generic_metadata_provider, ) from ray.data.datasource.datasource import ReadTask -from ray.data.datasource.file_based_datasource import _validate_shuffle_arg from ray.data.datasource.file_meta_provider import _handle_read_os_error from ray.data.datasource.parquet_meta_provider import ParquetMetadataProvider from ray.data.datasource.partitioning import PathPartitionFilter @@ -279,10 +278,9 @@ def __init__( self._to_batches_kwargs = to_batch_kwargs self._columns = columns self._schema = schema + self._file_metadata_shuffler = None self._include_paths = include_paths - _validate_shuffle_arg(shuffle) - self._file_metadata_shuffler = None if shuffle == "files": self._file_metadata_shuffler = np.random.default_rng() diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index 082270bab596..388ef36edfc6 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -762,6 +762,8 @@ def read_parquet( **arrow_parquet_args, ) + _validate_shuffle_arg + dataset_kwargs = arrow_parquet_args.pop("dataset_kwargs", None) _block_udf = arrow_parquet_args.pop("_block_udf", None) schema = arrow_parquet_args.pop("schema", None) @@ -3157,3 +3159,11 @@ def _get_num_output_blocks( elif override_num_blocks is not None: parallelism = override_num_blocks return parallelism + + +def _validate_shuffle_arg(shuffle: Optional[str]) -> None: + if shuffle not in [None, "files"]: + raise ValueError( + f"Invalid value for 'shuffle': {shuffle}. " + "Valid values are None, 'files'." + ) diff --git a/python/ray/data/tests/test_file_based_datasource.py b/python/ray/data/tests/test_file_based_datasource.py index 77abea99bcab..e3e6dcbf7538 100644 --- a/python/ray/data/tests/test_file_based_datasource.py +++ b/python/ray/data/tests/test_file_based_datasource.py @@ -119,27 +119,35 @@ def test_windows_path(): @pytest.mark.parametrize( - "shuffle, valid", - [ - (None, True), - ("files", True), - (True, False), - (False, False), - ("file", False), - ], + "shuffle", + [True, False, "file"], ) -def test_shuffle_arg(ray_start_regular_shared, tmp_path, shuffle, valid): +def test_invalid_shuffle_arg_raises_error(ray_start_regular_shared, tmp_path, shuffle): path = os.path.join(tmp_path, "test.txt") with open(path, "w"): pass - if valid: + with pytest.raises(ValueError): FileBasedDatasource(path, shuffle=shuffle) - else: - with pytest.raises(ValueError): - FileBasedDatasource(path, shuffle=shuffle) + +@pytest.mark.parametrize( + "shuffle", + [ + None, + "files", + ], +) +def test_valid_shuffle_arg_does_not_raise_error( + ray_start_regular_shared, tmp_path, shuffle +): + + path = os.path.join(tmp_path, "test.txt") + with open(path, "w"): + pass + + FileBasedDatasource(path, shuffle=shuffle) if __name__ == "__main__": From 0b33e1076175dc90d74dab881d21e77b9a7a9012 Mon Sep 17 00:00:00 2001 From: Matthew Deng Date: Mon, 12 Aug 2024 15:18:30 -0700 Subject: [PATCH 3/4] update tests Signed-off-by: Matthew Deng --- python/ray/data/read_api.py | 4 +-- .../data/tests/test_file_based_datasource.py | 33 ++++--------------- python/ray/data/tests/test_parquet.py | 12 +++++++ 3 files changed, 20 insertions(+), 29 deletions(-) diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index 388ef36edfc6..869bf8865cb0 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -755,6 +755,8 @@ def read_parquet( :class:`~ray.data.Dataset` producing records read from the specified parquet files. """ + _validate_shuffle_arg(shuffle) + if meta_provider is None: meta_provider = get_parquet_metadata_provider(override_num_blocks) arrow_parquet_args = _resolve_parquet_args( @@ -762,8 +764,6 @@ def read_parquet( **arrow_parquet_args, ) - _validate_shuffle_arg - dataset_kwargs = arrow_parquet_args.pop("dataset_kwargs", None) _block_udf = arrow_parquet_args.pop("_block_udf", None) schema = arrow_parquet_args.pop("schema", None) diff --git a/python/ray/data/tests/test_file_based_datasource.py b/python/ray/data/tests/test_file_based_datasource.py index e3e6dcbf7538..ca7c4ff3325a 100644 --- a/python/ray/data/tests/test_file_based_datasource.py +++ b/python/ray/data/tests/test_file_based_datasource.py @@ -118,36 +118,15 @@ def test_windows_path(): assert _is_local_windows_path("c:\\some\\where/mixed") -@pytest.mark.parametrize( - "shuffle", - [True, False, "file"], -) -def test_invalid_shuffle_arg_raises_error(ray_start_regular_shared, tmp_path, shuffle): - - path = os.path.join(tmp_path, "test.txt") - with open(path, "w"): - pass - +@pytest.mark.parametrize("shuffle", [True, False, "file"]) +def test_invalid_shuffle_arg_raises_error(ray_start_regular_shared, shuffle): with pytest.raises(ValueError): - FileBasedDatasource(path, shuffle=shuffle) + FileBasedDatasource("example://iris.csv", shuffle=shuffle) -@pytest.mark.parametrize( - "shuffle", - [ - None, - "files", - ], -) -def test_valid_shuffle_arg_does_not_raise_error( - ray_start_regular_shared, tmp_path, shuffle -): - - path = os.path.join(tmp_path, "test.txt") - with open(path, "w"): - pass - - FileBasedDatasource(path, shuffle=shuffle) +@pytest.mark.parametrize("shuffle", [None, "files"]) +def test_valid_shuffle_arg_does_not_raise_error(ray_start_regular_shared, shuffle): + FileBasedDatasource("example://iris.csv", shuffle=shuffle) if __name__ == "__main__": diff --git a/python/ray/data/tests/test_parquet.py b/python/ray/data/tests/test_parquet.py index bec6e1699a97..bbfd6ba69bfc 100644 --- a/python/ray/data/tests/test_parquet.py +++ b/python/ray/data/tests/test_parquet.py @@ -1179,6 +1179,18 @@ def test_write_num_rows_per_file(tmp_path, ray_start_regular_shared, num_rows_pe assert len(table) == num_rows_per_file +@pytest.mark.parametrize("shuffle", [True, False, "file"]) +def test_invalid_shuffle_arg_raises_error(ray_start_regular_shared, shuffle): + + with pytest.raises(ValueError): + ray.data.read_parquet("example://iris.parquet", shuffle=shuffle) + + +@pytest.mark.parametrize("shuffle", [None, "files"]) +def test_valid_shuffle_arg_does_not_raise_error(ray_start_regular_shared, shuffle): + ray.data.read_parquet("example://iris.parquet", shuffle=shuffle) + + if __name__ == "__main__": import sys From 1e3ff25d471816e3c5c08c81d8bb1434399ad98c Mon Sep 17 00:00:00 2001 From: Matthew Deng Date: Mon, 12 Aug 2024 15:20:54 -0700 Subject: [PATCH 4/4] newline Signed-off-by: Matthew Deng --- python/ray/data/_internal/datasource/parquet_datasource.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/ray/data/_internal/datasource/parquet_datasource.py b/python/ray/data/_internal/datasource/parquet_datasource.py index 8d7aca6c6a2e..da050627cec2 100644 --- a/python/ray/data/_internal/datasource/parquet_datasource.py +++ b/python/ray/data/_internal/datasource/parquet_datasource.py @@ -280,7 +280,6 @@ def __init__( self._schema = schema self._file_metadata_shuffler = None self._include_paths = include_paths - if shuffle == "files": self._file_metadata_shuffler = np.random.default_rng()