diff --git a/python/ray/data/dataset_pipeline.py b/python/ray/data/dataset_pipeline.py index 059c10f73cf0..272d3c060476 100644 --- a/python/ray/data/dataset_pipeline.py +++ b/python/ray/data/dataset_pipeline.py @@ -824,11 +824,14 @@ def random_shuffle_each_window( *, seed: Optional[int] = None, num_blocks: Optional[int] = None, + **ray_remote_args, ) -> "DatasetPipeline[U]": """Apply :py:meth:`Dataset.random_shuffle ` to each dataset/window in this pipeline.""" return self.foreach_window( - lambda ds: ds.random_shuffle(seed=seed, num_blocks=num_blocks) + lambda ds: ds.random_shuffle( + seed=seed, num_blocks=num_blocks, **ray_remote_args + ) ) def sort_each_window( diff --git a/python/ray/data/tests/test_dataset_pipeline.py b/python/ray/data/tests/test_dataset_pipeline.py index c33c1602a691..202d599c90be 100644 --- a/python/ray/data/tests/test_dataset_pipeline.py +++ b/python/ray/data/tests/test_dataset_pipeline.py @@ -634,6 +634,33 @@ def test_drop_columns(ray_start_regular_shared): assert pipe.drop_columns(["col2"]).take(1) == [{"col1": 1, "col3": 3}] +def test_random_shuffle_each_window_with_custom_resource(ray_start_cluster): + ray.shutdown() + cluster = ray_start_cluster + # Create two nodes which have different custom resources. + cluster.add_node( + resources={"foo": 100}, + num_cpus=1, + ) + cluster.add_node(resources={"bar": 100}, num_cpus=1) + + ray.init(cluster.address) + + # Run pipeline in "bar" nodes. + pipe = ray.data.read_datasource( + ray.data.datasource.RangeDatasource(), + parallelism=10, + n=1000, + block_format="list", + ray_remote_args={"resources": {"bar": 1}}, + ).repeat(3) + pipe = pipe.random_shuffle_each_window(resources={"bar": 1}) + for batch in pipe.iter_batches(): + pass + assert "1 nodes used" in pipe.stats() + assert "2 nodes used" not in pipe.stats() + + def test_in_place_transformation_doesnt_clear_objects(ray_start_regular_shared): ds = ray.data.from_items([1, 2, 3, 4, 5, 6])