Skip to content

Commit

Permalink
[tune/execution] Update staged resources in a fixed counter for faste…
Browse files Browse the repository at this point in the history
…r lookup (#32087)

In #30016 we migrated Ray Tune to use a new resource management interface. In the same PR, we simplified the resource consolidation logic. This lead to a performance regression first identified in #31337.

After manual profiling, the regression seems to come from `RayTrialExecutor._count_staged_resources`. We have 1000 staged trials, and this function is called on every step, executing a linear scan through all trials.

This PR fixes this performance bottleneck by keeping state of the resource counter instead of dynamically recreating it every time. This is simple as we can just add/subtract the resources whenever we add/remove from the `RayTrialExecutor._staged_trials` set.

Manual testing confirmed this improves the runtime of `tune_scalability_result_throughput_cluster` from ~132 seconds to ~122 seconds, bringing it back to the same level as before the refactor.

Signed-off-by: Kai Fricke <[email protected]>
  • Loading branch information
krfricke authored Jan 31, 2023
1 parent 7573d49 commit 10d52f7
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 14 deletions.
12 changes: 6 additions & 6 deletions python/ray/tune/execution/ray_trial_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ def __init__(
] = defaultdict(list)

# Trials for which we requested resources
self._staged_trials = set()
self._staged_trials = set() # Staged trials
self._staged_resources = Counter() # Resources of staged trials
self._trial_to_acquired_resources: Dict[Trial, AcquiredResources] = {}

# Result buffer
Expand Down Expand Up @@ -319,6 +320,7 @@ def _stage_and_update_status(self, trials: Iterable[Trial]):
resource_request = trial.placement_group_factory

self._staged_trials.add(trial)
self._staged_resources[trial.placement_group_factory] += 1
self._resource_manager.request_resources(resource_request=resource_request)

self._resource_manager.update_state()
Expand Down Expand Up @@ -533,6 +535,7 @@ def _unstage_trial_with_resources(self, trial: Trial):
# Case 1: The trial we started was staged. Just remove it
if trial in self._staged_trials:
self._staged_trials.remove(trial)
self._staged_resources[trial.placement_group_factory] -= 1
return

# Case 2: We staged a trial "A" with the same resources, but our trial "B"
Expand All @@ -551,6 +554,7 @@ def _unstage_trial_with_resources(self, trial: Trial):

if candidate_trial:
self._staged_trials.remove(candidate_trial)
self._staged_resources[candidate_trial.placement_group_factory] -= 1
return

raise RuntimeError(
Expand Down Expand Up @@ -848,11 +852,7 @@ def on_step_end(self, search_ended: bool = False) -> None:
self._do_force_trial_cleanup()

def _count_staged_resources(self):
counter = Counter()
for trial in self._staged_trials:
resource_request = trial.placement_group_factory
counter[resource_request] += 1
return counter
return self._staged_resources

def _cleanup_cached_actors(
self, search_ended: bool = False, force_all: bool = False
Expand Down
2 changes: 1 addition & 1 deletion release/ray_release/alerts/tune_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def handle_result(
target_time = 900 if not was_smoke_test else 400
elif test_name == "result_throughput_cluster":
target_terminated = 1000
target_time = 160
target_time = 130
elif test_name == "result_throughput_single_node":
target_terminated = 96
target_time = 120
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
Test owner: krfricke
Acceptance criteria: Should run faster than 160 seconds.
Acceptance criteria: Should run faster than 130 seconds.
Theoretical minimum time: 100 seconds
"""
import os

import ray
from ray import tune
from ray.tune.execution.cluster_info import _is_ray_cluster

from ray.tune.utils.release_test_util import timed_tune_run

Expand All @@ -31,11 +30,7 @@ def main():
results_per_second = 0.5
trial_length_s = 100

max_runtime = 160

if _is_ray_cluster():
# Add constant overhead for SSH connection
max_runtime = 160
max_runtime = 130

timed_tune_run(
name="result throughput cluster",
Expand Down

0 comments on commit 10d52f7

Please sign in to comment.