Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Algorithm.add_policy() should alternatively accept an already instantiated policy object. #28637

Merged
merged 10 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 96 additions & 65 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class directly. Note that this arg can also be specified via
"""

# User provided (partial) config (this may be w/o the default
# Trainer's Config object). Will get merged with AlgorithmConfig()
# Algorithm's Config object). Will get merged with AlgorithmConfig()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found a few of these old Trainer in the comments.

# in self.setup().
config = config or {}
# Resolve AlgorithmConfig into a plain dict.
Expand Down Expand Up @@ -343,7 +343,7 @@ def setup(self, config: PartialAlgorithmConfigDict):
# Validate the framework settings in config.
self.validate_framework(self.config)

# Set Trainer's seed after we have - if necessary - enabled
# Set Algorithm's seed after we have - if necessary - enabled
# tf eager-execution.
update_global_seed_if_necessary(self.config["framework"], self.config["seed"])

Expand All @@ -368,7 +368,7 @@ def setup(self, config: PartialAlgorithmConfigDict):

# Create a dict, mapping ActorHandles to sets of open remote
# requests (object refs). This way, we keep track, of which actors
# inside this Trainer (e.g. a remote RolloutWorker) have
# inside this Algorithm (e.g. a remote RolloutWorker) have
# already been sent how many (e.g. `sample()`) requests.
self.remote_requests_in_flight: DefaultDict[
ActorHandle, Set[ray.ObjectRef]
Expand Down Expand Up @@ -1547,7 +1547,8 @@ def set_weights(self, weights: Dict[PolicyID, dict]):
def add_policy(
self,
policy_id: PolicyID,
policy_cls: Type[Policy],
policy_cls: Optional[Type[Policy]] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified the Algorithm.add_policy() method by only using the new WorkerSet APIs. No more micro-handling individual workers policy_maps here.

policy: Optional[Policy] = None,
*,
observation_space: Optional[gym.spaces.Space] = None,
action_space: Optional[gym.spaces.Space] = None,
Expand All @@ -1561,14 +1562,21 @@ def add_policy(
]
] = None,
evaluation_workers: bool = True,
workers: Optional[List[Union[RolloutWorker, ActorHandle]]] = None,
) -> Policy:
"""Adds a new policy to this Trainer.
worker_list: Optional[List[Union[RolloutWorker, ActorHandle]]] = None,
# Deprecated args:
workers=None,
) -> Optional[Policy]:
"""Adds a new policy to this Algorithm.

Args:
policy_id: ID of the policy to add.
policy_cls: The Policy class to use for
constructing the new Policy.
policy_cls: The Policy class to use for constructing the new Policy.
Note: Only one of `policy_cls` or `policy` must be provided.
policy: The Policy instance to add to this algorithm. If not None, the
given Policy object will be directly inserted into the Algorithm's
local worker and clones of that Policy will be created on all remote
workers as well as all evaluation workers.
Note: Only one of `policy_cls` or `policy` must be provided.
observation_space: The observation space of the policy to add.
If None, try to infer this space from the environment.
action_space: The action space of the policy to add.
Expand All @@ -1588,46 +1596,69 @@ def add_policy(
returns False) will not be updated.
evaluation_workers: Whether to add the new policy also
to the evaluation WorkerSet.
workers: A list of RolloutWorker/ActorHandles (remote
worker_list: A list of RolloutWorker/ActorHandles (remote
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a random request, do you mind staying with workers for now? since this is a simple name change.
as part of the elastic training PR, I am getting rid of all these places where we are accessing underlying RolloutWorkers outside of WorkerSet.
so if we deprecate this today, in a few days, I am gonna have to deprecate worker_list too, and we will have 2 deprecated fields here.

RolloutWorkers) to add this policy to. If defined, will only
add the given policy to these workers.

Returns:
The newly added policy (the copy that got added to the local
worker).
"""

kwargs = dict(
policy_id=policy_id,
policy_cls=policy_cls,
observation_space=observation_space,
action_space=action_space,
config=config,
policy_state=policy_state,
policy_mapping_fn=policy_mapping_fn,
policies_to_train=list(policies_to_train) if policies_to_train else None,
)

def fn(worker: RolloutWorker):
# `foreach_worker` function: Adds the policy the the worker (and
# maybe changes its policy_mapping_fn - if provided here).
worker.add_policy(**kwargs)
worker). If `worker_list` was provided, None is returned.

Raises:
ValueError: If both `policy_cls` AND `policy` are provided.
KeyError: If the given `policy_id` already exists in this Algorithm.
"""
# Deprecated args.
if workers is not None:
ray_gets = []
for worker in workers:
if isinstance(worker, ActorHandle):
ray_gets.append(worker.add_policy.remote(**kwargs))
else:
fn(worker)
ray.get(ray_gets)
deprecation_warning(
old="Algorithm.add_policy(.., workers=...)",
new="Algorithm.add_policy(.., worker_list=...)",
error=False,
)
worker_list = workers

# Worker list is explicitly provided -> Use only those workers (local or remote)
# specified.
if worker_list is not None:
RolloutWorker.add_policy_to_workers(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this change, this actually works well with my fault tolerance PR. I will make WorkerSet.add_policy_to_workers() the only way to go about this in the future actually.

one problem though, this should be WorkerSet. not RolloutWorker.?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yeah, great catch!

worker_list,
policy_id,
policy_cls,
policy,
observation_space=observation_space,
action_space=action_space,
config=config,
policy_state=policy_state,
policy_mapping_fn=policy_mapping_fn,
policies_to_train=policies_to_train,
)
# Add to all our regular RolloutWorkers and maybe also all evaluation workers.
else:
# Run foreach_worker fn on all workers.
self.workers.foreach_worker(fn)
self.workers.add_policy(
policy_id,
policy_cls,
policy,
observation_space=observation_space,
action_space=action_space,
config=config,
policy_state=policy_state,
policy_mapping_fn=policy_mapping_fn,
policies_to_train=policies_to_train,
)

# Update evaluation workers, if necessary.
if evaluation_workers and self.evaluation_workers is not None:
self.evaluation_workers.foreach_worker(fn)
# Add to evaluation workers, if necessary.
if evaluation_workers and self.evaluation_workers is not None:
self.evaluation_workers.add_policy(
policy_id,
policy_cls,
policy,
observation_space=observation_space,
action_space=action_space,
config=config,
policy_state=policy_state,
policy_mapping_fn=policy_mapping_fn,
policies_to_train=policies_to_train,
)

# Return newly added policy (from the local rollout worker).
return self.get_policy(policy_id)
Expand All @@ -1646,7 +1677,7 @@ def remove_policy(
] = None,
evaluation_workers: bool = True,
) -> None:
"""Removes a new policy from this Trainer.
"""Removes a new policy from this Algorithm.

Args:
policy_id: ID of the policy to be removed.
Expand Down Expand Up @@ -1693,12 +1724,12 @@ def export_policy_model(

Example:
>>> from ray.rllib.algorithms.ppo import PPO
>>> # Use a Trainer from RLlib or define your own.
>>> trainer = PPO(...) # doctest: +SKIP
>>> # Use an Algorithm from RLlib or define your own.
>>> algo = PPO(...) # doctest: +SKIP
>>> for _ in range(10): # doctest: +SKIP
>>> trainer.train() # doctest: +SKIP
>>> trainer.export_policy_model("/tmp/dir") # doctest: +SKIP
>>> trainer.export_policy_model("/tmp/dir/onnx", onnx=1) # doctest: +SKIP
>>> algo.train() # doctest: +SKIP
>>> algo.export_policy_model("/tmp/dir") # doctest: +SKIP
>>> algo.export_policy_model("/tmp/dir/onnx", onnx=1) # doctest: +SKIP
"""
self.get_policy(policy_id).export_model(export_dir, onnx)

Expand All @@ -1718,11 +1749,11 @@ def export_policy_checkpoint(

Example:
>>> from ray.rllib.algorithms.ppo import PPO
>>> # Use a Trainer from RLlib or define your own.
>>> trainer = PPO(...) # doctest: +SKIP
>>> # Use an Algorithm from RLlib or define your own.
>>> algo = PPO(...) # doctest: +SKIP
>>> for _ in range(10): # doctest: +SKIP
>>> trainer.train() # doctest: +SKIP
>>> trainer.export_policy_checkpoint("/tmp/export_dir") # doctest: +SKIP
>>> algo.train() # doctest: +SKIP
>>> algo.export_policy_checkpoint("/tmp/export_dir") # doctest: +SKIP
"""
self.get_policy(policy_id).export_checkpoint(export_dir, filename_prefix)

Expand All @@ -1740,10 +1771,10 @@ def import_policy_model_from_h5(

Example:
>>> from ray.rllib.algorithms.ppo import PPO
>>> trainer = PPO(...) # doctest: +SKIP
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5") # doctest: +SKIP
>>> algo = PPO(...) # doctest: +SKIP
>>> algo.import_policy_model_from_h5("/tmp/weights.h5") # doctest: +SKIP
>>> for _ in range(10): # doctest: +SKIP
>>> trainer.train() # doctest: +SKIP
>>> algo.train() # doctest: +SKIP
"""
self.get_policy(policy_id).import_model_from_h5(import_file)
# Sync new weights to remote workers.
Expand Down Expand Up @@ -1787,7 +1818,7 @@ def default_resource_request(
cls, config: PartialAlgorithmConfigDict
) -> Union[Resources, PlacementGroupFactory]:

# Default logic for RLlib algorithms (Trainers):
# Default logic for RLlib Algorithms:
# Create one bundle per individual worker (local or remote).
# Use `num_cpus_for_driver` and `num_gpus` for the local worker and
# `num_cpus_per_worker` and `num_gpus_per_worker` for the remote
Expand Down Expand Up @@ -1854,7 +1885,7 @@ def _get_env_id_and_creator(
Args:
env_specifier: An env class, an already tune registered env ID, a known
gym env name, or None (if no env is used).
config: The Trainer's (maybe partial) config dict.
config: The Algorithm's (maybe partial) config dict.

Returns:
Tuple consisting of a) env ID string and b) env creator callable.
Expand Down Expand Up @@ -1976,21 +2007,21 @@ def merge_trainer_configs(
config2: PartialAlgorithmConfigDict,
_allow_unknown_configs: Optional[bool] = None,
) -> AlgorithmConfigDict:
"""Merges a complete Trainer config with a partial override dict.
"""Merges a complete Algorithm config dict with a partial override dict.

Respects nested structures within the config dicts. The values in the
partial override dict take priority.

Args:
config1: The complete Trainer's dict to be merged (overridden)
config1: The complete Algorithm's dict to be merged (overridden)
with `config2`.
config2: The partial override config dict to merge on top of
`config1`.
_allow_unknown_configs: If True, keys in `config2` that don't exist
in `config1` are allowed and will be added to the final config.

Returns:
The merged full trainer config dict.
The merged full algorithm config dict.
"""
config1 = copy.deepcopy(config1)
if "callbacks" in config2 and type(config2["callbacks"]) is dict:
Expand Down Expand Up @@ -2092,7 +2123,7 @@ def resolve_tf_settings():
@OverrideToImplementCustomLogic_CallToSuperRecommended
@DeveloperAPI
def validate_config(self, config: AlgorithmConfigDict) -> None:
"""Validates a given config dict for this Trainer.
"""Validates a given config dict for this Algorithm.

Users should override this method to implement custom validation
behavior. It is recommended to call `super().validate_config()` in
Expand Down Expand Up @@ -2310,8 +2341,8 @@ def validate_config(self, config: AlgorithmConfigDict) -> None:
f"You have specified {config['evaluation_num_workers']} "
"evaluation workers, but your `evaluation_interval` is None! "
"Therefore, evaluation will not occur automatically with each"
" call to `Trainer.train()`. Instead, you will have to call "
"`Trainer.evaluate()` manually in order to trigger an "
" call to `Algorithm.train()`. Instead, you will have to call "
"`Algorithm.evaluate()` manually in order to trigger an "
"evaluation run."
)
# If `evaluation_num_workers=0` and
Expand Down Expand Up @@ -2349,7 +2380,7 @@ def validate_config(self, config: AlgorithmConfigDict) -> None:
@staticmethod
@ExperimentalAPI
def validate_env(env: EnvType, env_context: EnvContext) -> None:
"""Env validator function for this Trainer class.
"""Env validator function for this Algorithm class.

Override this in child classes to define custom validation
behavior.
Expand Down Expand Up @@ -2547,7 +2578,7 @@ def _create_local_replay_buffer_if_necessary(
config: Algorithm-specific configuration data.

Returns:
MultiAgentReplayBuffer instance based on trainer config.
MultiAgentReplayBuffer instance based on algorithm config.
None, if local replay buffer is not needed.
"""
if not config.get("replay_buffer_config") or config["replay_buffer_config"].get(
Expand Down Expand Up @@ -2842,7 +2873,7 @@ def _record_usage(self, config):
alg = "USER_DEFINED"
record_extra_usage_tag(TagKey.RLLIB_ALGORITHM, alg)

@Deprecated(new="Trainer.compute_single_action()", error=False)
@Deprecated(new="Algorithm.compute_single_action()", error=False)
def compute_action(self, *args, **kwargs):
return self.compute_single_action(*args, **kwargs)

Expand Down Expand Up @@ -2872,7 +2903,7 @@ def _make_workers(
)

@staticmethod
@Deprecated(new="Trainer.validate_config()", error=False)
@Deprecated(new="Algorithm.validate_config()", error=False)
def _validate_config(config, trainer_or_none):
assert trainer_or_none is not None
return trainer_or_none.validate_config(config)
Expand Down
Loading