Skip to content

Commit

Permalink
[RLlib] Add timeout to filter synchronization. (#25959)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArturNiederfahrenhorst authored Jun 24, 2022
1 parent eee866d commit bed9083
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 13 deletions.
20 changes: 17 additions & 3 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,12 @@ def step(self) -> ResultDict:

if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
# Sync filters on workers.
self._sync_filters_if_needed(self.workers)
self._sync_filters_if_needed(
self.workers,
timeout_seconds=self.config[
"sync_filters_on_rollout_workers_timeout_s"
],
)

# Collect worker metrics and add combine them with `results`.
if self.config["_disable_execution_plan_api"]:
Expand Down Expand Up @@ -674,7 +679,12 @@ def evaluate(
self.evaluation_workers.sync_weights(
from_worker=self.workers.local_worker()
)
self._sync_filters_if_needed(self.evaluation_workers)
self._sync_filters_if_needed(
self.evaluation_workers,
timeout_seconds=self.config[
"sync_filters_on_rollout_workers_timeout_s"
],
)

if self.config["custom_eval_function"]:
logger.info(
Expand Down Expand Up @@ -1597,12 +1607,15 @@ def _is_multi_agent(self):
'(e.g., YourEnvCls) or a registered env id (e.g., "your_env").'
)

def _sync_filters_if_needed(self, workers: WorkerSet):
def _sync_filters_if_needed(
self, workers: WorkerSet, timeout_seconds: Optional[float] = None
):
if self.config.get("observation_filter", "NoFilter") != "NoFilter":
FilterManager.synchronize(
workers.local_worker().filters,
workers.remote_workers(),
update_remote=self.config["synchronize_filters"],
timeout_seconds=timeout_seconds,
)
logger.debug(
"synchronized filters: {}".format(workers.local_worker().filters)
Expand Down Expand Up @@ -2273,6 +2286,7 @@ def _run_one_training_iteration(self) -> Tuple[ResultDict, "TrainIterCtx"]:
ignore=self.config["ignore_worker_failures"],
recreate=self.config["recreate_failed_workers"],
)

return results, train_iter_ctx

def _run_one_evaluation(
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def __init__(self, algo_class=None):
# TODO: Set this flag still in the config or - much better - in the
# RolloutWorker as a property.
self.in_evaluation = False
self.sync_filters_on_rollout_workers_timeout_s = 60.0

# `self.reporting()`
self.keep_per_episode_custom_metrics = False
Expand Down
6 changes: 1 addition & 5 deletions rllib/evaluation/tests/test_trajectory_view_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ def test_traj_view_normal_case(self):
policy = algo.get_policy()
view_req_model = policy.model.view_requirements
view_req_policy = policy.view_requirements
print(_)
print(view_req_policy)
print(view_req_model)
assert len(view_req_model) == 1, view_req_model
assert len(view_req_policy) == 11, view_req_policy
for key in [
Expand Down Expand Up @@ -321,8 +318,7 @@ def policy_fn(agent_id, episode, **kwargs):
normalize_actions=False,
num_envs=1,
)
batch = rollout_worker_w_api.sample()
print(batch)
batch = rollout_worker_w_api.sample() # noqa: F841

def test_counting_by_agent_steps(self):
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
Expand Down
44 changes: 39 additions & 5 deletions rllib/utils/filter_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import logging
from typing import Optional

import ray
from ray.rllib.utils.annotations import DeveloperAPI
from ray.exceptions import GetTimeoutError

logger = logging.getLogger(__name__)


@DeveloperAPI
Expand All @@ -10,7 +16,12 @@ class FilterManager:

@staticmethod
@DeveloperAPI
def synchronize(local_filters, remotes, update_remote=True):
def synchronize(
local_filters,
remotes,
update_remote=True,
timeout_seconds: Optional[float] = None,
):
"""Aggregates all filters from remote evaluators.
Local copy is updated and then broadcasted to all remote evaluators.
Expand All @@ -19,14 +30,37 @@ def synchronize(local_filters, remotes, update_remote=True):
local_filters: Filters to be synchronized.
remotes: Remote evaluators with filters.
update_remote: Whether to push updates to remote filters.
timeout_seconds: How long to wait for filter to get or set filters
"""
remote_filters = ray.get(
[r.get_filters.remote(flush_after=True) for r in remotes]
)
try:
remote_filters = ray.get(
[r.get_filters.remote(flush_after=True) for r in remotes],
timeout=timeout_seconds,
)
except GetTimeoutError:
logger.error(
"Failed to get remote filters from a rollout worker in "
"FilterManager. "
"Filtered "
"metrics may be computed, but filtered wrong."
)

for rf in remote_filters:
for k in local_filters:
local_filters[k].apply_changes(rf[k], with_buffer=False)
if update_remote:
copies = {k: v.as_serializable() for k, v in local_filters.items()}
remote_copy = ray.put(copies)
[r.sync_filters.remote(remote_copy) for r in remotes]

try:
ray.get(
[r.sync_filters.remote(remote_copy) for r in remotes],
timeout=timeout_seconds,
)
except GetTimeoutError:
logger.error(
"Failed to set remote filters to a rollout worker in "
"FilterManager. "
"Filtered "
"metrics may be computed, but filtered wrong."
)

0 comments on commit bed9083

Please sign in to comment.