From f0dfda965d262b332da8dc0b389e87d6fd6ff1ed Mon Sep 17 00:00:00 2001 From: sven1977 Date: Tue, 20 Sep 2022 14:59:45 +0200 Subject: [PATCH 1/8] wip Signed-off-by: sven1977 --- rllib/algorithms/algorithm.py | 93 +++++++++++++++++--- rllib/algorithms/tests/test_algorithm.py | 104 ++++++++++++++--------- rllib/connectors/util.py | 6 +- rllib/evaluation/rollout_worker.py | 80 ++++++++++++----- rllib/policy/policy_map.py | 35 ++++---- 5 files changed, 226 insertions(+), 92 deletions(-) diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 3217584deea9..b9991863b5ad 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -1547,7 +1547,8 @@ def set_weights(self, weights: Dict[PolicyID, dict]): def add_policy( self, policy_id: PolicyID, - policy_cls: Type[Policy], + policy_cls: Optional[Type[Policy]] = None, + policy: Optional[Policy] = None, *, observation_space: Optional[gym.spaces.Space] = None, action_space: Optional[gym.spaces.Space] = None, @@ -1569,6 +1570,9 @@ def add_policy( policy_id: ID of the policy to add. policy_cls: The Policy class to use for constructing the new Policy. + Note: Only one of `policy_cls` or `policy` must be provided. + policy: The Policy instance to add to this algorithm. + Note: Only one of `policy_cls` or `policy` must be provided. observation_space: The observation space of the policy to add. If None, try to infer this space from the environment. action_space: The action space of the policy to add. @@ -1595,35 +1599,98 @@ def add_policy( Returns: The newly added policy (the copy that got added to the local worker). + + Raises: + ValueError: If both `policy_cls` AND `policy` are provided. + KeyError: If the given `policy_id` already exists in this Algorithm. """ + local_worker = self.workers.local_worker() - kwargs = dict( - policy_id=policy_id, - policy_cls=policy_cls, - observation_space=observation_space, - action_space=action_space, - config=config, - policy_state=policy_state, - policy_mapping_fn=policy_mapping_fn, - policies_to_train=list(policies_to_train) if policies_to_train else None, - ) + if policy_id in local_worker.policy_map: + raise KeyError( + f"Policy ID '{policy_id}' already exists in policy map! " + "Make sure you use a Policy ID that has not been taken yet." + " Policy IDs that are already in your policy map: " + f"{list(local_worker.policy_map.keys())}" + ) + + if policy_cls is not None and policy is not None: + raise ValueError( + "Only one of `policy_cls` or `policy` must be provided to " + "Algorithm.add_policy()!" + ) + + # Policy instance not provided: Use the information given here. + if policy_cls is not None: + kwargs = dict( + policy_id=policy_id, + policy_cls=policy_cls, + observation_space=observation_space, + action_space=action_space, + config=config, + policy_state=policy_state, + policy_mapping_fn=policy_mapping_fn, + policies_to_train=list(policies_to_train) + if policies_to_train + else None, + ) + # Policy instance provided: Create clones of this very policy on the different + # workers (copy all its properties here for the calls to add_policy on the + # remote workers). + else: + kwargs = dict( + policy_id=policy_id, + policy_cls=type(policy), + observation_space=policy.observation_space, + action_space=policy.action_space, + config=policy.config, + policy_state=policy.get_state(), + policy_mapping_fn=policy_mapping_fn, + policies_to_train=list(policies_to_train) + if policies_to_train + else None, + ) def fn(worker: RolloutWorker): # `foreach_worker` function: Adds the policy the the worker (and # maybe changes its policy_mapping_fn - if provided here). worker.add_policy(**kwargs) + # Workers to add the policy to are given as an explicit list. if workers is not None: ray_gets = [] for worker in workers: + # A remote worker (ray actor). if isinstance(worker, ActorHandle): ray_gets.append(worker.add_policy.remote(**kwargs)) + # (Local) RolloutWorker instance. else: - fn(worker) + if policy is not None: + worker.add_policy( + policy_id=policy_id, + policy=policy, + policies_to_train=policies_to_train, + policy_mapping_fn=policy_mapping_fn, + ) + else: + fn(worker) ray.get(ray_gets) + # Add to all RolloutWorkers within `self.workers`. else: + # Policy is provided as an instance -> Add this very instance to local + # worker. + if policy is not None: + local_worker.add_policy( + policy_id=policy_id, + policy=policy, + policies_to_train=policies_to_train, + policy_mapping_fn=policy_mapping_fn, + ) + # Then add a new instance to each remote worker. + ray.get([w.apply.remote(fn) for w in self.workers.remote_workers()]) # Run foreach_worker fn on all workers. - self.workers.foreach_worker(fn) + else: + self.workers.foreach_worker(fn) # Update evaluation workers, if necessary. if evaluation_workers and self.evaluation_workers is not None: diff --git a/rllib/algorithms/tests/test_algorithm.py b/rllib/algorithms/tests/test_algorithm.py index ae45389effe3..231a5bac4f33 100644 --- a/rllib/algorithms/tests/test_algorithm.py +++ b/rllib/algorithms/tests/test_algorithm.py @@ -52,40 +52,54 @@ def test_validate_config_idempotent(self): algo.stop() def test_add_delete_policy(self): - config = pg.DEFAULT_CONFIG.copy() - config.update( - { - "env": MultiAgentCartPole, - "env_config": { - "config": { - "num_agents": 4, - }, + config = pg.PGConfig() + config.environment( + env=MultiAgentCartPole, + env_config={ + "config": { + "num_agents": 4, }, - "num_workers": 2, # Test on remote workers as well. + }, + ).rollouts(num_rollout_workers=2, rollout_fragment_length=50).resources( + num_cpus_per_worker=0.1 + ).training( + train_batch_size=100, + ).multi_agent( + # Start with a single policy. + policies={"p0"}, + policy_mapping_fn=lambda aid, eps, worker, **kwargs: "p0", + # And only two policies that can be stored in memory at a + # time. + policy_map_capacity=2, + ).evaluation( + evaluation_num_workers=1, + evaluation_config={ "num_cpus_per_worker": 0.1, - "model": { - "fcnet_hiddens": [5], - "fcnet_activation": "linear", - }, - "train_batch_size": 100, - "rollout_fragment_length": 50, - "multiagent": { - # Start with a single policy. - "policies": {"p0"}, - "policy_mapping_fn": lambda aid, eps, worker, **kwargs: "p0", - # And only two policies that can be stored in memory at a - # time. - "policy_map_capacity": 2, - }, - "evaluation_num_workers": 1, - "evaluation_config": { - "num_cpus_per_worker": 0.1, - }, + }, + ) + # Don't override existing model settings. + config.model.update( + { + "fcnet_hiddens": [5], + "fcnet_activation": "linear", } ) - for _ in framework_iterator(config): - algo = pg.PG(config=config) + obs_space = gym.spaces.Box(-2.0, 2.0, (4,)) + act_space = gym.spaces.Discrete(2) + + for fw in framework_iterator(config): + # Pre-generate a policy instance to test adding these directly to an + # existing algorithm. + if fw == "tf": + policy_obj = pg.PGTF1Policy(obs_space, act_space, config.to_dict()) + elif fw == "tf2": + policy_obj = pg.PGTF2Policy(obs_space, act_space, config.to_dict()) + else: + policy_obj = pg.PGTorchPolicy(obs_space, act_space, config.to_dict()) + + # Construct the Algorithm with a single policy in it. + algo = config.build() pol0 = algo.get_policy("p0") r = algo.train() self.assertTrue("p0" in r["info"][LEARNER_INFO]) @@ -94,16 +108,30 @@ def test_add_delete_policy(self): def new_mapping_fn(agent_id, episode, worker, **kwargs): return f"p{choice([i, i - 1])}" - # Add a new policy. + # Add a new policy either by class (and options) or by instance. pid = f"p{i}" - new_pol = algo.add_policy( - pid, - algo.get_default_policy_class(config), - # Test changing the mapping fn. - policy_mapping_fn=new_mapping_fn, - # Change the list of policies to train. - policies_to_train=[f"p{i}", f"p{i-1}"], - ) + print(f"Adding policy {pid} ...") + # By instance. + if i == 2: + new_pol = algo.add_policy( + pid, + # Pass in an already existing policy instance. + policy=policy_obj, + # Test changing the mapping fn. + policy_mapping_fn=new_mapping_fn, + # Change the list of policies to train. + policies_to_train=[f"p{i}", f"p{i - 1}"], + ) + # By class (and options). + else: + new_pol = algo.add_policy( + pid, + algo.get_default_policy_class(config.to_dict()), + # Test changing the mapping fn. + policy_mapping_fn=new_mapping_fn, + # Change the list of policies to train. + policies_to_train=[f"p{i}", f"p{i-1}"], + ) # Make sure new policy is part of remote workers in the # worker set and the eval worker set. assert pid in ( diff --git a/rllib/connectors/util.py b/rllib/connectors/util.py index 3cbe92addc00..aec37d8bcd2e 100644 --- a/rllib/connectors/util.py +++ b/rllib/connectors/util.py @@ -78,8 +78,10 @@ def create_connectors_for_policy(policy: "Policy", config: TrainerConfigDict): """ ctx: ConnectorContext = ConnectorContext.from_policy(policy) - policy.agent_connectors = get_agent_connectors_from_config(ctx, config) - policy.action_connectors = get_action_connectors_from_config(ctx, config) + if policy.agent_connectors is not None: + policy.agent_connectors = get_agent_connectors_from_config(ctx, config) + if policy.action_connectors is not None: + policy.action_connectors = get_action_connectors_from_config(ctx, config) logger.info("Using connectors:") logger.info(policy.agent_connectors.__str__(indentation=4)) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 550cefefc0c2..2c475395937a 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -1179,9 +1179,10 @@ def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Optional[Policy @DeveloperAPI def add_policy( self, - *, policy_id: PolicyID, - policy_cls: Type[Policy], + policy_cls: Optional[Type[Policy]] = None, + policy: Optional[Policy] = None, + *, observation_space: Optional[Space] = None, action_space: Optional[Space] = None, config: Optional[PartialAlgorithmConfigDict] = None, @@ -1195,8 +1196,10 @@ def add_policy( Args: policy_id: ID of the policy to add. - policy_cls: The Policy class to use for constructing the new - Policy. + policy_cls: The Policy class to use for constructing the new Policy. + Note: Only one of `policy_cls` or `policy` must be provided. + policy: The Policy instance to add to this algorithm. + Note: Only one of `policy_cls` or `policy` must be provided. observation_space: The observation space of the policy to add. action_space: The action space of the policy to add. config: The config overrides for the policy to add. @@ -1217,24 +1220,50 @@ def add_policy( The newly added policy. Raises: + ValueError: If both `policy_cls` AND `policy` are provided. KeyError: If the given `policy_id` already exists in this worker's PolicyMap. """ if policy_id in self.policy_map: - raise KeyError(f"Policy ID '{policy_id}' already in policy map!") - policy_dict_to_add = _determine_spaces_for_multi_agent_dict( - { + raise KeyError( + f"Policy ID '{policy_id}' already exists in policy map! " + "Make sure you use a Policy ID that has not been taken yet." + " Policy IDs that are already in your policy map: " + f"{list(self.workers.local_worker().policy_map.keys())}" + ) + if policy_cls is not None and policy is not None: + raise ValueError( + "Only one of `policy_cls` or `policy` must be provided to " + "RolloutWorker.add_policy()!" + ) + + if policy is None: + policy_dict_to_add = _determine_spaces_for_multi_agent_dict( + { + policy_id: PolicySpec( + policy_cls, observation_space, action_space, config or {} + ) + }, + self.env, + spaces=self.spaces, + policy_config=self.policy_config, + ) + else: + policy_dict_to_add = { policy_id: PolicySpec( - policy_cls, observation_space, action_space, config or {} + type(policy), + policy.observation_space, + policy.action_space, + policy.config, ) - }, - self.env, - spaces=self.spaces, - policy_config=self.policy_config, - ) + } + self.policy_dict.update(policy_dict_to_add) self._build_policy_map( - policy_dict_to_add, self.policy_config, seed=self.policy_config.get("seed") + policy_dict=policy_dict_to_add, + policy_config=self.policy_config, + policy=policy, + seed=self.policy_config.get("seed"), ) new_policy = self.policy_map[policy_id] # Set the state of the newly created policy. @@ -1752,6 +1781,7 @@ def _build_policy_map( self, policy_dict: MultiAgentPolicyConfigDict, policy_config: PartialAlgorithmConfigDict, + policy: Optional[Policy] = None, session_creator: Optional[Callable[[], "tf1.Session"]] = None, seed: Optional[int] = None, ) -> None: @@ -1763,6 +1793,7 @@ def _build_policy_map( policy_config: The general policy config to use. May be updated by individual policy config overrides in the given multi-agent `policy_dict`. + policy: If the policy to add already exists, user can provide it here. session_creator: A callable that creates a tf session (if applicable). seed: An optional random seed to pass to PolicyMap's @@ -1816,15 +1847,18 @@ def _build_policy_map( # the running of these preprocessors. self.preprocessors[name] = preprocessor - # Create the actual policy object. - self.policy_map.create_policy( - name, - policy_spec.policy_class, - obs_space, - policy_spec.action_space, - policy_spec.config, # overrides. - merged_conf, - ) + if policy is not None: + self.policy_map.insert_policy(name, policy) + else: + # Create the actual policy object. + self.policy_map.create_policy( + name, + policy_spec.policy_class, + obs_space, + policy_spec.action_space, + policy_spec.config, # overrides. + merged_conf, + ) if connectors_enabled and name in self.policy_map: create_connectors_for_policy(self.policy_map[name], policy_config) diff --git a/rllib/policy/policy_map.py b/rllib/policy/policy_map.py index 85555a0a63b9..6e109ff3dbd6 100644 --- a/rllib/policy/policy_map.py +++ b/rllib/policy/policy_map.py @@ -2,10 +2,10 @@ import gym import os import threading -from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Type +from typing import Callable, Dict, Optional, Set, Type import ray.cloudpickle as pickle -from ray.rllib.policy.policy import PolicySpec +from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.utils.annotations import PublicAPI, override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.policy import create_policy_for_framework @@ -18,9 +18,6 @@ ) from ray.tune.utils.util import merge_dicts -if TYPE_CHECKING: - from ray.rllib.policy.policy import Policy - tf1, tf, tfv = try_import_tf() @@ -91,6 +88,21 @@ def __init__( # and the underlying structures, like self.deque and others. self._lock = threading.RLock() + def insert_policy( + self, policy_id: PolicyID, policy: Policy, config_override=None + ) -> None: + self[policy_id] = policy + + # Store spec (class, obs-space, act-space, and config overrides) such + # that the map will be able to reproduce on-the-fly added policies + # from disk. + self.policy_specs[policy_id] = PolicySpec( + policy_class=type(policy), + observation_space=policy.observation_space, + action_space=policy.action_space, + config=config_override if config_override is not None else policy.config, + ) + def create_policy( self, policy_id: PolicyID, @@ -119,7 +131,7 @@ def create_policy( """ _class = get_tf_eager_cls_if_necessary(policy_cls, merged_config) - self[policy_id] = create_policy_for_framework( + policy = create_policy_for_framework( policy_id, _class, merged_config, @@ -129,16 +141,7 @@ def create_policy( self.session_creator, self.seed, ) - - # Store spec (class, obs-space, act-space, and config overrides) such - # that the map will be able to reproduce on-the-fly added policies - # from disk. - self.policy_specs[policy_id] = PolicySpec( - policy_class=policy_cls, - observation_space=observation_space, - action_space=action_space, - config=config_override, - ) + self.insert_policy(policy_id, policy, config_override) @with_lock @override(dict) From 8fde067f02fa45c2a129ec9114d6e5c254208fce Mon Sep 17 00:00:00 2001 From: sven1977 Date: Tue, 20 Sep 2022 19:14:45 +0200 Subject: [PATCH 2/8] merge Signed-off-by: sven1977 --- release/ray_release/cluster_manager/cluster_manager.py | 1 + release/setup.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/release/ray_release/cluster_manager/cluster_manager.py b/release/ray_release/cluster_manager/cluster_manager.py index d1ad14e1d080..69606084fd52 100644 --- a/release/ray_release/cluster_manager/cluster_manager.py +++ b/release/ray_release/cluster_manager/cluster_manager.py @@ -52,6 +52,7 @@ def set_cluster_env(self, cluster_env: Dict[str, Any]): self.cluster_env["env_vars"]["RAY_USAGE_STATS_ENABLED"] = "1" self.cluster_env["env_vars"]["RAY_USAGE_STATS_SOURCE"] = "nightly-tests" self.cluster_env["env_vars"]["RAY_memory_monitor_interval_ms"] = "250" + self.cluster_env["env_vars"]["RAY_memory_usage_threshold_fraction"] = "0.95" self.cluster_env["env_vars"][ "RAY_USAGE_STATS_EXTRA_TAGS" ] = f"test_name={self.test_name};smoke_test={self.smoke_test}" diff --git a/release/setup.py b/release/setup.py index a10ed0fa7a14..9c5ded69673e 100644 --- a/release/setup.py +++ b/release/setup.py @@ -12,5 +12,5 @@ author="Ray Team", description="The Ray OSS release testing package", url="https://github.com/ray-project/ray", - install_requires=["ray>=1.9", "click", "anyscale", "boto3", "freezegun"], + install_requires=["ray>=1.9", "click", "anyscale", "boto3", "freezegun", "retry"], ) From c8afcf392e9b0051a4bb59d97e803e83a0cb6cf7 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Tue, 20 Sep 2022 21:14:25 +0200 Subject: [PATCH 3/8] wip Signed-off-by: sven1977 --- rllib/connectors/util.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/rllib/connectors/util.py b/rllib/connectors/util.py index 4074593de7a8..7ccb4ce18956 100644 --- a/rllib/connectors/util.py +++ b/rllib/connectors/util.py @@ -78,14 +78,14 @@ def create_connectors_for_policy(policy: "Policy", config: TrainerConfigDict): """ ctx: ConnectorContext = ConnectorContext.from_policy(policy) - if policy.agent_connectors is not None: + if policy.agent_connectors is None: policy.agent_connectors = get_agent_connectors_from_config(ctx, config) - logger.info("Using agent connector:") - logger.info(policy.agent_connectors.__str__(indentation=4)) - if policy.action_connectors is not None: + if policy.action_connectors is None: policy.action_connectors = get_action_connectors_from_config(ctx, config) - logger.info("Using action connector:") - logger.info(policy.action_connectors.__str__(indentation=4)) + + logger.info("Using connectors:") + logger.info(policy.agent_connectors.__str__(indentation=4)) + logger.info(policy.action_connectors.__str__(indentation=4)) @PublicAPI(stability="alpha") From f230ffc63ffe624abed61e4b95ac7eb855479898 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Wed, 21 Sep 2022 14:44:59 +0200 Subject: [PATCH 4/8] wip Signed-off-by: sven1977 --- rllib/algorithms/algorithm.py | 47 +++++++++++++++++------------- rllib/evaluation/rollout_worker.py | 4 +-- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index b9991863b5ad..7e5361bf21e5 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -1614,7 +1614,7 @@ def add_policy( f"{list(local_worker.policy_map.keys())}" ) - if policy_cls is not None and policy is not None: + if (policy_cls is None) == (policy is None): raise ValueError( "Only one of `policy_cls` or `policy` must be provided to " "Algorithm.add_policy()!" @@ -1622,7 +1622,7 @@ def add_policy( # Policy instance not provided: Use the information given here. if policy_cls is not None: - kwargs = dict( + new_policy_instance_kwargs = dict( policy_id=policy_id, policy_cls=policy_cls, observation_space=observation_space, @@ -1638,7 +1638,7 @@ def add_policy( # workers (copy all its properties here for the calls to add_policy on the # remote workers). else: - kwargs = dict( + new_policy_instance_kwargs = dict( policy_id=policy_id, policy_cls=type(policy), observation_space=policy.observation_space, @@ -1651,29 +1651,30 @@ def add_policy( else None, ) - def fn(worker: RolloutWorker): + def _create_new_policy_fn(worker: RolloutWorker): # `foreach_worker` function: Adds the policy the the worker (and # maybe changes its policy_mapping_fn - if provided here). - worker.add_policy(**kwargs) + worker.add_policy(**new_policy_instance_kwargs) # Workers to add the policy to are given as an explicit list. if workers is not None: ray_gets = [] for worker in workers: + # Existing policy AND local worker: Add Policy as-is. + if policy is not None and not isinstance(worker, ActorHandle): + worker.add_policy( + policy_id=policy_id, + policy=policy, + policies_to_train=policies_to_train, + policy_mapping_fn=policy_mapping_fn, + ) # A remote worker (ray actor). - if isinstance(worker, ActorHandle): - ray_gets.append(worker.add_policy.remote(**kwargs)) + elif isinstance(worker, ActorHandle): + ray_gets.append(worker.apply.remote(_create_new_policy_fn)) # (Local) RolloutWorker instance. else: - if policy is not None: - worker.add_policy( - policy_id=policy_id, - policy=policy, - policies_to_train=policies_to_train, - policy_mapping_fn=policy_mapping_fn, - ) - else: - fn(worker) + worker.add_policy(**new_policy_instance_kwargs) + ray.get(ray_gets) # Add to all RolloutWorkers within `self.workers`. else: @@ -1687,14 +1688,20 @@ def fn(worker: RolloutWorker): policy_mapping_fn=policy_mapping_fn, ) # Then add a new instance to each remote worker. - ray.get([w.apply.remote(fn) for w in self.workers.remote_workers()]) - # Run foreach_worker fn on all workers. + ray.get( + [ + w.apply.remote(_create_new_policy_fn) + for w in self.workers.remote_workers() + ] + ) + # Run foreach_worker fn on all workers (incl. local one) + # to create new Policies on each. else: - self.workers.foreach_worker(fn) + self.workers.foreach_worker(_create_new_policy_fn) # Update evaluation workers, if necessary. if evaluation_workers and self.evaluation_workers is not None: - self.evaluation_workers.foreach_worker(fn) + self.evaluation_workers.foreach_worker(_create_new_policy_fn) # Return newly added policy (from the local rollout worker). return self.get_policy(policy_id) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 2c475395937a..db83ce03aa61 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -1229,9 +1229,9 @@ def add_policy( f"Policy ID '{policy_id}' already exists in policy map! " "Make sure you use a Policy ID that has not been taken yet." " Policy IDs that are already in your policy map: " - f"{list(self.workers.local_worker().policy_map.keys())}" + f"{list(self.policy_map.keys())}" ) - if policy_cls is not None and policy is not None: + if (policy_cls is None) == (policy is None): raise ValueError( "Only one of `policy_cls` or `policy` must be provided to " "RolloutWorker.add_policy()!" From a01ccb2dc4af9537e87efc1df9ee16daf8f52bcb Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 23 Sep 2022 09:56:14 +0200 Subject: [PATCH 5/8] wip Signed-off-by: sven1977 --- rllib/algorithms/algorithm.py | 199 ++++++++++++------------------ rllib/evaluation/worker_set.py | 217 ++++++++++++++++++++++++++++++++- 2 files changed, 292 insertions(+), 124 deletions(-) diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 7e5361bf21e5..46fc16137a4f 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -231,7 +231,7 @@ class directly. Note that this arg can also be specified via """ # User provided (partial) config (this may be w/o the default - # Trainer's Config object). Will get merged with AlgorithmConfig() + # Algorithm's Config object). Will get merged with AlgorithmConfig() # in self.setup(). config = config or {} # Resolve AlgorithmConfig into a plain dict. @@ -343,7 +343,7 @@ def setup(self, config: PartialAlgorithmConfigDict): # Validate the framework settings in config. self.validate_framework(self.config) - # Set Trainer's seed after we have - if necessary - enabled + # Set Algorithm's seed after we have - if necessary - enabled # tf eager-execution. update_global_seed_if_necessary(self.config["framework"], self.config["seed"]) @@ -368,7 +368,7 @@ def setup(self, config: PartialAlgorithmConfigDict): # Create a dict, mapping ActorHandles to sets of open remote # requests (object refs). This way, we keep track, of which actors - # inside this Trainer (e.g. a remote RolloutWorker) have + # inside this Algorithm (e.g. a remote RolloutWorker) have # already been sent how many (e.g. `sample()`) requests. self.remote_requests_in_flight: DefaultDict[ ActorHandle, Set[ray.ObjectRef] @@ -1562,16 +1562,20 @@ def add_policy( ] ] = None, evaluation_workers: bool = True, - workers: Optional[List[Union[RolloutWorker, ActorHandle]]] = None, - ) -> Policy: - """Adds a new policy to this Trainer. + worker_list: Optional[List[Union[RolloutWorker, ActorHandle]]] = None, + # Deprecated args: + workers=None, + ) -> Optional[Policy]: + """Adds a new policy to this Algorithm. Args: policy_id: ID of the policy to add. - policy_cls: The Policy class to use for - constructing the new Policy. + policy_cls: The Policy class to use for constructing the new Policy. Note: Only one of `policy_cls` or `policy` must be provided. - policy: The Policy instance to add to this algorithm. + policy: The Policy instance to add to this algorithm. If not None, the + given Policy object will be directly inserted into the Algorithm's + local worker and clones of that Policy will be created on all remote + workers as well as all evaluation workers. Note: Only one of `policy_cls` or `policy` must be provided. observation_space: The observation space of the policy to add. If None, try to infer this space from the environment. @@ -1592,116 +1596,69 @@ def add_policy( returns False) will not be updated. evaluation_workers: Whether to add the new policy also to the evaluation WorkerSet. - workers: A list of RolloutWorker/ActorHandles (remote + worker_list: A list of RolloutWorker/ActorHandles (remote RolloutWorkers) to add this policy to. If defined, will only add the given policy to these workers. Returns: The newly added policy (the copy that got added to the local - worker). + worker). If `worker_list` was provided, None is returned. Raises: ValueError: If both `policy_cls` AND `policy` are provided. KeyError: If the given `policy_id` already exists in this Algorithm. """ - local_worker = self.workers.local_worker() - - if policy_id in local_worker.policy_map: - raise KeyError( - f"Policy ID '{policy_id}' already exists in policy map! " - "Make sure you use a Policy ID that has not been taken yet." - " Policy IDs that are already in your policy map: " - f"{list(local_worker.policy_map.keys())}" - ) - - if (policy_cls is None) == (policy is None): - raise ValueError( - "Only one of `policy_cls` or `policy` must be provided to " - "Algorithm.add_policy()!" + # Deprecated args. + if workers is not None: + deprecation_warning( + old="Algorithm.add_policy(.., workers=...)", + new="Algorithm.add_policy(.., worker_list=...)", + error=False, ) - - # Policy instance not provided: Use the information given here. - if policy_cls is not None: - new_policy_instance_kwargs = dict( - policy_id=policy_id, - policy_cls=policy_cls, + worker_list = workers + + # Worker list is explicitly provided -> Use only those workers (local or remote) + # specified. + if worker_list is not None: + RolloutWorker.add_policy_to_workers( + worker_list, + policy_id, + policy_cls, + policy, observation_space=observation_space, action_space=action_space, config=config, policy_state=policy_state, policy_mapping_fn=policy_mapping_fn, - policies_to_train=list(policies_to_train) - if policies_to_train - else None, + policies_to_train=policies_to_train, ) - # Policy instance provided: Create clones of this very policy on the different - # workers (copy all its properties here for the calls to add_policy on the - # remote workers). + # Add to all our regular RolloutWorkers and maybe also all evaluation workers. else: - new_policy_instance_kwargs = dict( - policy_id=policy_id, - policy_cls=type(policy), - observation_space=policy.observation_space, - action_space=policy.action_space, - config=policy.config, - policy_state=policy.get_state(), + self.workers.add_policy( + policy_id, + policy_cls, + policy, + observation_space=observation_space, + action_space=action_space, + config=config, + policy_state=policy_state, policy_mapping_fn=policy_mapping_fn, - policies_to_train=list(policies_to_train) - if policies_to_train - else None, + policies_to_train=policies_to_train, ) - def _create_new_policy_fn(worker: RolloutWorker): - # `foreach_worker` function: Adds the policy the the worker (and - # maybe changes its policy_mapping_fn - if provided here). - worker.add_policy(**new_policy_instance_kwargs) - - # Workers to add the policy to are given as an explicit list. - if workers is not None: - ray_gets = [] - for worker in workers: - # Existing policy AND local worker: Add Policy as-is. - if policy is not None and not isinstance(worker, ActorHandle): - worker.add_policy( - policy_id=policy_id, - policy=policy, - policies_to_train=policies_to_train, - policy_mapping_fn=policy_mapping_fn, - ) - # A remote worker (ray actor). - elif isinstance(worker, ActorHandle): - ray_gets.append(worker.apply.remote(_create_new_policy_fn)) - # (Local) RolloutWorker instance. - else: - worker.add_policy(**new_policy_instance_kwargs) - - ray.get(ray_gets) - # Add to all RolloutWorkers within `self.workers`. - else: - # Policy is provided as an instance -> Add this very instance to local - # worker. - if policy is not None: - local_worker.add_policy( - policy_id=policy_id, - policy=policy, - policies_to_train=policies_to_train, + # Add to evaluation workers, if necessary. + if evaluation_workers and self.evaluation_workers is not None: + self.evaluation_workers.add_policy( + policy_id, + policy_cls, + policy, + observation_space=observation_space, + action_space=action_space, + config=config, + policy_state=policy_state, policy_mapping_fn=policy_mapping_fn, + policies_to_train=policies_to_train, ) - # Then add a new instance to each remote worker. - ray.get( - [ - w.apply.remote(_create_new_policy_fn) - for w in self.workers.remote_workers() - ] - ) - # Run foreach_worker fn on all workers (incl. local one) - # to create new Policies on each. - else: - self.workers.foreach_worker(_create_new_policy_fn) - - # Update evaluation workers, if necessary. - if evaluation_workers and self.evaluation_workers is not None: - self.evaluation_workers.foreach_worker(_create_new_policy_fn) # Return newly added policy (from the local rollout worker). return self.get_policy(policy_id) @@ -1720,7 +1677,7 @@ def remove_policy( ] = None, evaluation_workers: bool = True, ) -> None: - """Removes a new policy from this Trainer. + """Removes a new policy from this Algorithm. Args: policy_id: ID of the policy to be removed. @@ -1767,12 +1724,12 @@ def export_policy_model( Example: >>> from ray.rllib.algorithms.ppo import PPO - >>> # Use a Trainer from RLlib or define your own. - >>> trainer = PPO(...) # doctest: +SKIP + >>> # Use an Algorithm from RLlib or define your own. + >>> algo = PPO(...) # doctest: +SKIP >>> for _ in range(10): # doctest: +SKIP - >>> trainer.train() # doctest: +SKIP - >>> trainer.export_policy_model("/tmp/dir") # doctest: +SKIP - >>> trainer.export_policy_model("/tmp/dir/onnx", onnx=1) # doctest: +SKIP + >>> algo.train() # doctest: +SKIP + >>> algo.export_policy_model("/tmp/dir") # doctest: +SKIP + >>> algo.export_policy_model("/tmp/dir/onnx", onnx=1) # doctest: +SKIP """ self.get_policy(policy_id).export_model(export_dir, onnx) @@ -1792,11 +1749,11 @@ def export_policy_checkpoint( Example: >>> from ray.rllib.algorithms.ppo import PPO - >>> # Use a Trainer from RLlib or define your own. - >>> trainer = PPO(...) # doctest: +SKIP + >>> # Use an Algorithm from RLlib or define your own. + >>> algo = PPO(...) # doctest: +SKIP >>> for _ in range(10): # doctest: +SKIP - >>> trainer.train() # doctest: +SKIP - >>> trainer.export_policy_checkpoint("/tmp/export_dir") # doctest: +SKIP + >>> algo.train() # doctest: +SKIP + >>> algo.export_policy_checkpoint("/tmp/export_dir") # doctest: +SKIP """ self.get_policy(policy_id).export_checkpoint(export_dir, filename_prefix) @@ -1814,10 +1771,10 @@ def import_policy_model_from_h5( Example: >>> from ray.rllib.algorithms.ppo import PPO - >>> trainer = PPO(...) # doctest: +SKIP - >>> trainer.import_policy_model_from_h5("/tmp/weights.h5") # doctest: +SKIP + >>> algo = PPO(...) # doctest: +SKIP + >>> algo.import_policy_model_from_h5("/tmp/weights.h5") # doctest: +SKIP >>> for _ in range(10): # doctest: +SKIP - >>> trainer.train() # doctest: +SKIP + >>> algo.train() # doctest: +SKIP """ self.get_policy(policy_id).import_model_from_h5(import_file) # Sync new weights to remote workers. @@ -1861,7 +1818,7 @@ def default_resource_request( cls, config: PartialAlgorithmConfigDict ) -> Union[Resources, PlacementGroupFactory]: - # Default logic for RLlib algorithms (Trainers): + # Default logic for RLlib Algorithms: # Create one bundle per individual worker (local or remote). # Use `num_cpus_for_driver` and `num_gpus` for the local worker and # `num_cpus_per_worker` and `num_gpus_per_worker` for the remote @@ -1928,7 +1885,7 @@ def _get_env_id_and_creator( Args: env_specifier: An env class, an already tune registered env ID, a known gym env name, or None (if no env is used). - config: The Trainer's (maybe partial) config dict. + config: The Algorithm's (maybe partial) config dict. Returns: Tuple consisting of a) env ID string and b) env creator callable. @@ -2050,13 +2007,13 @@ def merge_trainer_configs( config2: PartialAlgorithmConfigDict, _allow_unknown_configs: Optional[bool] = None, ) -> AlgorithmConfigDict: - """Merges a complete Trainer config with a partial override dict. + """Merges a complete Algorithm config dict with a partial override dict. Respects nested structures within the config dicts. The values in the partial override dict take priority. Args: - config1: The complete Trainer's dict to be merged (overridden) + config1: The complete Algorithm's dict to be merged (overridden) with `config2`. config2: The partial override config dict to merge on top of `config1`. @@ -2064,7 +2021,7 @@ def merge_trainer_configs( in `config1` are allowed and will be added to the final config. Returns: - The merged full trainer config dict. + The merged full algorithm config dict. """ config1 = copy.deepcopy(config1) if "callbacks" in config2 and type(config2["callbacks"]) is dict: @@ -2166,7 +2123,7 @@ def resolve_tf_settings(): @OverrideToImplementCustomLogic_CallToSuperRecommended @DeveloperAPI def validate_config(self, config: AlgorithmConfigDict) -> None: - """Validates a given config dict for this Trainer. + """Validates a given config dict for this Algorithm. Users should override this method to implement custom validation behavior. It is recommended to call `super().validate_config()` in @@ -2384,8 +2341,8 @@ def validate_config(self, config: AlgorithmConfigDict) -> None: f"You have specified {config['evaluation_num_workers']} " "evaluation workers, but your `evaluation_interval` is None! " "Therefore, evaluation will not occur automatically with each" - " call to `Trainer.train()`. Instead, you will have to call " - "`Trainer.evaluate()` manually in order to trigger an " + " call to `Algorithm.train()`. Instead, you will have to call " + "`Algorithm.evaluate()` manually in order to trigger an " "evaluation run." ) # If `evaluation_num_workers=0` and @@ -2423,7 +2380,7 @@ def validate_config(self, config: AlgorithmConfigDict) -> None: @staticmethod @ExperimentalAPI def validate_env(env: EnvType, env_context: EnvContext) -> None: - """Env validator function for this Trainer class. + """Env validator function for this Algorithm class. Override this in child classes to define custom validation behavior. @@ -2621,7 +2578,7 @@ def _create_local_replay_buffer_if_necessary( config: Algorithm-specific configuration data. Returns: - MultiAgentReplayBuffer instance based on trainer config. + MultiAgentReplayBuffer instance based on algorithm config. None, if local replay buffer is not needed. """ if not config.get("replay_buffer_config") or config["replay_buffer_config"].get( @@ -2916,7 +2873,7 @@ def _record_usage(self, config): alg = "USER_DEFINED" record_extra_usage_tag(TagKey.RLLIB_ALGORITHM, alg) - @Deprecated(new="Trainer.compute_single_action()", error=False) + @Deprecated(new="Algorithm.compute_single_action()", error=False) def compute_action(self, *args, **kwargs): return self.compute_single_action(*args, **kwargs) @@ -2946,7 +2903,7 @@ def _make_workers( ) @staticmethod - @Deprecated(new="Trainer.validate_config()", error=False) + @Deprecated(new="Algorithm.validate_config()", error=False) def _validate_config(config, trainer_or_none): assert trainer_or_none is not None return trainer_or_none.validate_config(config) diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index e597ab15ee15..94a2184bc34a 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -3,7 +3,17 @@ import importlib.util import os from types import FunctionType -from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import ( + Callable, + Container, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import ray from ray.actor import ActorHandle @@ -22,19 +32,22 @@ DatasetWriter, get_dataset_and_shards, ) -from ray.rllib.policy.policy import Policy, PolicySpec +from ray.rllib.policy.policy import Policy, PolicySpec, PolicyState from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.from_config import from_config from ray.rllib.utils.typing import ( + AgentID, + AlgorithmConfigDict, EnvCreator, EnvType, + EpisodeID, + PartialAlgorithmConfigDict, PolicyID, SampleBatchType, TensorType, - AlgorithmConfigDict, ) from ray.tune.registry import registry_contains_input, registry_get_input @@ -231,6 +244,204 @@ def sync_weights( elif self.local_worker() is not None and global_vars is not None: self.local_worker().set_global_vars(global_vars) + def add_policy( + self, + policy_id: PolicyID, + policy_cls: Optional[Type[Policy]] = None, + policy: Optional[Policy] = None, + *, + observation_space: Optional[gym.spaces.Space] = None, + action_space: Optional[gym.spaces.Space] = None, + config: Optional[PartialAlgorithmConfigDict] = None, + policy_state: Optional[PolicyState] = None, + policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None, + policies_to_train: Optional[ + Union[ + Container[PolicyID], + Callable[[PolicyID, Optional[SampleBatchType]], bool], + ] + ] = None, + worker_list: Optional[List[Union[RolloutWorker, ActorHandle]]] = None, + ) -> None: + """Adds a policy to this WorkerSet's workers or a specific list of workers. + + Args: + policy_id: ID of the policy to add. + policy_cls: The Policy class to use for constructing the new Policy. + Note: Only one of `policy_cls` or `policy` must be provided. + policy: The Policy instance to add to this WorkerSet. If not None, the + given Policy object will be directly inserted into the + local worker and clones of that Policy will be created on all remote + workers. + Note: Only one of `policy_cls` or `policy` must be provided. + observation_space: The observation space of the policy to add. + If None, try to infer this space from the environment. + action_space: The action space of the policy to add. + If None, try to infer this space from the environment. + config: The config overrides for the policy to add. + policy_state: Optional state dict to apply to the new + policy instance, right after its construction. + policy_mapping_fn: An optional (updated) policy mapping function + to use from here on. Note that already ongoing episodes will + not change their mapping but will use the old mapping till + the end of the episode. + policies_to_train: An optional list of policy IDs to be trained + or a callable taking PolicyID and SampleBatchType and + returning a bool (trainable or not?). + If None, will keep the existing setup in place. Policies, + whose IDs are not in the list (or for which the callable + returns False) will not be updated. + worker_list: A list of RolloutWorker/ActorHandles (remote + RolloutWorkers) to add this policy to. If defined, will only + add the given policy to these workers. + + Raises: + KeyError: If the given `policy_id` already exists in this WorkerSet. + """ + if ( + worker_list is None + and self.local_worker() + and policy_id in self.local_worker().policy_map + ): + raise KeyError( + f"Policy ID '{policy_id}' already exists in policy map! " + "Make sure you use a Policy ID that has not been taken yet." + " Policy IDs that are already in your policy map: " + f"{list(self.local_worker().policy_map.keys())}" + ) + + if worker_list is None: + worker_list = ( + [self.local_worker()] if self.local_worker() else [] + ) + self.remote_workers() + + self.add_policy_to_workers( + worker_list=worker_list, + policy_id=policy_id, + policy_cls=policy_cls, + policy=policy, + observation_space=observation_space, + action_space=action_space, + config=config, + policy_state=policy_state, + policy_mapping_fn=policy_mapping_fn, + policies_to_train=policies_to_train, + ) + + @staticmethod + def add_policy_to_workers( + worker_list: List[Union[RolloutWorker, ActorHandle]], + policy_id: PolicyID, + policy_cls: Optional[Type[Policy]] = None, + policy: Optional[Policy] = None, + *, + observation_space: Optional[gym.spaces.Space] = None, + action_space: Optional[gym.spaces.Space] = None, + config: Optional[PartialAlgorithmConfigDict] = None, + policy_state: Optional[PolicyState] = None, + policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None, + policies_to_train: Optional[ + Union[ + Container[PolicyID], + Callable[[PolicyID, Optional[SampleBatchType]], bool], + ] + ] = None, + ) -> None: + """Adds a new policy to a specific list of RolloutWorkers (or remote actors). + + Args: + worker_list: A list of RolloutWorker/ActorHandles (remote + RolloutWorkers) to add this policy to. + policy_id: ID of the policy to add. + policy_cls: The Policy class to use for constructing the new Policy. + Note: Only one of `policy_cls` or `policy` must be provided. + policy: The Policy instance to add to this WorkerSet. If not None, the + given Policy object will be directly inserted into the + local worker and clones of that Policy will be created on all remote + workers. + Note: Only one of `policy_cls` or `policy` must be provided. + observation_space: The observation space of the policy to add. + If None, try to infer this space from the environment. + action_space: The action space of the policy to add. + If None, try to infer this space from the environment. + config: The config overrides for the policy to add. + policy_state: Optional state dict to apply to the new + policy instance, right after its construction. + policy_mapping_fn: An optional (updated) policy mapping function + to use from here on. Note that already ongoing episodes will + not change their mapping but will use the old mapping till + the end of the episode. + policies_to_train: An optional list of policy IDs to be trained + or a callable taking PolicyID and SampleBatchType and + returning a bool (trainable or not?). + If None, will keep the existing setup in place. Policies, + whose IDs are not in the list (or for which the callable + returns False) will not be updated. + + Raises: + ValueError: If both `policy_cls` AND `policy` are provided. + """ + if (policy_cls is None) == (policy is None): + raise ValueError( + "Only one of `policy_cls` or `policy` must be provided to " + "Algorithm.add_policy()!" + ) + + # Policy instance not provided: Use the information given here. + if policy_cls is not None: + new_policy_instance_kwargs = dict( + policy_id=policy_id, + policy_cls=policy_cls, + observation_space=observation_space, + action_space=action_space, + config=config, + policy_state=policy_state, + policy_mapping_fn=policy_mapping_fn, + policies_to_train=list(policies_to_train) + if policies_to_train + else None, + ) + # Policy instance provided: Create clones of this very policy on the different + # workers (copy all its properties here for the calls to add_policy on the + # remote workers). + else: + new_policy_instance_kwargs = dict( + policy_id=policy_id, + policy_cls=type(policy), + observation_space=policy.observation_space, + action_space=policy.action_space, + config=policy.config, + policy_state=policy.get_state(), + policy_mapping_fn=policy_mapping_fn, + policies_to_train=list(policies_to_train) + if policies_to_train + else None, + ) + + def _create_new_policy_fn(worker: RolloutWorker): + # `foreach_worker` function: Adds the policy the the worker (and + # maybe changes its policy_mapping_fn - if provided here). + worker.add_policy(**new_policy_instance_kwargs) + + ray_gets = [] + for worker in worker_list: + # Existing policy AND local worker: Add Policy as-is. + if policy is not None and not isinstance(worker, ActorHandle): + worker.add_policy( + policy_id=policy_id, + policy=policy, + policies_to_train=policies_to_train, + policy_mapping_fn=policy_mapping_fn, + ) + # A remote worker (ray actor). + elif isinstance(worker, ActorHandle): + ray_gets.append(worker.apply.remote(_create_new_policy_fn)) + # (Local) RolloutWorker instance. + else: + worker.add_policy(**new_policy_instance_kwargs) + + ray.get(ray_gets) + def add_workers(self, num_workers: int, validate: bool = False) -> None: """Creates and adds a number of remote workers to this worker set. From 447ba028a4176882781371a49a1980ba203da783 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 23 Sep 2022 10:02:35 +0200 Subject: [PATCH 6/8] wip Signed-off-by: sven1977 --- rllib/connectors/util.py | 8 ++++---- rllib/evaluation/rollout_worker.py | 14 +++++--------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/rllib/connectors/util.py b/rllib/connectors/util.py index 7ccb4ce18956..639ada8bda35 100644 --- a/rllib/connectors/util.py +++ b/rllib/connectors/util.py @@ -78,10 +78,10 @@ def create_connectors_for_policy(policy: "Policy", config: TrainerConfigDict): """ ctx: ConnectorContext = ConnectorContext.from_policy(policy) - if policy.agent_connectors is None: - policy.agent_connectors = get_agent_connectors_from_config(ctx, config) - if policy.action_connectors is None: - policy.action_connectors = get_action_connectors_from_config(ctx, config) + assert policy.agent_connectors is None and policy.agent_connectors is None + + policy.agent_connectors = get_agent_connectors_from_config(ctx, config) + policy.action_connectors = get_action_connectors_from_config(ctx, config) logger.info("Using connectors:") logger.info(policy.agent_connectors.__str__(indentation=4)) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index db83ce03aa61..19d71550346b 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -1281,15 +1281,11 @@ def add_policy( self.filters[policy_id] = get_filter(self.observation_filter, filter_shape) - if ( - self.policy_config.get("enable_connectors") - and policy_id in self.policy_map - and not ( - self.policy_map[policy_id].agent_connectors - or self.policy_map[policy_id].action_connectors - ) - ): - create_connectors_for_policy(self.policy_map[policy_id], self.policy_config) + # Create connectors for the new policy, if necessary. + # Only if connectors are enables and we created the new policy from scratch + # (it was not provided to us via the `policy` arg. + if policy is None and self.policy_config.get("enable_connectors"): + create_connectors_for_policy(new_policy, self.policy_config) self.set_policy_mapping_fn(policy_mapping_fn) if policies_to_train is not None: From 93c51a191637e4a8f2194d1246d84c6414aa0ccf Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 26 Sep 2022 19:15:25 +0200 Subject: [PATCH 7/8] wip Signed-off-by: sven1977 --- rllib/algorithms/algorithm.py | 23 ++++++----------------- rllib/evaluation/worker_set.py | 20 +++++++++++--------- 2 files changed, 17 insertions(+), 26 deletions(-) diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 46fc16137a4f..03ce0d7aa5d0 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -1562,9 +1562,7 @@ def add_policy( ] ] = None, evaluation_workers: bool = True, - worker_list: Optional[List[Union[RolloutWorker, ActorHandle]]] = None, - # Deprecated args: - workers=None, + workers: Optional[List[Union[RolloutWorker, ActorHandle]]] = None, ) -> Optional[Policy]: """Adds a new policy to this Algorithm. @@ -1596,32 +1594,23 @@ def add_policy( returns False) will not be updated. evaluation_workers: Whether to add the new policy also to the evaluation WorkerSet. - worker_list: A list of RolloutWorker/ActorHandles (remote + workers: A list of RolloutWorker/ActorHandles (remote RolloutWorkers) to add this policy to. If defined, will only add the given policy to these workers. Returns: The newly added policy (the copy that got added to the local - worker). If `worker_list` was provided, None is returned. + worker). If `workers` was provided, None is returned. Raises: ValueError: If both `policy_cls` AND `policy` are provided. KeyError: If the given `policy_id` already exists in this Algorithm. """ - # Deprecated args. - if workers is not None: - deprecation_warning( - old="Algorithm.add_policy(.., workers=...)", - new="Algorithm.add_policy(.., worker_list=...)", - error=False, - ) - worker_list = workers - # Worker list is explicitly provided -> Use only those workers (local or remote) # specified. - if worker_list is not None: - RolloutWorker.add_policy_to_workers( - worker_list, + if workers is not None: + WorkerSet.add_policy_to_workers( + workers, policy_id, policy_cls, policy, diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 94a2184bc34a..129bdfaf44ee 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -261,7 +261,7 @@ def add_policy( Callable[[PolicyID, Optional[SampleBatchType]], bool], ] ] = None, - worker_list: Optional[List[Union[RolloutWorker, ActorHandle]]] = None, + workers: Optional[List[Union[RolloutWorker, ActorHandle]]] = None, ) -> None: """Adds a policy to this WorkerSet's workers or a specific list of workers. @@ -291,7 +291,7 @@ def add_policy( If None, will keep the existing setup in place. Policies, whose IDs are not in the list (or for which the callable returns False) will not be updated. - worker_list: A list of RolloutWorker/ActorHandles (remote + workers: A list of RolloutWorker/ActorHandles (remote RolloutWorkers) to add this policy to. If defined, will only add the given policy to these workers. @@ -299,7 +299,7 @@ def add_policy( KeyError: If the given `policy_id` already exists in this WorkerSet. """ if ( - worker_list is None + workers is None and self.local_worker() and policy_id in self.local_worker().policy_map ): @@ -310,13 +310,15 @@ def add_policy( f"{list(self.local_worker().policy_map.keys())}" ) - if worker_list is None: - worker_list = ( + # No `workers` arg provided: Compile list of workers automatically from + # all RolloutWorkers in this WorkerSet. + if workers is None: + workers = ( [self.local_worker()] if self.local_worker() else [] ) + self.remote_workers() self.add_policy_to_workers( - worker_list=worker_list, + workers=workers, policy_id=policy_id, policy_cls=policy_cls, policy=policy, @@ -330,7 +332,7 @@ def add_policy( @staticmethod def add_policy_to_workers( - worker_list: List[Union[RolloutWorker, ActorHandle]], + workers: List[Union[RolloutWorker, ActorHandle]], policy_id: PolicyID, policy_cls: Optional[Type[Policy]] = None, policy: Optional[Policy] = None, @@ -350,7 +352,7 @@ def add_policy_to_workers( """Adds a new policy to a specific list of RolloutWorkers (or remote actors). Args: - worker_list: A list of RolloutWorker/ActorHandles (remote + workers: A list of RolloutWorker/ActorHandles (remote RolloutWorkers) to add this policy to. policy_id: ID of the policy to add. policy_cls: The Policy class to use for constructing the new Policy. @@ -424,7 +426,7 @@ def _create_new_policy_fn(worker: RolloutWorker): worker.add_policy(**new_policy_instance_kwargs) ray_gets = [] - for worker in worker_list: + for worker in workers: # Existing policy AND local worker: Add Policy as-is. if policy is not None and not isinstance(worker, ActorHandle): worker.add_policy( From 4d42ac2f34a6e6db7b87b67d07f76ab509ac907c Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 26 Sep 2022 19:17:01 +0200 Subject: [PATCH 8/8] wip Signed-off-by: sven1977 --- rllib/algorithms/algorithm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 03ce0d7aa5d0..1e358b388799 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -1609,6 +1609,7 @@ def add_policy( # Worker list is explicitly provided -> Use only those workers (local or remote) # specified. if workers is not None: + # Call static utility method. WorkerSet.add_policy_to_workers( workers, policy_id, @@ -1636,7 +1637,7 @@ def add_policy( ) # Add to evaluation workers, if necessary. - if evaluation_workers and self.evaluation_workers is not None: + if evaluation_workers is True and self.evaluation_workers is not None: self.evaluation_workers.add_policy( policy_id, policy_cls,