Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add more control to Algorithm.add_policy method over which components to add the Policy to. #46932

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 69 additions & 28 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2021,7 +2021,7 @@ def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Policy:
Args:
policy_id: ID of the policy to return.
"""
return self.env_runner_group.local_env_runner.get_policy(policy_id)
return self.env_runner.get_policy(policy_id)

@PublicAPI
def get_weights(self, policies: Optional[List[PolicyID]] = None) -> dict:
Expand All @@ -2034,7 +2034,7 @@ def get_weights(self, policies: Optional[List[PolicyID]] = None) -> dict:
# New API stack (get weights from LearnerGroup).
if self.learner_group is not None:
return self.learner_group.get_weights(module_ids=policies)
return self.env_runner_group.local_env_runner.get_weights(policies)
return self.env_runner.get_weights(policies)

@PublicAPI
def set_weights(self, weights: Dict[PolicyID, dict]):
Expand Down Expand Up @@ -2399,8 +2399,12 @@ def add_policy(
Callable[[PolicyID, Optional[SampleBatchType]], bool],
]
] = None,
evaluation_workers: bool = True,
add_to_learners: bool = True,
add_to_env_runners: bool = True,
add_to_eval_env_runners: bool = True,
module_spec: Optional[RLModuleSpec] = None,
# Deprecated arg.
evaluation_workers=DEPRECATED_VALUE,
) -> Optional[Policy]:
"""Adds a new policy to this Algorithm.

Expand Down Expand Up @@ -2433,8 +2437,13 @@ def add_policy(
If None, will keep the existing setup in place. Policies,
whose IDs are not in the list (or for which the callable
returns False) will not be updated.
evaluation_workers: Whether to add the new policy also
to the evaluation EnvRunnerGroup.
add_to_learners: Whether to add the new RLModule to the LearnerGroup
(with its n Learners). This setting is only valid on the hybrid-API
stack (with Learners, but w/o EnvRunners).
add_to_env_runners: Whether to add the new RLModule to the EnvRunnerGroup
(with its m EnvRunners plus the local one).
add_to_eval_env_runners: Whether to add the new RLModule to the eval
EnvRunnerGroup (with its o EnvRunners plus the local one).
module_spec: In the new RLModule API we need to pass in the module_spec for
the new module that is supposed to be added. Knowing the policy spec is
not sufficient.
Expand All @@ -2451,24 +2460,32 @@ def add_policy(
"example."
)

if evaluation_workers != DEPRECATED_VALUE:
deprecation_warning(
old="Algorithm.add_policy(evaluation_workers=...)",
new="Algorithm.add_policy(add_to_eval_env_runners=...)",
error=True,
)

validate_module_id(policy_id, error=True)

self.env_runner_group.add_policy(
policy_id,
policy_cls,
policy,
observation_space=observation_space,
action_space=action_space,
config=config,
policy_state=policy_state,
policy_mapping_fn=policy_mapping_fn,
policies_to_train=policies_to_train,
module_spec=module_spec,
)
if add_to_env_runners is True:
self.env_runner_group.add_policy(
policy_id,
policy_cls,
policy,
observation_space=observation_space,
action_space=action_space,
config=config,
policy_state=policy_state,
policy_mapping_fn=policy_mapping_fn,
policies_to_train=policies_to_train,
module_spec=module_spec,
)

# If learner API is enabled, we need to also add the underlying module
# If Learner API is enabled, we need to also add the underlying module
# to the learner group.
if self.config.enable_rl_module_and_learner:
if add_to_learners and self.config.enable_rl_module_and_learner:
policy = self.get_policy(policy_id)
module = policy.model
self.learner_group.add_module(
Expand All @@ -2490,7 +2507,7 @@ def add_policy(
self.learner_group.set_weights({policy_id: weights})

# Add to evaluation workers, if necessary.
if evaluation_workers is True and self.eval_env_runner_group is not None:
if add_to_eval_env_runners is True and self.eval_env_runner_group is not None:
self.eval_env_runner_group.add_policy(
policy_id,
policy_cls,
Expand All @@ -2504,8 +2521,11 @@ def add_policy(
module_spec=module_spec,
)

# Return newly added policy (from the local rollout worker).
return self.get_policy(policy_id)
# Return newly added policy (from the local EnvRunner).
if add_to_env_runners:
return self.get_policy(policy_id)
elif add_to_eval_env_runners and self.eval_env_runner_group:
return self.eval_env_runner.policy_map[policy_id]

@OldAPIStack
def remove_policy(
Expand All @@ -2519,7 +2539,11 @@ def remove_policy(
Callable[[PolicyID, Optional[SampleBatchType]], bool],
]
] = None,
evaluation_workers: bool = True,
remove_from_learners: bool = True,
remove_from_env_runners: bool = True,
remove_from_eval_env_runners: bool = True,
# Deprecated args.
evaluation_workers=DEPRECATED_VALUE,
) -> None:
"""Removes a policy from this Algorithm.

Expand All @@ -2535,9 +2559,21 @@ def remove_policy(
If None, will keep the existing setup in place. Policies,
whose IDs are not in the list (or for which the callable
returns False) will not be updated.
evaluation_workers: Whether to also remove the policy from the
evaluation EnvRunnerGroup.
remove_from_learners: Whether to remove the Policy from the LearnerGroup
(with its n Learners). Only valid on the hybrid API stack (w/ Learners,
but w/o EnvRunners).
remove_from_env_runners: Whether to remove the Policy from the
EnvRunnerGroup (with its m EnvRunners plus the local one).
remove_from_eval_env_runners: Whether to remove the RLModule from the eval
EnvRunnerGroup (with its o EnvRunners plus the local one).
"""
if evaluation_workers != DEPRECATED_VALUE:
deprecation_warning(
old="Algorithm.remove_policy(evaluation_workers=...)",
new="Algorithm.remove_policy(remove_from_eval_env_runners=...)",
error=False,
)
remove_from_eval_env_runners = evaluation_workers

def fn(worker):
worker.remove_policy(
Expand All @@ -2547,11 +2583,16 @@ def fn(worker):
)

# Update all EnvRunner workers.
self.env_runner_group.foreach_worker(fn, local_env_runner=True)
if remove_from_env_runners:
self.env_runner_group.foreach_worker(fn, local_env_runner=True)

# Update each Learner's `policies_to_train` information, but only
# if the arg is explicitly provided here.
if self.config.enable_rl_module_and_learner and policies_to_train is not None:
if (
remove_from_learners
and self.config.enable_rl_module_and_learner
and policies_to_train is not None
):
self.learner_group.foreach_learner(
func=lambda learner: learner.config.multi_agent(
policies_to_train=policies_to_train
Expand All @@ -2560,7 +2601,7 @@ def fn(worker):
)

# Update the evaluation worker set's workers, if required.
if evaluation_workers and self.eval_env_runner_group is not None:
if remove_from_eval_env_runners and self.eval_env_runner_group is not None:
self.eval_env_runner_group.foreach_worker(fn, local_env_runner=True)

@OldAPIStack
Expand Down
7 changes: 6 additions & 1 deletion rllib/algorithms/appo/appo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@ def add_module(
config_overrides: Optional[Dict] = None,
new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
) -> MultiRLModuleSpec:
marl_spec = super().add_module(module_id=module_id)
marl_spec = super().add_module(
module_id=module_id,
module_spec=module_spec,
config_overrides=config_overrides,
new_should_module_be_updated=new_should_module_be_updated,
)
# Create target networks for added Module, if applicable.
if isinstance(self.module[module_id].unwrapped(), TargetNetworkAPI):
self.module[module_id].unwrapped().make_target_networks()
Expand Down
7 changes: 6 additions & 1 deletion rllib/algorithms/dqn/dqn_rainbow_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ def add_module(
config_overrides: Optional[Dict] = None,
new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
) -> MultiRLModuleSpec:
marl_spec = super().add_module(module_id=module_id)
marl_spec = super().add_module(
module_id=module_id,
module_spec=module_spec,
config_overrides=config_overrides,
new_should_module_be_updated=new_should_module_be_updated,
)
# Create target networks for added Module, if applicable.
if isinstance(self.module[module_id].unwrapped(), TargetNetworkAPI):
self.module[module_id].unwrapped().make_target_networks()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def on_algorithm_init(self, *, algorithm, metrics_logger, **kwargs):
algorithm.add_policy(
policy_id=OPPONENT_POLICY_ID,
policy=policy,
evaluation_workers=True,
add_to_eval_env_runners=True,
)


Expand Down
Loading