diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index 8e539aaa577d..fc92ca458067 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -150,6 +150,7 @@ def __init__( "'file_extensions' field is set properly." ) + _validate_shuffle_arg(shuffle) self._file_metadata_shuffler = None if shuffle == "files": self._file_metadata_shuffler = np.random.default_rng() @@ -519,3 +520,11 @@ def _open_file_with_retry( max_attempts=OPEN_FILE_MAX_ATTEMPTS, max_backoff_s=OPEN_FILE_RETRY_MAX_BACKOFF_SECONDS, ) + + +def _validate_shuffle_arg(shuffle: Optional[str]) -> None: + if shuffle not in [None, "files"]: + raise ValueError( + f"Invalid value for 'shuffle': {shuffle}. " + "Valid values are None, 'files'." + ) diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index 082270bab596..869bf8865cb0 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -755,6 +755,8 @@ def read_parquet( :class:`~ray.data.Dataset` producing records read from the specified parquet files. """ + _validate_shuffle_arg(shuffle) + if meta_provider is None: meta_provider = get_parquet_metadata_provider(override_num_blocks) arrow_parquet_args = _resolve_parquet_args( @@ -3157,3 +3159,11 @@ def _get_num_output_blocks( elif override_num_blocks is not None: parallelism = override_num_blocks return parallelism + + +def _validate_shuffle_arg(shuffle: Optional[str]) -> None: + if shuffle not in [None, "files"]: + raise ValueError( + f"Invalid value for 'shuffle': {shuffle}. " + "Valid values are None, 'files'." + ) diff --git a/python/ray/data/tests/test_file_based_datasource.py b/python/ray/data/tests/test_file_based_datasource.py index 621150e148c6..ca7c4ff3325a 100644 --- a/python/ray/data/tests/test_file_based_datasource.py +++ b/python/ray/data/tests/test_file_based_datasource.py @@ -118,6 +118,17 @@ def test_windows_path(): assert _is_local_windows_path("c:\\some\\where/mixed") +@pytest.mark.parametrize("shuffle", [True, False, "file"]) +def test_invalid_shuffle_arg_raises_error(ray_start_regular_shared, shuffle): + with pytest.raises(ValueError): + FileBasedDatasource("example://iris.csv", shuffle=shuffle) + + +@pytest.mark.parametrize("shuffle", [None, "files"]) +def test_valid_shuffle_arg_does_not_raise_error(ray_start_regular_shared, shuffle): + FileBasedDatasource("example://iris.csv", shuffle=shuffle) + + if __name__ == "__main__": import sys diff --git a/python/ray/data/tests/test_parquet.py b/python/ray/data/tests/test_parquet.py index bec6e1699a97..bbfd6ba69bfc 100644 --- a/python/ray/data/tests/test_parquet.py +++ b/python/ray/data/tests/test_parquet.py @@ -1179,6 +1179,18 @@ def test_write_num_rows_per_file(tmp_path, ray_start_regular_shared, num_rows_pe assert len(table) == num_rows_per_file +@pytest.mark.parametrize("shuffle", [True, False, "file"]) +def test_invalid_shuffle_arg_raises_error(ray_start_regular_shared, shuffle): + + with pytest.raises(ValueError): + ray.data.read_parquet("example://iris.parquet", shuffle=shuffle) + + +@pytest.mark.parametrize("shuffle", [None, "files"]) +def test_valid_shuffle_arg_does_not_raise_error(ray_start_regular_shared, shuffle): + ray.data.read_parquet("example://iris.parquet", shuffle=shuffle) + + if __name__ == "__main__": import sys