From 06f5d4a143dfb3103ad310385147878f1f5fe88d Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Mon, 29 Jan 2024 15:13:57 -0800 Subject: [PATCH] Address comments Signed-off-by: Cheng Su --- .../concurrency_cap_backpressure_policy.py | 9 +++------ python/ray/data/tests/test_backpressure_policies.py | 4 ++-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py index e1b307c3e6d1..a52bd1f6ab9f 100644 --- a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py +++ b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py @@ -18,14 +18,11 @@ class ConcurrencyCapBackpressurePolicy(BackpressurePolicy): """A backpressure policy that caps the concurrency of each operator. - The concurrency cap limits the number of concurrently running tasks. - - The concrete stategy is as follows: - - Each PhysicalOperator is assigned a concurrency cap. - - An PhysicalOperator can run new tasks if the number of running tasks is less - than the cap. + The policy will limit the number of concurrently running tasks based on its + concurrency cap parameter. NOTE: Only support setting concurrency cap for `TaskPoolMapOperator` for now. + TODO(chengsu): Consolidate with actor scaling logic of `ActorPoolMapOperator`. """ def __init__(self, topology: "Topology"): diff --git a/python/ray/data/tests/test_backpressure_policies.py b/python/ray/data/tests/test_backpressure_policies.py index 9df247454713..e852e182f414 100644 --- a/python/ray/data/tests/test_backpressure_policies.py +++ b/python/ray/data/tests/test_backpressure_policies.py @@ -125,8 +125,8 @@ def test_e2e_normal(self): N = self.__class__._cluster_cpus ds = ray.data.range(N, parallelism=N) # Use different `num_cpus` to make sure they don't fuse. - ds = ds.map_batches(map_func1, batch_size=None, num_cpus=1) - ds = ds.map_batches(map_func2, batch_size=None, num_cpus=1.1) + ds = ds.map_batches(map_func1, batch_size=None, num_cpus=1, concurrency=1) + ds = ds.map_batches(map_func2, batch_size=None, num_cpus=1.1, concurrency=1) res = ds.take_all() self.assertEqual(len(res), N)