From 0c0a10b291471ad3a784ffcafed9f4c24a74773f Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 1 Aug 2024 23:05:14 +0200 Subject: [PATCH 1/2] wip Signed-off-by: sven1977 --- rllib/algorithms/algorithm.py | 61 ++++++++++++++------- rllib/algorithms/appo/appo_learner.py | 7 ++- rllib/algorithms/dqn/dqn_rainbow_learner.py | 7 ++- 3 files changed, 52 insertions(+), 23 deletions(-) diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 0d5d47c20ca1..7c846bef1f18 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -2021,7 +2021,7 @@ def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Policy: Args: policy_id: ID of the policy to return. """ - return self.env_runner_group.local_env_runner.get_policy(policy_id) + return self.env_runner.get_policy(policy_id) @PublicAPI def get_weights(self, policies: Optional[List[PolicyID]] = None) -> dict: @@ -2034,7 +2034,7 @@ def get_weights(self, policies: Optional[List[PolicyID]] = None) -> dict: # New API stack (get weights from LearnerGroup). if self.learner_group is not None: return self.learner_group.get_weights(module_ids=policies) - return self.env_runner_group.local_env_runner.get_weights(policies) + return self.env_runner.get_weights(policies) @PublicAPI def set_weights(self, weights: Dict[PolicyID, dict]): @@ -2399,8 +2399,12 @@ def add_policy( Callable[[PolicyID, Optional[SampleBatchType]], bool], ] ] = None, - evaluation_workers: bool = True, + add_to_learners: bool = True, + add_to_env_runners: bool = True, + add_to_eval_env_runners: bool = True, module_spec: Optional[RLModuleSpec] = None, + # Deprecated arg. + evaluation_workers=DEPRECATED_VALUE, ) -> Optional[Policy]: """Adds a new policy to this Algorithm. @@ -2433,8 +2437,12 @@ def add_policy( If None, will keep the existing setup in place. Policies, whose IDs are not in the list (or for which the callable returns False) will not be updated. - evaluation_workers: Whether to add the new policy also - to the evaluation EnvRunnerGroup. + add_to_learners: Whether to add the new RLModule to the LearnerGroup + (with its n Learners). + add_to_env_runners: Whether to add the new RLModule to the EnvRunnerGroup + (with its m EnvRunners plus the local one). + add_to_eval_env_runners: Whether to add the new RLModule to the eval + EnvRunnerGroup (with its o EnvRunners plus the local one). module_spec: In the new RLModule API we need to pass in the module_spec for the new module that is supposed to be added. Knowing the policy spec is not sufficient. @@ -2451,24 +2459,32 @@ def add_policy( "example." ) + if evaluation_workers != DEPRECATED_VALUE: + deprecation_warning( + old="Algorithm.add_policy(evaluation_workers=...)", + new="Algorithm.add_policy(add_to_eval_env_runners=...)", + error=True, + ) + validate_module_id(policy_id, error=True) - self.env_runner_group.add_policy( - policy_id, - policy_cls, - policy, - observation_space=observation_space, - action_space=action_space, - config=config, - policy_state=policy_state, - policy_mapping_fn=policy_mapping_fn, - policies_to_train=policies_to_train, - module_spec=module_spec, - ) + if add_to_env_runners is True: + self.env_runner_group.add_policy( + policy_id, + policy_cls, + policy, + observation_space=observation_space, + action_space=action_space, + config=config, + policy_state=policy_state, + policy_mapping_fn=policy_mapping_fn, + policies_to_train=policies_to_train, + module_spec=module_spec, + ) # If learner API is enabled, we need to also add the underlying module # to the learner group. - if self.config.enable_rl_module_and_learner: + if add_to_learners and self.config.enable_rl_module_and_learner: policy = self.get_policy(policy_id) module = policy.model self.learner_group.add_module( @@ -2490,7 +2506,7 @@ def add_policy( self.learner_group.set_weights({policy_id: weights}) # Add to evaluation workers, if necessary. - if evaluation_workers is True and self.eval_env_runner_group is not None: + if add_to_eval_env_runners is True and self.eval_env_runner_group is not None: self.eval_env_runner_group.add_policy( policy_id, policy_cls, @@ -2504,8 +2520,11 @@ def add_policy( module_spec=module_spec, ) - # Return newly added policy (from the local rollout worker). - return self.get_policy(policy_id) + # Return newly added policy (from the local EnvRunner). + if add_to_env_runners: + return self.get_policy(policy_id) + elif add_to_eval_env_runners and self.eval_env_runner_group: + return self.eval_env_runner.policy_map[policy_id] @OldAPIStack def remove_policy( diff --git a/rllib/algorithms/appo/appo_learner.py b/rllib/algorithms/appo/appo_learner.py index 9440dd9c33ca..a1c06a854309 100644 --- a/rllib/algorithms/appo/appo_learner.py +++ b/rllib/algorithms/appo/appo_learner.py @@ -58,7 +58,12 @@ def add_module( config_overrides: Optional[Dict] = None, new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, ) -> MultiRLModuleSpec: - marl_spec = super().add_module(module_id=module_id) + marl_spec = super().add_module( + module_id=module_id, + module_spec=module_spec, + config_overrides=config_overrides, + new_should_module_be_updated=new_should_module_be_updated, + ) # Create target networks for added Module, if applicable. if isinstance(self.module[module_id].unwrapped(), TargetNetworkAPI): self.module[module_id].unwrapped().make_target_networks() diff --git a/rllib/algorithms/dqn/dqn_rainbow_learner.py b/rllib/algorithms/dqn/dqn_rainbow_learner.py index 41f2e48f44f2..b09174ab2c90 100644 --- a/rllib/algorithms/dqn/dqn_rainbow_learner.py +++ b/rllib/algorithms/dqn/dqn_rainbow_learner.py @@ -72,7 +72,12 @@ def add_module( config_overrides: Optional[Dict] = None, new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, ) -> MultiRLModuleSpec: - marl_spec = super().add_module(module_id=module_id) + marl_spec = super().add_module( + module_id=module_id, + module_spec=module_spec, + config_overrides=config_overrides, + new_should_module_be_updated=new_should_module_be_updated, + ) # Create target networks for added Module, if applicable. if isinstance(self.module[module_id].unwrapped(), TargetNetworkAPI): self.module[module_id].unwrapped().make_target_networks() From dcd26c2395d22eecc6a47943db3942a7d0b22767 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 2 Aug 2024 08:57:21 +0200 Subject: [PATCH 2/2] wip Signed-off-by: sven1977 --- rllib/algorithms/algorithm.py | 38 +++++++++++++++---- .../self_play_with_policy_checkpoint.py | 2 +- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 7c846bef1f18..ff10fcc19f66 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -2438,7 +2438,8 @@ def add_policy( whose IDs are not in the list (or for which the callable returns False) will not be updated. add_to_learners: Whether to add the new RLModule to the LearnerGroup - (with its n Learners). + (with its n Learners). This setting is only valid on the hybrid-API + stack (with Learners, but w/o EnvRunners). add_to_env_runners: Whether to add the new RLModule to the EnvRunnerGroup (with its m EnvRunners plus the local one). add_to_eval_env_runners: Whether to add the new RLModule to the eval @@ -2482,7 +2483,7 @@ def add_policy( module_spec=module_spec, ) - # If learner API is enabled, we need to also add the underlying module + # If Learner API is enabled, we need to also add the underlying module # to the learner group. if add_to_learners and self.config.enable_rl_module_and_learner: policy = self.get_policy(policy_id) @@ -2538,7 +2539,11 @@ def remove_policy( Callable[[PolicyID, Optional[SampleBatchType]], bool], ] ] = None, - evaluation_workers: bool = True, + remove_from_learners: bool = True, + remove_from_env_runners: bool = True, + remove_from_eval_env_runners: bool = True, + # Deprecated args. + evaluation_workers=DEPRECATED_VALUE, ) -> None: """Removes a policy from this Algorithm. @@ -2554,9 +2559,21 @@ def remove_policy( If None, will keep the existing setup in place. Policies, whose IDs are not in the list (or for which the callable returns False) will not be updated. - evaluation_workers: Whether to also remove the policy from the - evaluation EnvRunnerGroup. + remove_from_learners: Whether to remove the Policy from the LearnerGroup + (with its n Learners). Only valid on the hybrid API stack (w/ Learners, + but w/o EnvRunners). + remove_from_env_runners: Whether to remove the Policy from the + EnvRunnerGroup (with its m EnvRunners plus the local one). + remove_from_eval_env_runners: Whether to remove the RLModule from the eval + EnvRunnerGroup (with its o EnvRunners plus the local one). """ + if evaluation_workers != DEPRECATED_VALUE: + deprecation_warning( + old="Algorithm.remove_policy(evaluation_workers=...)", + new="Algorithm.remove_policy(remove_from_eval_env_runners=...)", + error=False, + ) + remove_from_eval_env_runners = evaluation_workers def fn(worker): worker.remove_policy( @@ -2566,11 +2583,16 @@ def fn(worker): ) # Update all EnvRunner workers. - self.env_runner_group.foreach_worker(fn, local_env_runner=True) + if remove_from_env_runners: + self.env_runner_group.foreach_worker(fn, local_env_runner=True) # Update each Learner's `policies_to_train` information, but only # if the arg is explicitly provided here. - if self.config.enable_rl_module_and_learner and policies_to_train is not None: + if ( + remove_from_learners + and self.config.enable_rl_module_and_learner + and policies_to_train is not None + ): self.learner_group.foreach_learner( func=lambda learner: learner.config.multi_agent( policies_to_train=policies_to_train @@ -2579,7 +2601,7 @@ def fn(worker): ) # Update the evaluation worker set's workers, if required. - if evaluation_workers and self.eval_env_runner_group is not None: + if remove_from_eval_env_runners and self.eval_env_runner_group is not None: self.eval_env_runner_group.foreach_worker(fn, local_env_runner=True) @OldAPIStack diff --git a/rllib/examples/_old_api_stack/connectors/self_play_with_policy_checkpoint.py b/rllib/examples/_old_api_stack/connectors/self_play_with_policy_checkpoint.py index f15994c0456c..6215f476a964 100644 --- a/rllib/examples/_old_api_stack/connectors/self_play_with_policy_checkpoint.py +++ b/rllib/examples/_old_api_stack/connectors/self_play_with_policy_checkpoint.py @@ -64,7 +64,7 @@ def on_algorithm_init(self, *, algorithm, metrics_logger, **kwargs): algorithm.add_policy( policy_id=OPPONENT_POLICY_ID, policy=policy, - evaluation_workers=True, + add_to_eval_env_runners=True, )