diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index fc2a363166ab..c4cf138865be 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -612,7 +612,12 @@ def step(self) -> ResultDict: if hasattr(self, "workers") and isinstance(self.workers, WorkerSet): # Sync filters on workers. - self._sync_filters_if_needed(self.workers) + self._sync_filters_if_needed( + self.workers, + timeout_seconds=self.config[ + "sync_filters_on_rollout_workers_timeout_s" + ], + ) # Collect worker metrics and add combine them with `results`. if self.config["_disable_execution_plan_api"]: @@ -674,7 +679,12 @@ def evaluate( self.evaluation_workers.sync_weights( from_worker=self.workers.local_worker() ) - self._sync_filters_if_needed(self.evaluation_workers) + self._sync_filters_if_needed( + self.evaluation_workers, + timeout_seconds=self.config[ + "sync_filters_on_rollout_workers_timeout_s" + ], + ) if self.config["custom_eval_function"]: logger.info( @@ -1597,12 +1607,15 @@ def _is_multi_agent(self): '(e.g., YourEnvCls) or a registered env id (e.g., "your_env").' ) - def _sync_filters_if_needed(self, workers: WorkerSet): + def _sync_filters_if_needed( + self, workers: WorkerSet, timeout_seconds: Optional[float] = None + ): if self.config.get("observation_filter", "NoFilter") != "NoFilter": FilterManager.synchronize( workers.local_worker().filters, workers.remote_workers(), update_remote=self.config["synchronize_filters"], + timeout_seconds=timeout_seconds, ) logger.debug( "synchronized filters: {}".format(workers.local_worker().filters) @@ -2273,6 +2286,7 @@ def _run_one_training_iteration(self) -> Tuple[ResultDict, "TrainIterCtx"]: ignore=self.config["ignore_worker_failures"], recreate=self.config["recreate_failed_workers"], ) + return results, train_iter_ctx def _run_one_evaluation( diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index f231b14427ca..f41580e4ac08 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -189,6 +189,7 @@ def __init__(self, algo_class=None): # TODO: Set this flag still in the config or - much better - in the # RolloutWorker as a property. self.in_evaluation = False + self.sync_filters_on_rollout_workers_timeout_s = 60.0 # `self.reporting()` self.keep_per_episode_custom_metrics = False diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index 0980656191b7..d9b1b0997c71 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -64,9 +64,6 @@ def test_traj_view_normal_case(self): policy = algo.get_policy() view_req_model = policy.model.view_requirements view_req_policy = policy.view_requirements - print(_) - print(view_req_policy) - print(view_req_model) assert len(view_req_model) == 1, view_req_model assert len(view_req_policy) == 11, view_req_policy for key in [ @@ -321,8 +318,7 @@ def policy_fn(agent_id, episode, **kwargs): normalize_actions=False, num_envs=1, ) - batch = rollout_worker_w_api.sample() - print(batch) + batch = rollout_worker_w_api.sample() # noqa: F841 def test_counting_by_agent_steps(self): config = copy.deepcopy(ppo.DEFAULT_CONFIG) diff --git a/rllib/utils/filter_manager.py b/rllib/utils/filter_manager.py index 9374b7855f52..939b611fa238 100644 --- a/rllib/utils/filter_manager.py +++ b/rllib/utils/filter_manager.py @@ -1,5 +1,11 @@ +import logging +from typing import Optional + import ray from ray.rllib.utils.annotations import DeveloperAPI +from ray.exceptions import GetTimeoutError + +logger = logging.getLogger(__name__) @DeveloperAPI @@ -10,7 +16,12 @@ class FilterManager: @staticmethod @DeveloperAPI - def synchronize(local_filters, remotes, update_remote=True): + def synchronize( + local_filters, + remotes, + update_remote=True, + timeout_seconds: Optional[float] = None, + ): """Aggregates all filters from remote evaluators. Local copy is updated and then broadcasted to all remote evaluators. @@ -19,14 +30,37 @@ def synchronize(local_filters, remotes, update_remote=True): local_filters: Filters to be synchronized. remotes: Remote evaluators with filters. update_remote: Whether to push updates to remote filters. + timeout_seconds: How long to wait for filter to get or set filters """ - remote_filters = ray.get( - [r.get_filters.remote(flush_after=True) for r in remotes] - ) + try: + remote_filters = ray.get( + [r.get_filters.remote(flush_after=True) for r in remotes], + timeout=timeout_seconds, + ) + except GetTimeoutError: + logger.error( + "Failed to get remote filters from a rollout worker in " + "FilterManager. " + "Filtered " + "metrics may be computed, but filtered wrong." + ) + for rf in remote_filters: for k in local_filters: local_filters[k].apply_changes(rf[k], with_buffer=False) if update_remote: copies = {k: v.as_serializable() for k, v in local_filters.items()} remote_copy = ray.put(copies) - [r.sync_filters.remote(remote_copy) for r in remotes] + + try: + ray.get( + [r.sync_filters.remote(remote_copy) for r in remotes], + timeout=timeout_seconds, + ) + except GetTimeoutError: + logger.error( + "Failed to set remote filters to a rollout worker in " + "FilterManager. " + "Filtered " + "metrics may be computed, but filtered wrong." + )