-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Algorithm.add_policy()
should alternatively accept an already instantiated policy object.
#28637
[RLlib] Algorithm.add_policy()
should alternatively accept an already instantiated policy object.
#28637
Changes from 8 commits
f0dfda9
b31114c
8fde067
c8afcf3
f230ffc
80b1829
a01ccb2
447ba02
93c51a1
4d42ac2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -231,7 +231,7 @@ class directly. Note that this arg can also be specified via | |
""" | ||
|
||
# User provided (partial) config (this may be w/o the default | ||
# Trainer's Config object). Will get merged with AlgorithmConfig() | ||
# Algorithm's Config object). Will get merged with AlgorithmConfig() | ||
# in self.setup(). | ||
config = config or {} | ||
# Resolve AlgorithmConfig into a plain dict. | ||
|
@@ -343,7 +343,7 @@ def setup(self, config: PartialAlgorithmConfigDict): | |
# Validate the framework settings in config. | ||
self.validate_framework(self.config) | ||
|
||
# Set Trainer's seed after we have - if necessary - enabled | ||
# Set Algorithm's seed after we have - if necessary - enabled | ||
# tf eager-execution. | ||
update_global_seed_if_necessary(self.config["framework"], self.config["seed"]) | ||
|
||
|
@@ -368,7 +368,7 @@ def setup(self, config: PartialAlgorithmConfigDict): | |
|
||
# Create a dict, mapping ActorHandles to sets of open remote | ||
# requests (object refs). This way, we keep track, of which actors | ||
# inside this Trainer (e.g. a remote RolloutWorker) have | ||
# inside this Algorithm (e.g. a remote RolloutWorker) have | ||
# already been sent how many (e.g. `sample()`) requests. | ||
self.remote_requests_in_flight: DefaultDict[ | ||
ActorHandle, Set[ray.ObjectRef] | ||
|
@@ -1547,7 +1547,8 @@ def set_weights(self, weights: Dict[PolicyID, dict]): | |
def add_policy( | ||
self, | ||
policy_id: PolicyID, | ||
policy_cls: Type[Policy], | ||
policy_cls: Optional[Type[Policy]] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplified the |
||
policy: Optional[Policy] = None, | ||
*, | ||
observation_space: Optional[gym.spaces.Space] = None, | ||
action_space: Optional[gym.spaces.Space] = None, | ||
|
@@ -1561,14 +1562,21 @@ def add_policy( | |
] | ||
] = None, | ||
evaluation_workers: bool = True, | ||
workers: Optional[List[Union[RolloutWorker, ActorHandle]]] = None, | ||
) -> Policy: | ||
"""Adds a new policy to this Trainer. | ||
worker_list: Optional[List[Union[RolloutWorker, ActorHandle]]] = None, | ||
# Deprecated args: | ||
workers=None, | ||
) -> Optional[Policy]: | ||
"""Adds a new policy to this Algorithm. | ||
|
||
Args: | ||
policy_id: ID of the policy to add. | ||
policy_cls: The Policy class to use for | ||
constructing the new Policy. | ||
policy_cls: The Policy class to use for constructing the new Policy. | ||
Note: Only one of `policy_cls` or `policy` must be provided. | ||
policy: The Policy instance to add to this algorithm. If not None, the | ||
given Policy object will be directly inserted into the Algorithm's | ||
local worker and clones of that Policy will be created on all remote | ||
workers as well as all evaluation workers. | ||
Note: Only one of `policy_cls` or `policy` must be provided. | ||
observation_space: The observation space of the policy to add. | ||
If None, try to infer this space from the environment. | ||
action_space: The action space of the policy to add. | ||
|
@@ -1588,46 +1596,69 @@ def add_policy( | |
returns False) will not be updated. | ||
evaluation_workers: Whether to add the new policy also | ||
to the evaluation WorkerSet. | ||
workers: A list of RolloutWorker/ActorHandles (remote | ||
worker_list: A list of RolloutWorker/ActorHandles (remote | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a random request, do you mind staying with |
||
RolloutWorkers) to add this policy to. If defined, will only | ||
add the given policy to these workers. | ||
|
||
Returns: | ||
The newly added policy (the copy that got added to the local | ||
worker). | ||
""" | ||
|
||
kwargs = dict( | ||
policy_id=policy_id, | ||
policy_cls=policy_cls, | ||
observation_space=observation_space, | ||
action_space=action_space, | ||
config=config, | ||
policy_state=policy_state, | ||
policy_mapping_fn=policy_mapping_fn, | ||
policies_to_train=list(policies_to_train) if policies_to_train else None, | ||
) | ||
|
||
def fn(worker: RolloutWorker): | ||
# `foreach_worker` function: Adds the policy the the worker (and | ||
# maybe changes its policy_mapping_fn - if provided here). | ||
worker.add_policy(**kwargs) | ||
worker). If `worker_list` was provided, None is returned. | ||
|
||
Raises: | ||
ValueError: If both `policy_cls` AND `policy` are provided. | ||
KeyError: If the given `policy_id` already exists in this Algorithm. | ||
""" | ||
# Deprecated args. | ||
if workers is not None: | ||
ray_gets = [] | ||
for worker in workers: | ||
if isinstance(worker, ActorHandle): | ||
ray_gets.append(worker.add_policy.remote(**kwargs)) | ||
else: | ||
fn(worker) | ||
ray.get(ray_gets) | ||
deprecation_warning( | ||
old="Algorithm.add_policy(.., workers=...)", | ||
new="Algorithm.add_policy(.., worker_list=...)", | ||
error=False, | ||
) | ||
worker_list = workers | ||
|
||
# Worker list is explicitly provided -> Use only those workers (local or remote) | ||
# specified. | ||
if worker_list is not None: | ||
RolloutWorker.add_policy_to_workers( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this change, this actually works well with my fault tolerance PR. I will make one problem though, this should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, yeah, great catch! |
||
worker_list, | ||
policy_id, | ||
policy_cls, | ||
policy, | ||
observation_space=observation_space, | ||
action_space=action_space, | ||
config=config, | ||
policy_state=policy_state, | ||
policy_mapping_fn=policy_mapping_fn, | ||
policies_to_train=policies_to_train, | ||
) | ||
# Add to all our regular RolloutWorkers and maybe also all evaluation workers. | ||
else: | ||
# Run foreach_worker fn on all workers. | ||
self.workers.foreach_worker(fn) | ||
self.workers.add_policy( | ||
policy_id, | ||
policy_cls, | ||
policy, | ||
observation_space=observation_space, | ||
action_space=action_space, | ||
config=config, | ||
policy_state=policy_state, | ||
policy_mapping_fn=policy_mapping_fn, | ||
policies_to_train=policies_to_train, | ||
) | ||
|
||
# Update evaluation workers, if necessary. | ||
if evaluation_workers and self.evaluation_workers is not None: | ||
self.evaluation_workers.foreach_worker(fn) | ||
# Add to evaluation workers, if necessary. | ||
if evaluation_workers and self.evaluation_workers is not None: | ||
self.evaluation_workers.add_policy( | ||
policy_id, | ||
policy_cls, | ||
policy, | ||
observation_space=observation_space, | ||
action_space=action_space, | ||
config=config, | ||
policy_state=policy_state, | ||
policy_mapping_fn=policy_mapping_fn, | ||
policies_to_train=policies_to_train, | ||
) | ||
|
||
# Return newly added policy (from the local rollout worker). | ||
return self.get_policy(policy_id) | ||
|
@@ -1646,7 +1677,7 @@ def remove_policy( | |
] = None, | ||
evaluation_workers: bool = True, | ||
) -> None: | ||
"""Removes a new policy from this Trainer. | ||
"""Removes a new policy from this Algorithm. | ||
|
||
Args: | ||
policy_id: ID of the policy to be removed. | ||
|
@@ -1693,12 +1724,12 @@ def export_policy_model( | |
|
||
Example: | ||
>>> from ray.rllib.algorithms.ppo import PPO | ||
>>> # Use a Trainer from RLlib or define your own. | ||
>>> trainer = PPO(...) # doctest: +SKIP | ||
>>> # Use an Algorithm from RLlib or define your own. | ||
>>> algo = PPO(...) # doctest: +SKIP | ||
>>> for _ in range(10): # doctest: +SKIP | ||
>>> trainer.train() # doctest: +SKIP | ||
>>> trainer.export_policy_model("/tmp/dir") # doctest: +SKIP | ||
>>> trainer.export_policy_model("/tmp/dir/onnx", onnx=1) # doctest: +SKIP | ||
>>> algo.train() # doctest: +SKIP | ||
>>> algo.export_policy_model("/tmp/dir") # doctest: +SKIP | ||
>>> algo.export_policy_model("/tmp/dir/onnx", onnx=1) # doctest: +SKIP | ||
""" | ||
self.get_policy(policy_id).export_model(export_dir, onnx) | ||
|
||
|
@@ -1718,11 +1749,11 @@ def export_policy_checkpoint( | |
|
||
Example: | ||
>>> from ray.rllib.algorithms.ppo import PPO | ||
>>> # Use a Trainer from RLlib or define your own. | ||
>>> trainer = PPO(...) # doctest: +SKIP | ||
>>> # Use an Algorithm from RLlib or define your own. | ||
>>> algo = PPO(...) # doctest: +SKIP | ||
>>> for _ in range(10): # doctest: +SKIP | ||
>>> trainer.train() # doctest: +SKIP | ||
>>> trainer.export_policy_checkpoint("/tmp/export_dir") # doctest: +SKIP | ||
>>> algo.train() # doctest: +SKIP | ||
>>> algo.export_policy_checkpoint("/tmp/export_dir") # doctest: +SKIP | ||
""" | ||
self.get_policy(policy_id).export_checkpoint(export_dir, filename_prefix) | ||
|
||
|
@@ -1740,10 +1771,10 @@ def import_policy_model_from_h5( | |
|
||
Example: | ||
>>> from ray.rllib.algorithms.ppo import PPO | ||
>>> trainer = PPO(...) # doctest: +SKIP | ||
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5") # doctest: +SKIP | ||
>>> algo = PPO(...) # doctest: +SKIP | ||
>>> algo.import_policy_model_from_h5("/tmp/weights.h5") # doctest: +SKIP | ||
>>> for _ in range(10): # doctest: +SKIP | ||
>>> trainer.train() # doctest: +SKIP | ||
>>> algo.train() # doctest: +SKIP | ||
""" | ||
self.get_policy(policy_id).import_model_from_h5(import_file) | ||
# Sync new weights to remote workers. | ||
|
@@ -1787,7 +1818,7 @@ def default_resource_request( | |
cls, config: PartialAlgorithmConfigDict | ||
) -> Union[Resources, PlacementGroupFactory]: | ||
|
||
# Default logic for RLlib algorithms (Trainers): | ||
# Default logic for RLlib Algorithms: | ||
# Create one bundle per individual worker (local or remote). | ||
# Use `num_cpus_for_driver` and `num_gpus` for the local worker and | ||
# `num_cpus_per_worker` and `num_gpus_per_worker` for the remote | ||
|
@@ -1854,7 +1885,7 @@ def _get_env_id_and_creator( | |
Args: | ||
env_specifier: An env class, an already tune registered env ID, a known | ||
gym env name, or None (if no env is used). | ||
config: The Trainer's (maybe partial) config dict. | ||
config: The Algorithm's (maybe partial) config dict. | ||
|
||
Returns: | ||
Tuple consisting of a) env ID string and b) env creator callable. | ||
|
@@ -1976,21 +2007,21 @@ def merge_trainer_configs( | |
config2: PartialAlgorithmConfigDict, | ||
_allow_unknown_configs: Optional[bool] = None, | ||
) -> AlgorithmConfigDict: | ||
"""Merges a complete Trainer config with a partial override dict. | ||
"""Merges a complete Algorithm config dict with a partial override dict. | ||
|
||
Respects nested structures within the config dicts. The values in the | ||
partial override dict take priority. | ||
|
||
Args: | ||
config1: The complete Trainer's dict to be merged (overridden) | ||
config1: The complete Algorithm's dict to be merged (overridden) | ||
with `config2`. | ||
config2: The partial override config dict to merge on top of | ||
`config1`. | ||
_allow_unknown_configs: If True, keys in `config2` that don't exist | ||
in `config1` are allowed and will be added to the final config. | ||
|
||
Returns: | ||
The merged full trainer config dict. | ||
The merged full algorithm config dict. | ||
""" | ||
config1 = copy.deepcopy(config1) | ||
if "callbacks" in config2 and type(config2["callbacks"]) is dict: | ||
|
@@ -2092,7 +2123,7 @@ def resolve_tf_settings(): | |
@OverrideToImplementCustomLogic_CallToSuperRecommended | ||
@DeveloperAPI | ||
def validate_config(self, config: AlgorithmConfigDict) -> None: | ||
"""Validates a given config dict for this Trainer. | ||
"""Validates a given config dict for this Algorithm. | ||
|
||
Users should override this method to implement custom validation | ||
behavior. It is recommended to call `super().validate_config()` in | ||
|
@@ -2310,8 +2341,8 @@ def validate_config(self, config: AlgorithmConfigDict) -> None: | |
f"You have specified {config['evaluation_num_workers']} " | ||
"evaluation workers, but your `evaluation_interval` is None! " | ||
"Therefore, evaluation will not occur automatically with each" | ||
" call to `Trainer.train()`. Instead, you will have to call " | ||
"`Trainer.evaluate()` manually in order to trigger an " | ||
" call to `Algorithm.train()`. Instead, you will have to call " | ||
"`Algorithm.evaluate()` manually in order to trigger an " | ||
"evaluation run." | ||
) | ||
# If `evaluation_num_workers=0` and | ||
|
@@ -2349,7 +2380,7 @@ def validate_config(self, config: AlgorithmConfigDict) -> None: | |
@staticmethod | ||
@ExperimentalAPI | ||
def validate_env(env: EnvType, env_context: EnvContext) -> None: | ||
"""Env validator function for this Trainer class. | ||
"""Env validator function for this Algorithm class. | ||
|
||
Override this in child classes to define custom validation | ||
behavior. | ||
|
@@ -2547,7 +2578,7 @@ def _create_local_replay_buffer_if_necessary( | |
config: Algorithm-specific configuration data. | ||
|
||
Returns: | ||
MultiAgentReplayBuffer instance based on trainer config. | ||
MultiAgentReplayBuffer instance based on algorithm config. | ||
None, if local replay buffer is not needed. | ||
""" | ||
if not config.get("replay_buffer_config") or config["replay_buffer_config"].get( | ||
|
@@ -2842,7 +2873,7 @@ def _record_usage(self, config): | |
alg = "USER_DEFINED" | ||
record_extra_usage_tag(TagKey.RLLIB_ALGORITHM, alg) | ||
|
||
@Deprecated(new="Trainer.compute_single_action()", error=False) | ||
@Deprecated(new="Algorithm.compute_single_action()", error=False) | ||
def compute_action(self, *args, **kwargs): | ||
return self.compute_single_action(*args, **kwargs) | ||
|
||
|
@@ -2872,7 +2903,7 @@ def _make_workers( | |
) | ||
|
||
@staticmethod | ||
@Deprecated(new="Trainer.validate_config()", error=False) | ||
@Deprecated(new="Algorithm.validate_config()", error=False) | ||
def _validate_config(config, trainer_or_none): | ||
assert trainer_or_none is not None | ||
return trainer_or_none.validate_config(config) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found a few of these old
Trainer
in the comments.