From 15f150d02b1585fa6c1d84d32348059fc74981b1 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Tue, 28 Mar 2023 00:31:10 +0200 Subject: [PATCH] [RLlib] Fix: Recovered eval worker should use eval-config's policy_mapping_fn and policy_to_train fn, not the main train workers' ones. (#33648) Signed-off-by: sven1977 Signed-off-by: Kourosh Hakhamaneshi --- rllib/algorithms/algorithm.py | 56 ++++------ rllib/algorithms/tests/test_algorithm.py | 10 +- .../algorithms/tests/test_worker_failures.py | 100 +++++++++++++----- rllib/evaluation/worker_set.py | 10 +- rllib/examples/custom_eval.py | 20 ++-- rllib/utils/actor_manager.py | 8 +- 6 files changed, 120 insertions(+), 84 deletions(-) diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index fcbffd0f2426..f8729e639bde 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -618,8 +618,6 @@ def setup(self, config: AlgorithmConfig) -> None: default_policy_class=self.get_default_policy_class(self.config), config=self.evaluation_config, num_workers=self.config.evaluation_num_workers, - # Don't even create a local worker if num_workers > 0. - local_worker=False, logdir=self.logdir, ) @@ -1276,14 +1274,14 @@ def restore_workers(self, workers: WorkerSet): workers: The WorkerSet to restore. This may be Rollout or Evaluation workers. """ + # If `workers` is None, or + # 1. `workers` (WorkerSet) does not have a local worker, and + # 2. `self.workers` (WorkerSet used for training) does not have a local worker + # -> we don't have a local worker to get state from, so we can't recover + # remote worker in this case. if not workers or ( not workers.local_worker() and not self.workers.local_worker() ): - # If workers does not exist, or - # 1. this WorkerSet does not have a local worker, and - # 2. self.workers (rollout worker set) does not have a local worker, - # we don't have a local worker to get state from. - # We can't recover remote worker in this case. return # This is really cheap, since probe_unhealthy_workers() is a no-op @@ -1292,12 +1290,16 @@ def restore_workers(self, workers: WorkerSet): if restored: from_worker = workers.local_worker() or self.workers.local_worker() - state = ray.put(from_worker.get_state()) + # Get the state of the correct (reference) worker. E.g. The local worker + # of the main WorkerSet. + state_ref = ray.put(from_worker.get_state()) + # By default, entire local worker state is synced after restoration # to bring these workers up to date. workers.foreach_worker( - func=lambda w: w.set_state(ray.get(state)), + func=lambda w: w.set_state(ray.get(state_ref)), remote_worker_ids=restored, + # Don't update the local_worker, b/c it's the one we are synching from. local_worker=False, timeout_seconds=self.config.worker_restore_timeout_s, # Bring back actor after successful state syncing. @@ -1836,12 +1838,12 @@ def add_policy( if workers is not DEPRECATED_VALUE: deprecation_warning( - old="workers", + old="Algorithm.add_policy(.., workers=..)", help=( - "The `workers` argument to `Algorithm.add_policy()` is deprecated " - "and no-op now. Please do not use it anymore." + "The `workers` argument to `Algorithm.add_policy()` is deprecated! " + "Please do not use it anymore." ), - error=False, + error=True, ) self.workers.add_policy( @@ -2552,7 +2554,6 @@ def __setstate__(self, state) -> None: # there in case they are used for evaluation purpose. self.evaluation_workers.foreach_worker( lambda w: w.set_state(ray.get(remote_state)), - local_worker=False, healthy_only=False, ) # If necessary, restore replay data as well. @@ -3023,29 +3024,12 @@ def _record_usage(self, config): def compute_action(self, *args, **kwargs): return self.compute_single_action(*args, **kwargs) - @Deprecated(new="construct WorkerSet(...) instance directly", error=False) - def _make_workers( - self, - *, - env_creator: EnvCreator, - validate_env: Optional[Callable[[EnvType, EnvContext], None]], - policy_class: Type[Policy], - config: AlgorithmConfigDict, - num_workers: int, - local_worker: bool = True, - ) -> WorkerSet: - return WorkerSet( - env_creator=env_creator, - validate_env=validate_env, - default_policy_class=policy_class, - config=config, - num_workers=num_workers, - local_worker=local_worker, - logdir=self.logdir, - ) + @Deprecated(new="construct WorkerSet(...) instance directly", error=True) + def _make_workers(self, *args, **kwargs): + pass - def validate_config(self, config) -> None: - # TODO: Deprecate. All logic has been moved into the AlgorithmConfig classes. + @Deprecated(new="AlgorithmConfig.validate()", error=False) + def validate_config(self, config): pass @staticmethod diff --git a/rllib/algorithms/tests/test_algorithm.py b/rllib/algorithms/tests/test_algorithm.py index 28f2a42f51f5..910e135de8e0 100644 --- a/rllib/algorithms/tests/test_algorithm.py +++ b/rllib/algorithms/tests/test_algorithm.py @@ -109,12 +109,14 @@ def new_mapping_fn(agent_id, episode, worker, **kwargs): # Make sure new policy is part of remote workers in the # worker set and the eval worker set. self.assertTrue( - algo.workers.foreach_worker(func=lambda w: pid in w.policy_map)[0] + all(algo.workers.foreach_worker(func=lambda w: pid in w.policy_map)) ) self.assertTrue( - algo.evaluation_workers.foreach_worker( - func=lambda w: pid in w.policy_map - )[0] + all( + algo.evaluation_workers.foreach_worker( + func=lambda w: pid in w.policy_map + ) + ) ) # Assert new policy is part of local worker (eval worker set does NOT diff --git a/rllib/algorithms/tests/test_worker_failures.py b/rllib/algorithms/tests/test_worker_failures.py index 268326118515..5b258bc58699 100644 --- a/rllib/algorithms/tests/test_worker_failures.py +++ b/rllib/algorithms/tests/test_worker_failures.py @@ -216,7 +216,7 @@ def setUpClass(cls) -> None: register_env("fault_env", lambda c: FaultInjectEnv(c)) register_env( - "multi-agent-fault_env", lambda c: make_multi_agent(FaultInjectEnv)(c) + "multi_agent_fault_env", lambda c: make_multi_agent(FaultInjectEnv)(c) ) @classmethod @@ -284,28 +284,47 @@ def _do_test_fault_fatal(self, config, fail_eval=False): self.assertRaises(Exception, lambda: a.train()) a.stop() - def _do_test_fault_fatal_but_recreate(self, config): + def _do_test_fault_fatal_but_recreate(self, config, multi_agent=False): # Counter that will survive restarts. - COUNTER_NAME = "_do_test_fault_fatal_but_recreate" + COUNTER_NAME = ( + f"_do_test_fault_fatal_but_recreate{'_ma' if multi_agent else ''}" + ) counter = Counter.options(name=COUNTER_NAME).remote() # Test raises real error when out of workers. config.num_rollout_workers = 1 config.evaluation_num_workers = 1 config.evaluation_interval = 1 - config.env = "fault_env" - config.evaluation_config = { - "recreate_failed_workers": True, + config.env = "fault_env" if not multi_agent else "multi_agent_fault_env" + config.evaluation_config = AlgorithmConfig.overrides( + recreate_failed_workers=True, # 0 delay for testing purposes. - "delay_between_worker_restarts_s": 0, + delay_between_worker_restarts_s=0, # Make eval worker (index 1) fail. - "env_config": { + env_config={ "bad_indices": [1], "failure_start_count": 3, "failure_stop_count": 4, "counter": COUNTER_NAME, }, - } + **( + dict( + policy_mapping_fn=( + lambda aid, episode, worker, **kwargs: ( + # Allows this test to query this + # different-from-training-workers policy mapping fn. + "This is the eval mapping fn" + if episode is None + else "main" + if episode.episode_id % 2 == aid + else "p{}".format(np.random.choice([0, 1])) + ) + ) + ) + if multi_agent + else {} + ), + ) for _ in framework_iterator(config, frameworks=("tf2", "torch")): # Reset interaction counter. @@ -313,21 +332,22 @@ def _do_test_fault_fatal_but_recreate(self, config): a = config.build() - a.train() - wait_for_restore() - a.train() - - self.assertEqual(a.workers.num_healthy_remote_workers(), 1) - self.assertEqual(a.evaluation_workers.num_healthy_remote_workers(), 1) - # This should also work several times. - a.train() - wait_for_restore() - a.train() - - self.assertEqual(a.workers.num_healthy_remote_workers(), 1) - self.assertEqual(a.evaluation_workers.num_healthy_remote_workers(), 1) - + for _ in range(2): + a.train() + wait_for_restore() + a.train() + + self.assertEqual(a.workers.num_healthy_remote_workers(), 1) + self.assertEqual(a.evaluation_workers.num_healthy_remote_workers(), 1) + if multi_agent: + # Make a dummy call to the eval worker's policy_mapping_fn and + # make sure the restored eval worker received the correct one from + # the eval config (not the main workers' one). + test = a.evaluation_workers.foreach_worker( + lambda w: w.policy_mapping_fn(0, None, None) + ) + self.assertEqual(test[0], "This is the eval mapping fn") a.stop() def test_fatal(self): @@ -429,6 +449,36 @@ def test_recreate_eval_workers_parallel_to_training_w_actor_manager(self): self._do_test_fault_fatal_but_recreate(config) + def test_recreate_eval_workers_parallel_to_training_w_actor_manager_and_multi_agent( + self, + ): + # Test the case where all eval workers fail on a multi-agent env with + # different `policy_mapping_fn` in eval- vs train workers, but we chose + # to recover. + config = ( + PGConfig() + .multi_agent( + policies={"main", "p0", "p1"}, + policy_mapping_fn=( + lambda aid, episode, worker, **kwargs: ( + "main" + if episode.episode_id % 2 == aid + else "p{}".format(np.random.choice([0, 1])) + ) + ), + ) + .evaluation( + evaluation_num_workers=1, + enable_async_evaluation=True, + evaluation_parallel_to_training=True, + evaluation_duration="auto", + ) + .training(model={"fcnet_hiddens": [4]}) + .debugging(worker_cls=ForwardHealthCheckToEnvWorker) + ) + + self._do_test_fault_fatal_but_recreate(config, multi_agent=True) + def test_eval_workers_failing_fatal(self): # Test the case where all eval workers fail (w/o recovery). self._do_test_fault_fatal( @@ -526,7 +576,7 @@ def on_algorithm_init(self, *, algorithm, **kwargs): model={"fcnet_hiddens": [4]}, ) .environment( - env="multi-agent-fault_env", + env="multi_agent_fault_env", env_config={ # Make both worker idx=1 and 2 fail. "bad_indices": [1, 2], @@ -838,7 +888,7 @@ def test_multi_agent_env_eval_workers_fault_but_restore_env(self): model={"fcnet_hiddens": [4]}, ) .environment( - env="multi-agent-fault_env", + env="multi_agent_fault_env", # Workers do not fault and no fault tolerance. env_config={}, disable_env_checking=True, diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 780f811f401d..ecb8889fb579 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -522,12 +522,12 @@ def add_policy( if workers is not DEPRECATED_VALUE: deprecation_warning( - old="workers", + old="WorkerSet.add_policy(.., workers=..)", help=( - "The `workers` argument to `WorkerSet.add_policy()` is deprecated " - "and a no-op now. Please do not use it anymore." + "The `workers` argument to `WorkerSet.add_policy()` is deprecated! " + "Please do not use it anymore." ), - error=False, + error=True, ) if (policy_cls is None) == (policy is None): @@ -576,6 +576,7 @@ def _create_new_policy_fn(worker: RolloutWorker): worker.add_policy(**new_policy_instance_kwargs) if self.local_worker() is not None: + # Add policy directly by (already instantiated) object. if policy is not None: self.local_worker().add_policy( policy_id=policy_id, @@ -584,6 +585,7 @@ def _create_new_policy_fn(worker: RolloutWorker): policies_to_train=policies_to_train, module_spec=module_spec, ) + # Add policy by constructor kwargs. else: self.local_worker().add_policy(**new_policy_instance_kwargs) diff --git a/rllib/examples/custom_eval.py b/rllib/examples/custom_eval.py index 826782f93664..1cef98781658 100644 --- a/rllib/examples/custom_eval.py +++ b/rllib/examples/custom_eval.py @@ -118,23 +118,19 @@ def custom_eval_function(algorithm, eval_workers): Returns: metrics: Evaluation metrics dict. """ - - # We configured 2 eval workers in the training config. - funcs = [ - lambda w: w.foreach_env(lambda env: env.set_corridor_length(4)), - lambda w: w.foreach_env(lambda env: env.set_corridor_length(7)), - ] - - # Set different env settings for each worker. Here we use a fixed config, - # which also could have been computed in each worker by looking at - # env_config.worker_index (printed in SimpleCorridor class above). - eval_workers.foreach_worker(func=funcs) + # Set different env settings for each worker. Here we use the worker's + # `worker_index` property. + eval_workers.foreach_worker( + func=lambda w: w.foreach_env( + lambda env: env.set_corridor_length(4 if w.worker_index == 1 else 7) + ) + ) for i in range(5): print("Custom evaluation round", i) # Calling .sample() runs exactly one episode per worker due to how the # eval workers are configured. - eval_workers.foreach_worker(func=lambda w: w.sample()) + eval_workers.foreach_worker(func=lambda w: w.sample(), local_worker=False) # Collect the accumulated episodes on the workers, and then summarize the # episode stats into a metrics dict. diff --git a/rllib/utils/actor_manager.py b/rllib/utils/actor_manager.py index 92a1d920b088..1d06bb9cca97 100644 --- a/rllib/utils/actor_manager.py +++ b/rllib/utils/actor_manager.py @@ -341,7 +341,7 @@ def num_actors(self) -> int: @DeveloperAPI def num_healthy_actors(self) -> int: """Return the number of healthy remote actors.""" - return sum([s.is_healthy for s in self.__remote_actor_states.values()]) + return sum(s.is_healthy for s in self.__remote_actor_states.values()) @DeveloperAPI def total_num_restarts(self) -> int: @@ -785,10 +785,12 @@ def probe_unhealthy_actors( ) -> List[int]: """Ping all unhealthy actors to try bringing them back. - Returns: - A list of actor ids that are restored. + Args: timeout_seconds: Timeout to avoid pinging hanging workers indefinitely. mark_healthy: Whether to mark actors healthy if they respond to the ping. + + Returns: + A list of actor ids that are restored. """ unhealthy_actor_ids = [ actor_id