Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Rllib] Add timeout to filter synchronization #25959

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,12 @@ def step(self) -> ResultDict:

if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
# Sync filters on workers.
self._sync_filters_if_needed(self.workers)
self._sync_filters_if_needed(
self.workers,
timeout_seconds=self.config[
"sync_filters_on_rollout_workers_timeout_s"
],
)

# Collect worker metrics.
if self.config["_disable_execution_plan_api"]:
Expand Down Expand Up @@ -672,7 +677,12 @@ def evaluate(
self.evaluation_workers.sync_weights(
from_worker=self.workers.local_worker()
)
self._sync_filters_if_needed(self.evaluation_workers)
self._sync_filters_if_needed(
self.evaluation_workers,
timeout_seconds=self.config[
"sync_filters_on_rollout_workers_timeout_s"
],
)

if self.config["custom_eval_function"]:
logger.info(
Expand Down Expand Up @@ -1594,12 +1604,13 @@ def _is_multi_agent(self):
'(e.g., YourEnvCls) or a registered env id (e.g., "your_env").'
)

def _sync_filters_if_needed(self, workers: WorkerSet):
def _sync_filters_if_needed(self, workers: WorkerSet, timeout_seconds: int = None):
ArturNiederfahrenhorst marked this conversation as resolved.
Show resolved Hide resolved
if self.config.get("observation_filter", "NoFilter") != "NoFilter":
FilterManager.synchronize(
workers.local_worker().filters,
workers.remote_workers(),
update_remote=self.config["synchronize_filters"],
timeout_seconds=timeout_seconds,
)
logger.debug(
"synchronized filters: {}".format(workers.local_worker().filters)
Expand Down Expand Up @@ -2270,6 +2281,7 @@ def _run_one_training_iteration(self) -> Tuple[ResultDict, "TrainIterCtx"]:
ignore=self.config["ignore_worker_failures"],
recreate=self.config["recreate_failed_workers"],
)

return results, train_iter_ctx

def _run_one_evaluation(
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def __init__(self, algo_class=None):
# TODO: Set this flag still in the config or - much better - in the
# RolloutWorker as a property.
self.in_evaluation = False
self.sync_filters_on_rollout_workers_timeout_s = 60
ArturNiederfahrenhorst marked this conversation as resolved.
Show resolved Hide resolved

# `self.reporting()`
self.keep_per_episode_custom_metrics = False
Expand Down
6 changes: 1 addition & 5 deletions rllib/evaluation/tests/test_trajectory_view_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ def test_traj_view_normal_case(self):
policy = algo.get_policy()
view_req_model = policy.model.view_requirements
view_req_policy = policy.view_requirements
print(_)
print(view_req_policy)
print(view_req_model)
assert len(view_req_model) == 1, view_req_model
assert len(view_req_policy) == 11, view_req_policy
for key in [
Expand Down Expand Up @@ -321,8 +318,7 @@ def policy_fn(agent_id, episode, **kwargs):
normalize_actions=False,
num_envs=1,
)
batch = rollout_worker_w_api.sample()
print(batch)
batch = rollout_worker_w_api.sample() # noqa: F841

def test_counting_by_agent_steps(self):
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
Expand Down
40 changes: 35 additions & 5 deletions rllib/utils/filter_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import logging

import ray
from ray.rllib.utils.annotations import DeveloperAPI
from ray.exceptions import GetTimeoutError

logger = logging.getLogger(__name__)


@DeveloperAPI
Expand All @@ -10,7 +15,9 @@ class FilterManager:

@staticmethod
@DeveloperAPI
def synchronize(local_filters, remotes, update_remote=True):
def synchronize(
local_filters, remotes, update_remote=True, timeout_seconds: int = None
ArturNiederfahrenhorst marked this conversation as resolved.
Show resolved Hide resolved
):
"""Aggregates all filters from remote evaluators.

Local copy is updated and then broadcasted to all remote evaluators.
Expand All @@ -19,14 +26,37 @@ def synchronize(local_filters, remotes, update_remote=True):
local_filters: Filters to be synchronized.
remotes: Remote evaluators with filters.
update_remote: Whether to push updates to remote filters.
timeout_seconds: How long to wait for filter to get or set filters
"""
remote_filters = ray.get(
[r.get_filters.remote(flush_after=True) for r in remotes]
)
try:
remote_filters = ray.get(
[r.get_filters.remote(flush_after=True) for r in remotes],
timeout=timeout_seconds,
)
except GetTimeoutError:
logger.error(
"Failed to get remote filters from a rollout worker in "
"FilterManager. "
"Filtered "
"metrics may be computed, but filtered wrong."
)

for rf in remote_filters:
for k in local_filters:
local_filters[k].apply_changes(rf[k], with_buffer=False)
if update_remote:
copies = {k: v.as_serializable() for k, v in local_filters.items()}
remote_copy = ray.put(copies)
[r.sync_filters.remote(remote_copy) for r in remotes]

try:
ray.get(
[r.sync_filters.remote(remote_copy) for r in remotes],
timeout=timeout_seconds,
)
except GetTimeoutError:
logger.error(
"Failed to set remote filters to a rollout worker in "
"FilterManager. "
"Filtered "
"metrics may be computed, but filtered wrong."
)