diff --git a/doc/source/rllib/rllib-env.rst b/doc/source/rllib/rllib-env.rst index 1f1ec9beb57b..06411383b4c0 100644 --- a/doc/source/rllib/rllib-env.rst +++ b/doc/source/rllib/rllib-env.rst @@ -119,38 +119,105 @@ When using remote envs, you can control the batching level for inference with `` Multi-Agent and Hierarchical ---------------------------- -A multi-agent environment is one which has multiple acting entities per step, e.g., in a traffic simulation, there may be multiple "car"- and "traffic light" agents in the environment. The model for multi-agent in RLlib is as follows: (1) as a user, you define the number of policies available up front, and (2) a function that maps agent ids to policy ids. This is summarized by the below figure: +In a multi-agent environment, there are more than one "agent" acting simultaneously, in a turn-based fashion, or in a combination of these two. + +For example, in a traffic simulation, there may be multiple "car" and "traffic light" agents in the environment, +acting simultaneously. Whereas in a board game, you may have two or more agents acting in a turn-base fashion. + +The mental model for multi-agent in RLlib is as follows: +(1) Your environment (a sub-class of :py:class:`~ray.rllib.env.multi_agent_env.MultiAgentEnv`) returns dictionaries mapping agent IDs (e.g. strings; the env can chose these arbitrarily) to individual agents' observations, rewards, and done-flags. +(2) You define (some of) the policies that are available up front (you can also add new policies on-the-fly throughout training), and +(3) You define a function that maps an env-produced agent ID to any available policy ID, which is then to be used for computing actions for this particular agent. + +This is summarized by the below figure: .. image:: images/multi-agent.svg -The environment itself must subclass the `MultiAgentEnv `__ interface, which can return observations and rewards from multiple ready agents per step: +When implementing your own :py:class:`~ray.rllib.env.multi_agent_env.MultiAgentEnv`, note that you should only return those +agent IDs in an observation dict, for which you expect to receive actions in the next call to `step()`. + +This API allows you to implement any type of multi-agent environment, from `turn-based games `__ +over environments, in which `all agents always act simultaneously `__, to anything in between. + + + +Here is an example of an env, in which all agents always step simultaneously: .. code-block:: python - # Example: using a multi-agent env - > env = MultiAgentTrafficEnv(num_cars=20, num_traffic_lights=5) + # Env, in which all agents (whose IDs are entirely determined by the env + # itself via the returned multi-agent obs/reward/dones-dicts) step + # simultaneously. + env = MultiAgentTrafficEnv(num_cars=2, num_traffic_lights=1) # Observations are a dict mapping agent names to their obs. Only those - # agents' names that require actions in the next call to `step()` will - # be present in the returned observation dict. - > print(env.reset()) - { - "car_1": [[...]], - "car_2": [[...]], - "traffic_light_1": [[...]], - } + # agents' names that require actions in the next call to `step()` should + # be present in the returned observation dict (here: all, as we always step + # simultaneously). + print(env.reset()) + # ... { + # ... "car_1": [[...]], + # ... "car_2": [[...]], + # ... "traffic_light_1": [[...]], + # ... } # In the following call to `step`, actions should be provided for each # agent that returned an observation before: - > new_obs, rewards, dones, infos = env.step(actions={"car_1": ..., "car_2": ..., "traffic_light_1": ...}) + new_obs, rewards, dones, infos = env.step( + actions={"car_1": ..., "car_2": ..., "traffic_light_1": ...}) - # Similarly, new_obs, rewards, dones, etc. also become dicts - > print(rewards) - {"car_1": 3, "car_2": -1, "traffic_light_1": 0} + # Similarly, new_obs, rewards, dones, etc. also become dicts. + print(rewards) + # ... {"car_1": 3, "car_2": -1, "traffic_light_1": 0} + + # Individual agents can early exit; The entire episode is done when + # dones["__all__"] = True. + print(dones) + # ... {"car_2": True, "__all__": False} + + +An another example, where agents step one after the other (turn-based game): + +.. code-block:: python + + # Env, in which two agents step in sequence (tuen-based game). + # The env is in charge of the produced agent ID. Our env here produces + # agent IDs: "player1" and "player2". + env = TicTacToe() + + # Observations are a dict mapping agent names to their obs. Only those + # agents' names that require actions in the next call to `step()` should + # be present in the returned observation dict (here: one agent at a time). + print(env.reset()) + # ... { + # ... "player1": [[...]], + # ... } + + # In the following call to `step`, only those agents' actions should be + # provided that were present in the returned obs dict: + new_obs, rewards, dones, infos = env.step(actions={"player1": ...}) + + # Similarly, new_obs, rewards, dones, etc. also become dicts. + # Note that only in the `rewards` dict, any agent may be listed (even those that have + # not(!) acted in the `step()` call). Rewards for individual agents will be added + # up to the point where a new action for that agent is needed. This way, you may + # implement a turn-based 2-player game, in which player-2's reward is published + # in the `rewards` dict immediately after player-1 has acted. + print(rewards) + # ... {"player1": 0, "player2": 0} + + # Individual agents can early exit; The entire episode is done when + # dones["__all__"] = True. + print(dones) + # ... {"player1": False, "__all__": False} + + # In the next step, it's player2's turn. Therefore, `new_obs` only container + # this agent's ID: + print(new_obs) + # ... { + # ... "player2": [[...]] + # ... } - # Individual agents can early exit; The entire episode is done when "__all__" = True - > print(dones) - {"car_2": True, "__all__": False} If all the agents will be using the same algorithm class to train, then you can setup multi-agent training as follows: @@ -159,13 +226,28 @@ If all the agents will be using the same algorithm class to train, then you can trainer = pg.PGAgent(env="my_multiagent_env", config={ "multiagent": { "policies": { - # the first tuple value is None -> uses default policy - "car1": (None, car_obs_space, car_act_space, {"gamma": 0.85}), + # Use the PolicySpec namedtuple to specify an individual policy: + "car1": PolicySpec( + policy_class=None, # infer automatically from Trainer + observation_space=None, # infer automatically from env + action_space=None, # infer automatically from env + config={"gamma": 0.85}, # use main config plus <- this override here + ), # alternatively, simply do: `PolicySpec(config={"gamma": 0.85})` + + # Deprecated way: Tuple specifying class, obs-/action-spaces, + # config-overrides for each policy as a tuple. + # If class is None -> Uses Trainer's default policy class. "car2": (None, car_obs_space, car_act_space, {"gamma": 0.99}), - "traffic_light": (None, tl_obs_space, tl_act_space, {}), + + # New way: Use PolicySpec() with keywords: `policy_class`, + # `observation_space`, `action_space`, `config`. + "traffic_light": PolicySpec( + observation_space=tl_obs_space, # special obs space for lights? + action_space=tl_act_space, # special action space for lights? + ), }, "policy_mapping_fn": - lambda agent_id: + lambda agent_id, episode, worker, **kwargs: "traffic_light" # Traffic lights are always controlled by this policy if agent_id.startswith("traffic_light_") else random.choice(["car1", "car2"]) # Randomly choose from car policies @@ -175,9 +257,56 @@ If all the agents will be using the same algorithm class to train, then you can while True: print(trainer.train()) -RLlib will create three distinct policies and route agent decisions to its bound policy. When an agent first appears in the env, ``policy_mapping_fn`` will be called to determine which policy it is bound to. RLlib reports separate training statistics for each policy in the return from ``train()``, along with the combined reward. +To exclude some policies in your ``multiagent.policies`` dictionary, you can use the ``multiagent.policies_to_train`` setting. +For example, you may want to have one or more random (non learning) policies interact with your learning ones: + +.. code-block:: python + + + # Example for a mapping function that maps agent IDs "player1" and "player2" to either + # "random_policy" or "learning_policy", making sure that in each episode, both policies + # are always playing each other. + def policy_mapping_fn(agent_id, episode, worker, **kwargs): + agent_idx = int(agent_id[-1]) # 0 (player1) or 1 (player2) + # agent_id = "player[1|2]" -> policy depends on episode ID + # This way, we make sure that both policies sometimes play player1 + # (start player) and sometimes player2 (player to move 2nd). + return "learning_policy" if episode.episode_id % 2 == agent_idx else "random_policy" + + trainer = pg.PGAgent(env="two_player_game", config={ + "multiagent": { + "policies": { + "learning_policy": PolicySpec(), # <- use default class & infer obs-/act-spaces from env. + "random_policy": PolicySpec(policy_class=RandomPolicy), # infer obs-/act-spaces from env. + }, + # Example for a mapping function that maps agent IDs "player1" and "player2" to either + # "random_policy" or "learning_policy", making sure that in each episode, both policies + # are always playing each other. + "policy_mapping_fn": policy_mapping_fn, + # Specify a (fixed) list (or set) of policy IDs that should be updated. + "policies_to_train": ["learning_policy"], + + # Alternatively, you can provide a callable that returns True or False, when provided + # with a policy ID and an (optional) SampleBatch: + + # "policies_to_train": lambda pid, batch: ... (<- return True or False) + + # This allows you to more flexibly update (or not) policies, based on + # who they played with in the episode (or other information that can be + # found in the given batch, e.g. rewards). + }, + }) + + +RLlib will create three distinct policies and route agent decisions to its bound policy using the given ``policy_mapping_fn``. +When an agent first appears in the env, ``policy_mapping_fn`` will be called to determine which policy it is bound to. +RLlib reports separate training statistics for each policy in the return from ``train()``, along with the combined reward. -Here is a simple `example training script `__ in which you can vary the number of agents and policies in the environment. For how to use multiple training methods at once (here DQN and PPO), see the `two-trainer example `__. Metrics are reported for each policy separately, for example: +Here is a simple `example training script `__ +in which you can vary the number of agents and policies in the environment. +For how to use multiple training methods at once (here DQN and PPO), +see the `two-trainer example `__. +Metrics are reported for each policy separately, for example: .. code-block:: bash :emphasize-lines: 6,14,22 @@ -207,7 +336,8 @@ Here is a simple `example training script 1``. +To scale to hundreds of agents (if these agents are using the same policy), MultiAgentEnv batches policy evaluations across multiple agents internally. +Your ``MultiAgentEnvs`` are also auto-vectorized (as can be normal, single-agent envs, e.g. gym.Env) by setting ``num_envs_per_worker > 1``. PettingZoo Multi-Agent Environments @@ -330,7 +460,9 @@ This can be implemented as a multi-agent environment with three types of agents. }, -In this setup, the appropriate rewards for training lower-level agents must be provided by the multi-agent env implementation. The environment class is also responsible for routing between the agents, e.g., conveying `goals `__ from higher-level agents to lower-level agents as part of the lower-level agent observation. +In this setup, the appropriate rewards for training lower-level agents must be provided by the multi-agent env implementation. +The environment class is also responsible for routing between the agents, e.g., conveying `goals `__ from higher-level +agents to lower-level agents as part of the lower-level agent observation. See this file for a runnable example: `hierarchical_training.py `__. diff --git a/rllib/agents/dqn/apex.py b/rllib/agents/dqn/apex.py index 883230f4a973..b6e804bb1ba6 100644 --- a/rllib/agents/dqn/apex.py +++ b/rllib/agents/dqn/apex.py @@ -241,7 +241,7 @@ def update_prio_and_stats(item: Tuple[ActorHandle, dict, int]) -> None: def add_apex_metrics(result: dict) -> dict: replay_stats = ray.get(replay_actors[0].stats.remote( config["optimizer"].get("debug"))) - exploration_infos = workers.foreach_trainable_policy( + exploration_infos = workers.foreach_policy_to_train( lambda p, _: p.get_exploration_state()) result["info"].update({ "exploration_infos": exploration_infos, diff --git a/rllib/agents/maml/maml.py b/rllib/agents/maml/maml.py index 30fbcb95d20d..f3a27cb8909f 100644 --- a/rllib/agents/maml/maml.py +++ b/rllib/agents/maml/maml.py @@ -127,7 +127,7 @@ def update(pi, pi_id): else: logger.warning("No data for {}, not updating kl".format(pi_id)) - self.workers.local_worker().foreach_trainable_policy(update) + self.workers.local_worker().foreach_policy_to_train(update) # Modify Reporting Metrics metrics = _get_shared_metrics() diff --git a/rllib/agents/mbmpo/mbmpo.py b/rllib/agents/mbmpo/mbmpo.py index 655b0b2d552f..dfc641ab750f 100644 --- a/rllib/agents/mbmpo/mbmpo.py +++ b/rllib/agents/mbmpo/mbmpo.py @@ -169,7 +169,7 @@ def update(pi, pi_id): else: logger.warning("No data for {}, not updating kl".format(pi_id)) - self.workers.local_worker().foreach_trainable_policy(update) + self.workers.local_worker().foreach_policy_to_train(update) # Modify Reporting Metrics. metrics = _get_shared_metrics() diff --git a/rllib/agents/ppo/appo.py b/rllib/agents/ppo/appo.py index 2cfab4ff54f4..f286c3b0180c 100644 --- a/rllib/agents/ppo/appo.py +++ b/rllib/agents/ppo/appo.py @@ -100,7 +100,7 @@ def __call__(self, fetches): metrics.counters[NUM_TARGET_UPDATES] += 1 metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts # Update Target Network - self.workers.local_worker().foreach_trainable_policy( + self.workers.local_worker().foreach_policy_to_train( lambda p, _: p.update_target()) # Also update KL Coeff if self.config["use_kl_loss"]: @@ -117,7 +117,7 @@ def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) # After init: Initialize target net. - self.workers.local_worker().foreach_trainable_policy( + self.workers.local_worker().foreach_policy_to_train( lambda p, _: p.update_target()) @classmethod diff --git a/rllib/agents/ppo/ppo.py b/rllib/agents/ppo/ppo.py index 42d4b7b8400d..a695b9ed6bca 100644 --- a/rllib/agents/ppo/ppo.py +++ b/rllib/agents/ppo/ppo.py @@ -128,7 +128,7 @@ def update(pi, pi_id): # Update KL on all trainable policies within the local (trainer) # Worker. - self.workers.local_worker().foreach_trainable_policy(update) + self.workers.local_worker().foreach_policy_to_train(update) def warn_about_bad_reward_scales(config, result): @@ -271,7 +271,7 @@ def execution_plan(workers: WorkerSet, config: TrainerConfigDict, # Collect batches for the trainable policies. rollouts = rollouts.for_each( - SelectExperiences(workers.trainable_policies())) + SelectExperiences(local_worker=workers.local_worker())) # Concatenate the SampleBatches into one. rollouts = rollouts.combine( ConcatBatches( diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 86ed985923ea..d217a951f3c4 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -55,7 +55,7 @@ from ray.rllib.utils.spaces import space_utils from ray.rllib.utils.typing import AgentID, EnvCreator, EnvInfoDict, EnvType, \ EpisodeID, PartialTrainerConfigDict, PolicyID, PolicyState, ResultDict, \ - TensorStructType, TensorType, TrainerConfigDict + SampleBatchType, TensorStructType, TensorType, TrainerConfigDict from ray.tune.logger import Logger, UnifiedLogger from ray.tune.registry import ENV_CREATOR, register_env, _global_registry from ray.tune.resources import Resources @@ -522,7 +522,15 @@ "policy_map_cache": None, # Function mapping agent ids to policy ids. "policy_mapping_fn": None, - # Optional list of policies to train, or None for all policies. + # Determines those policies that should be updated. + # Options are: + # - None, for all policies. + # - An iterable of PolicyIDs that should be updated. + # - A callable, taking a PolicyID and a SampleBatch or MultiAgentBatch + # and returning a bool (indicating whether the given policy is trainable + # or not, given the particular batch). This allows you to have a policy + # trained only on certain data (e.g. when playing against a certain + # opponent). "policies_to_train": None, # Optional function that can be used to enhance the local agent # observations to include more state. @@ -1693,7 +1701,8 @@ def add_policy( policy_state: Optional[PolicyState] = None, policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None, - policies_to_train: Optional[Container[PolicyID]] = None, + policies_to_train: Optional[Union[Container[PolicyID], Callable[ + [PolicyID, Optional[SampleBatchType]], bool]]] = None, evaluation_workers: bool = True, workers: Optional[List[Union[RolloutWorker, ActorHandle]]] = None, ) -> Policy: @@ -1714,9 +1723,12 @@ def add_policy( to use from here on. Note that already ongoing episodes will not change their mapping but will use the old mapping till the end of the episode. - policies_to_train: An optional list/set of policy IDs to be - trained. If None, will keep the existing list in place. - Policies, whose IDs are not in the list will not be updated. + policies_to_train: An optional list of policy IDs to be trained + or a callable taking PolicyID and SampleBatchType and + returning a bool (trainable or not?). + If None, will keep the existing setup in place. Policies, + whose IDs are not in the list (or for which the callable + returns False) will not be updated. evaluation_workers: Whether to add the new policy also to the evaluation WorkerSet. workers: A list of RolloutWorker/ActorHandles (remote @@ -1769,7 +1781,8 @@ def remove_policy( policy_id: PolicyID = DEFAULT_POLICY_ID, *, policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None, - policies_to_train: Optional[List[PolicyID]] = None, + policies_to_train: Optional[Union[Set[PolicyID], Callable[ + [PolicyID, Optional[SampleBatchType]], bool]]] = None, evaluation_workers: bool = True, ) -> None: """Removes a new policy from this Trainer. @@ -1780,9 +1793,12 @@ def remove_policy( to use from here on. Note that already ongoing episodes will not change their mapping but will use the old mapping till the end of the episode. - policies_to_train: An optional list of policy IDs to be trained. - If None, will keep the existing list in place. Policies, - whose IDs are not in the list will not be updated. + policies_to_train: An optional list of policy IDs to be trained + or a callable taking PolicyID and SampleBatchType and + returning a bool (trainable or not?). + If None, will keep the existing setup in place. Policies, + whose IDs are not in the list (or for which the callable + returns False) will not be updated. evaluation_workers: Whether to also remove the policy from the evaluation WorkerSet. """ diff --git a/rllib/contrib/maddpg/maddpg.py b/rllib/contrib/maddpg/maddpg.py index 414e68940b88..5ddfc444f32c 100644 --- a/rllib/contrib/maddpg/maddpg.py +++ b/rllib/contrib/maddpg/maddpg.py @@ -172,7 +172,7 @@ def validate_config(self, config: TrainerConfigDict) -> None: def f(batch, workers, config): policies = dict(workers.local_worker() - .foreach_trainable_policy(lambda p, i: (i, p))) + .foreach_policy_to_train(lambda p, i: (i, p))) return before_learn_on_batch(batch, policies, config["train_batch_size"]) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index a9688a5b54b5..cd913dfdf8bb 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -6,8 +6,8 @@ import platform import os import tree # pip install dm_tree -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, \ - TYPE_CHECKING, Union +from typing import Any, Callable, Container, Dict, List, Optional, Set, \ + Tuple, Type, TYPE_CHECKING, Union import ray from ray import ObjectRef @@ -34,7 +34,8 @@ from ray.rllib.policy.policy_map import PolicyMap from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.utils import force_list, merge_dicts -from ray.rllib.utils.annotations import Deprecated, DeveloperAPI +from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, \ + ExperimentalAPI from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.error import EnvError, ERR_MSG_NO_GPUS, \ @@ -188,7 +189,8 @@ def __init__( PolicySpec]]] = None, policy_mapping_fn: Optional[Callable[[AgentID, "Episode"], PolicyID]] = None, - policies_to_train: Optional[List[PolicyID]] = None, + policies_to_train: Union[Container[PolicyID], Callable[ + [PolicyID, SampleBatchType], bool]] = None, tf_session_creator: Optional[Callable[[], "tf1.Session"]] = None, rollout_fragment_length: int = 100, count_steps_by: str = "env_steps", @@ -246,8 +248,9 @@ def __init__( agent appears in an episode, to bind that agent to a policy for the duration of the episode. If not provided, will map all agents to DEFAULT_POLICY_ID. - policies_to_train: Optional list of policies to train, or None - for all policies. + policies_to_train: Optional container of policies to train (None + for all policies), or a callable taking PolicyID and + SampleBatchType and returning a bool (trainable or not?). tf_session_creator: A function that returns a TF session. This is optional and only useful with TFPolicy. rollout_fragment_length: The target number of steps @@ -544,11 +547,17 @@ def make_sub_env(vector_index): spaces=self.spaces, policy_config=policy_config) - # List of IDs of those policies, which should be trained. - # By default, these are all policies found in `self.policy_dict`. - self.policies_to_train: List[PolicyID] = policies_to_train or list( - self.policy_dict.keys()) - self.set_policies_to_train(self.policies_to_train) + # Set of IDs of those policies, which should be trained. This property + # is optional and mainly used for backward compatibility. + self.policies_to_train = policies_to_train + self.is_policy_to_train: Callable[[PolicyID, SampleBatchType], bool] + + # By default (None), use the set of all policies found in the + # policy_dict. + if self.policies_to_train is None: + self.policies_to_train = set(self.policy_dict.keys()) + + self.set_is_policy_to_train(self.policies_to_train) self.policy_map: PolicyMap = None self.preprocessors: Dict[PolicyID, Preprocessor] = None @@ -843,7 +852,7 @@ def learn_on_batch(self, samples: SampleBatchType) -> Dict: builders = {} to_fetch = {} for pid, batch in samples.policy_batches.items(): - if pid not in self.policies_to_train: + if not self.is_policy_to_train(pid, samples): continue # Decompress SampleBatch, in case some columns are compressed. batch.decompress_if_needed() @@ -859,10 +868,11 @@ def learn_on_batch(self, samples: SampleBatchType) -> Dict: {pid: builders[pid].get(v) for pid, v in to_fetch.items()}) else: - info_out = { - DEFAULT_POLICY_ID: self.policy_map[DEFAULT_POLICY_ID] - .learn_on_batch(samples) - } + if self.is_policy_to_train(DEFAULT_POLICY_ID, samples): + info_out = { + DEFAULT_POLICY_ID: self.policy_map[DEFAULT_POLICY_ID] + .learn_on_batch(samples) + } if log_once("learn_out"): logger.debug("Training out:\n\n{}\n".format(summarize(info_out))) return info_out @@ -902,11 +912,12 @@ def compute_gradients( """Returns a gradient computed w.r.t the specified samples. Uses the Policy's/ies' compute_gradients method(s) to perform the - calculations. + calculations. Skips policies that are not trainable as per + `self.is_policy_to_train()`. Args: samples: The SampleBatch or MultiAgentBatch to compute gradients - for using this worker's policies. + for using this worker's trainable policies. Returns: In the single-agent case, a tuple consisting of ModelGradients and @@ -929,7 +940,7 @@ def compute_gradients( grad_out, info_out = {}, {} if self.policy_config.get("framework") == "tf": for pid, batch in samples.policy_batches.items(): - if pid not in self.policies_to_train: + if not self.is_policy_to_train(pid, samples): continue policy = self.policy_map[pid] builder = TFRunBuilder(policy.get_session(), @@ -940,7 +951,7 @@ def compute_gradients( info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): - if pid not in self.policies_to_train: + if not self.is_policy_to_train(pid, samples): continue grad_out[pid], info_out[pid] = ( self.policy_map[pid].compute_gradients(batch)) @@ -982,10 +993,10 @@ def apply_gradients( # Multi-agent case. if isinstance(grads, dict): for pid, g in grads.items(): - if pid in self.policies_to_train: + if self.is_policy_to_train(pid, None): self.policy_map[pid].apply_gradients(g) # Grads is a ModelGradients type. Single-agent case. - elif DEFAULT_POLICY_ID in self.policies_to_train: + elif self.is_policy_to_train(DEFAULT_POLICY_ID, None): self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads) @DeveloperAPI @@ -1087,7 +1098,8 @@ def add_policy( policy_state: Optional[PolicyState] = None, policy_mapping_fn: Optional[Callable[[AgentID, "Episode"], PolicyID]] = None, - policies_to_train: Optional[List[PolicyID]] = None, + policies_to_train: Optional[Union[Container[PolicyID], Callable[ + [PolicyID, SampleBatchType], bool]]] = None, ) -> Policy: """Adds a new policy to this RolloutWorker. @@ -1104,9 +1116,12 @@ def add_policy( to use from here on. Note that already ongoing episodes will not change their mapping but will use the old mapping till the end of the episode. - policies_to_train: An optional list of policy IDs to be trained. - If None, will keep the existing list in place. Policies, - whose IDs are not in the list will not be updated. + policies_to_train: An optional container of policy IDs to be + trained or a callable taking PolicyID and - optionally - + SampleBatchType and returning a bool (trainable or not?). + If None, will keep the existing setup in place. + Policies, whose IDs are not in the list (or for which the + callable returns False) will not be updated. Returns: The newly added policy. @@ -1140,7 +1155,8 @@ def add_policy( self.observation_filter, new_policy.observation_space.shape) self.set_policy_mapping_fn(policy_mapping_fn) - self.set_policies_to_train(policies_to_train) + if policies_to_train is not None: + self.set_is_policy_to_train(policies_to_train) return new_policy @@ -1150,7 +1166,8 @@ def remove_policy( *, policy_id: PolicyID = DEFAULT_POLICY_ID, policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None, - policies_to_train: Optional[List[PolicyID]] = None, + policies_to_train: Optional[Union[Container[PolicyID], Callable[ + [PolicyID, SampleBatchType], bool]]] = None, ) -> None: """Removes a policy from this RolloutWorker. @@ -1161,16 +1178,20 @@ def remove_policy( to use from here on. Note that already ongoing episodes will not change their mapping but will use the old mapping till the end of the episode. - policies_to_train: An optional list of policy IDs to be trained. - If None, will keep the existing list in place. Policies, - whose IDs are not in the list will not be updated. + policies_to_train: An optional container of policy IDs to be + trained or a callable taking PolicyID and - optionally - + SampleBatchType and returning a bool (trainable or not?). + If None, will keep the existing setup in place. + Policies, whose IDs are not in the list (or for which the + callable returns False) will not be updated. """ if policy_id not in self.policy_map: raise ValueError(f"Policy ID '{policy_id}' not in policy map!") del self.policy_map[policy_id] del self.preprocessors[policy_id] self.set_policy_mapping_fn(policy_mapping_fn) - self.set_policies_to_train(policies_to_train) + if policies_to_train is not None: + self.set_is_policy_to_train(policies_to_train) @DeveloperAPI def set_policy_mapping_fn( @@ -1190,16 +1211,54 @@ def set_policy_mapping_fn( raise ValueError("`policy_mapping_fn` must be a callable!") @DeveloperAPI - def set_policies_to_train( - self, policies_to_train: Optional[List[PolicyID]] = None) -> None: - """Sets `self.policies_to_train` to a new list of PolicyIDs. + def set_is_policy_to_train( + self, is_policy_to_train: Union[Container[PolicyID], Callable[ + [PolicyID, Optional[SampleBatchType]], bool]]) -> None: + """Sets `self.is_policy_to_train()` to a new callable. Args: - policies_to_train: The new list of policy IDs to train with. - If None, will keep the existing list in place. + is_policy_to_train: A container of policy IDs to be + trained or a callable taking PolicyID and - optionally - + SampleBatchType and returning a bool (trainable or not?). + If None, will keep the existing setup in place. + Policies, whose IDs are not in the list (or for which the + callable returns False) will not be updated. """ - if policies_to_train is not None: - self.policies_to_train = policies_to_train + # If container given, construct a simple default callable returning True + # if the PolicyID is found in the list/set of IDs. + if not callable(is_policy_to_train): + assert isinstance(is_policy_to_train, Container), \ + "ERROR: `is_policy_to_train`must be a container or a " \ + "callable taking PolicyID and SampleBatch and returning " \ + "True|False (trainable or not?)." + pols = set(is_policy_to_train) + + def is_policy_to_train(pid, batch=None): + return pid in pols + + self.is_policy_to_train = is_policy_to_train + + @ExperimentalAPI + def get_policies_to_train( + self, batch: Optional[SampleBatchType] = None) -> Set[PolicyID]: + """Returns all policies-to-train, given an optional batch. + + Loops through all policies currently in `self.policy_map` and checks + the return value of `self.is_policy_to_train(pid, batch)`. + + Args: + batch: An optional SampleBatchType for the + `self.is_policy_to_train(pid, [batch]?)` check. + + Returns: + The set of currently trainable policy IDs, given the optional + `batch`. + """ + return { + pid + for pid in self.policy_map.keys() + if self.is_policy_to_train(pid, batch) + } @DeveloperAPI def for_policy(self, @@ -1243,19 +1302,19 @@ def foreach_policy(self, ] @DeveloperAPI - def foreach_trainable_policy( + def foreach_policy_to_train( self, func: Callable[[Policy, PolicyID, Optional[Any]], T], **kwargs) -> List[T]: """ Calls the given function with each (policy, policy_id) tuple. - - Only those policies/IDs will be called on, which can be found in - `self.policies_to_train`. + Only those policies/IDs will be called on, for which + `self.is_policy_to_train()` returns True. Args: func: The function to call with each (policy, policy ID) tuple, - for only those policies that are in `self.policies_to_train`. + for only those policies that `self.is_policy_to_train` + returns True. Keyword Args: kwargs: Additional kwargs to be passed to the call. @@ -1267,7 +1326,7 @@ def foreach_trainable_policy( return [ func(policy, pid, **kwargs) for pid, policy in self.policy_map.items() - if pid in self.policies_to_train + if self.is_policy_to_train(pid, None) ] @DeveloperAPI @@ -1355,7 +1414,7 @@ def restore(self, objs: bytes) -> None: @DeveloperAPI def get_weights( self, - policies: Optional[List[PolicyID]] = None, + policies: Optional[Container[PolicyID]] = None, ) -> Dict[PolicyID, ModelWeights]: """Returns each policies' model weights of this worker. @@ -1609,6 +1668,10 @@ def export_policy_checkpoint(self, self.policy_map[policy_id].export_checkpoint(export_dir, filename_prefix) + @Deprecated(new="RolloutWorker.foreach_policy_to_train", error=False) + def foreach_trainable_policy(self, func, **kwargs): + return self.foreach_policy_to_train(func, **kwargs) + def _determine_spaces_for_multi_agent_dict( multi_agent_dict: MultiAgentPolicyConfigDict, diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 919f19dd63c6..bba14ebf290b 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -2,7 +2,8 @@ import logging import importlib.util from types import FunctionType -from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, \ + Union import ray from ray import data @@ -15,10 +16,11 @@ from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.utils import merge_dicts from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.from_config import from_config from ray.rllib.utils.typing import EnvCreator, EnvType, PolicyID, \ - TrainerConfigDict + SampleBatchType, TrainerConfigDict from ray.tune.registry import registry_contains_input, registry_get_input tf1, tf, tfv = try_import_tf() @@ -231,6 +233,18 @@ def stop(self) -> None: for w in self.remote_workers(): w.__ray_terminate__.remote() + @DeveloperAPI + def is_policy_to_train(self, + policy_id: PolicyID, + batch: Optional[SampleBatchType] = None) -> bool: + """Whether given PolicyID (optionally inside some batch) is trainable. + """ + local_worker = self.local_worker() + if local_worker: + return local_worker.is_policy_to_train(policy_id, batch) + else: + raise NotImplementedError + @DeveloperAPI def foreach_worker(self, func: Callable[[RolloutWorker], T]) -> List[T]: """Calls the given function with each worker instance as arg. @@ -308,21 +322,14 @@ def foreach_policy(self, func: Callable[[Policy, PolicyID], T]) -> List[T]: return results @DeveloperAPI - def trainable_policies(self) -> List[PolicyID]: - """Returns the list of trainable policy ids.""" - if self.local_worker() is not None: - return self.local_worker().policies_to_train - else: - raise NotImplementedError - - @DeveloperAPI - def foreach_trainable_policy( + def foreach_policy_to_train( self, func: Callable[[Policy, PolicyID], T]) -> List[T]: """Apply `func` to all workers' Policies iff in `policies_to_train`. Args: func: A function - taking a Policy and its ID - that is - called on all workers' Policies in `worker.policies_to_train`. + called on all workers' Policies, for which + `worker.is_policy_to_train()` returns True. Returns: List[any]: The list of n return values of all @@ -330,12 +337,11 @@ def foreach_trainable_policy( """ results = [] if self.local_worker() is not None: - results = self.local_worker().foreach_trainable_policy(func) + results = self.local_worker().foreach_policy_to_train(func) ray_gets = [] for worker in self.remote_workers(): ray_gets.append( - worker.apply.remote( - lambda w: w.foreach_trainable_policy(func))) + worker.apply.remote(lambda w: w.foreach_policy_to_train(func))) remote_results = ray.get(ray_gets) for r in remote_results: results.extend(r) @@ -597,3 +603,19 @@ def valid_module(class_path): ) return worker + + @Deprecated(new="WorkerSet.foreach_policy_to_train", error=False) + def foreach_trainable_policy(self, func): + return self.foreach_policy_to_train(func) + + @Deprecated( + new="WorkerSet.is_policy_to_train([pid], [batch]?)", error=False) + def trainable_policies(self): + local_worker = self.local_worker() + if local_worker is not None: + return [ + local_worker.is_policy_to_train(pid, None) + for pid in local_worker.policy_map.keys() + ] + else: + raise NotImplementedError diff --git a/rllib/examples/random_parametric_agent.py b/rllib/examples/random_parametric_agent.py index 62c9dfe1ca76..993f07e31dfb 100644 --- a/rllib/examples/random_parametric_agent.py +++ b/rllib/examples/random_parametric_agent.py @@ -69,7 +69,7 @@ def execution_plan(workers: WorkerSet, config: TrainerConfigDict, # Collect batches for the trainable policies. rollouts = rollouts.for_each( - SelectExperiences(workers.trainable_policies())) + SelectExperiences(local_worker=workers.local_worker())) # Return training metrics. return StandardMetricsReporting(rollouts, workers, config) diff --git a/rllib/execution/multi_gpu_learner_thread.py b/rllib/execution/multi_gpu_learner_thread.py index 978be2502fd2..f8d02f6ceecc 100644 --- a/rllib/execution/multi_gpu_learner_thread.py +++ b/rllib/execution/multi_gpu_learner_thread.py @@ -154,7 +154,7 @@ def step(self) -> None: for pid in self.policy_map.keys(): # Not a policy-to-train. - if pid not in self.local_worker.policies_to_train: + if not self.local_worker.is_policy_to_train(pid): continue policy = self.policy_map[pid] default_policy_results = policy.learn_on_loaded_batch( @@ -209,7 +209,7 @@ def _step(self) -> None: # Load the batch into the idle stack. with self.load_timer: for pid in policy_map.keys(): - if pid not in s.local_worker.policies_to_train: + if not s.local_worker.is_policy_to_train(pid, batch): continue policy = policy_map[pid] policy.load_batch_into_buffer( diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index eae7ac8483cf..999420f3a0af 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -1,6 +1,6 @@ import logging import time -from typing import Any, Callable, Dict, List, Optional, Tuple, \ +from typing import Any, Callable, Container, Dict, List, Optional, Tuple, \ TYPE_CHECKING import ray @@ -404,19 +404,45 @@ class SelectExperiences: {"pol1", "pol2"} """ - def __init__(self, policy_ids: List[PolicyID]): - assert isinstance(policy_ids, list), policy_ids - self.policy_ids = policy_ids + def __init__(self, + policy_ids: Optional[Container[PolicyID]] = None, + local_worker: Optional["RolloutWorker"] = None): + """Initializes a SelectExperiences instance. + + Args: + policy_ids: Container of PolicyID to select from passing through + batches. If not provided, must provide the `local_worker` arg. + local_worker: The local worker to use to determine, which policy + IDs are trainable. If not provided, must provide the + `policy_ids` arg. + """ + assert policy_ids is not None or local_worker is not None, \ + "ERROR: Must provide either one of `policy_ids` or " \ + "`local_worker` args!" + + self.local_worker = self.policy_ids = None + if local_worker: + self.local_worker = local_worker + else: + assert isinstance(policy_ids, Container), policy_ids + self.policy_ids = set(policy_ids) def __call__(self, samples: SampleBatchType) -> SampleBatchType: _check_sample_batch_type(samples) if isinstance(samples, MultiAgentBatch): - samples = MultiAgentBatch({ - k: v - for k, v in samples.policy_batches.items() - if k in self.policy_ids - }, samples.count) + if self.local_worker: + samples = MultiAgentBatch({ + pid: batch + for pid, batch in samples.policy_batches.items() + if self.local_worker.is_policy_to_train(pid, batch) + }, samples.count) + else: + samples = MultiAgentBatch({ + k: v + for k, v in samples.policy_batches.items() + if k in self.policy_ids + }, samples.count) return samples diff --git a/rllib/execution/train_ops.py b/rllib/execution/train_ops.py index b3f3c6ccc613..489aa615df7b 100644 --- a/rllib/execution/train_ops.py +++ b/rllib/execution/train_ops.py @@ -33,7 +33,6 @@ def train_one_step(trainer, train_batch) -> Dict: config = trainer.config workers = trainer.workers local_worker = workers.local_worker() - policies = local_worker.policies_to_train num_sgd_iter = config.get("num_sgd_iter", 1) sgd_minibatch_size = config.get("sgd_minibatch_size", 0) @@ -43,10 +42,10 @@ def train_one_step(trainer, train_batch) -> Dict: # train batch and loop through train batch `num_sgd_iter` times. if num_sgd_iter > 1 or sgd_minibatch_size > 0: info = do_minibatch_sgd( - train_batch, - {pid: local_worker.get_policy(pid) - for pid in policies}, local_worker, num_sgd_iter, - sgd_minibatch_size, []) + train_batch, { + pid: local_worker.get_policy(pid) + for pid in local_worker.get_policies_to_train(train_batch) + }, local_worker, num_sgd_iter, sgd_minibatch_size, []) # Single update step using train batch. else: info = local_worker.learn_on_batch(train_batch) @@ -63,7 +62,6 @@ def multi_gpu_train_one_step(trainer, train_batch) -> Dict: config = trainer.config workers = trainer.workers local_worker = workers.local_worker() - policies = local_worker.policies_to_train num_sgd_iter = config.get("sgd_num_iter", 1) sgd_minibatch_size = config.get("sgd_minibatch_size", config["train_batch_size"]) @@ -88,7 +86,7 @@ def multi_gpu_train_one_step(trainer, train_batch) -> Dict: num_loaded_samples = {} for policy_id, batch in train_batch.policy_batches.items(): # Not a policy-to-train. - if policy_id not in policies: + if not local_worker.is_policy_to_train(policy_id, train_batch): continue # Decompress SampleBatch, in case some columns are compressed. @@ -143,7 +141,9 @@ def multi_gpu_train_one_step(trainer, train_batch) -> Dict: # workers. if workers.remote_workers(): with trainer._timers[WORKER_UPDATE_TIMER]: - weights = ray.put(workers.local_worker().get_weights(policies)) + weights = ray.put( + local_worker.get_weights( + local_worker.get_policies_to_train(train_batch))) for e in workers.remote_workers(): e.set_weights.remote(weights) @@ -182,21 +182,20 @@ def __call__(self, _check_sample_batch_type(batch) metrics = _get_shared_metrics() learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] + lw = self.local_worker with learn_timer: # Subsample minibatches (size=`sgd_minibatch_size`) from the # train batch and loop through train batch `num_sgd_iter` times. if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0: - lw = self.workers.local_worker() learner_info = do_minibatch_sgd( batch, { pid: lw.get_policy(pid) for pid in self.policies - or self.local_worker.policies_to_train + or lw.get_policies_to_train(batch) }, lw, self.num_sgd_iter, self.sgd_minibatch_size, []) # Single update step using train batch. else: - learner_info = \ - self.workers.local_worker().learn_on_batch(batch) + learner_info = lw.learn_on_batch(batch) metrics.info[LEARNER_INFO] = learner_info learn_timer.push_units_processed(batch.count) @@ -209,12 +208,13 @@ def __call__(self, # workers. if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: - weights = ray.put(self.workers.local_worker().get_weights( - self.policies or self.local_worker.policies_to_train)) + weights = ray.put( + lw.get_weights( + self.policies or lw.get_policies_to_train(batch))) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. - self.workers.local_worker().set_global_vars(_get_global_vars()) + lw.set_global_vars(_get_global_vars()) return batch, learner_info @@ -291,7 +291,8 @@ def __call__(self, num_loaded_samples = {} for policy_id, batch in samples.policy_batches.items(): # Not a policy-to-train. - if policy_id not in self.local_worker.policies_to_train: + if not self.local_worker.is_policy_to_train( + policy_id, samples): continue # Decompress SampleBatch, in case some columns are compressed. @@ -348,7 +349,7 @@ def __call__(self, if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( - self.local_worker.policies_to_train)) + self.local_worker.get_policies_to_train())) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) @@ -433,17 +434,19 @@ def __call__(self, item: Tuple[ModelGradients, int]) -> None: apply_timer = metrics.timers[APPLY_GRADS_TIMER] with apply_timer: - self.workers.local_worker().apply_gradients(gradients) + self.local_worker.apply_gradients(gradients) apply_timer.push_units_processed(count) # Also update global vars of the local worker. - self.workers.local_worker().set_global_vars(_get_global_vars()) + self.local_worker.set_global_vars(_get_global_vars()) if self.update_all: if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: - weights = ray.put(self.workers.local_worker().get_weights( - self.policies or self.local_worker.policies_to_train)) + weights = ray.put( + self.local_worker.get_weights( + self.policies + or self.local_worker.get_policies_to_train())) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) else: @@ -453,8 +456,8 @@ def __call__(self, item: Tuple[ModelGradients, int]) -> None: "update_all=False, `current_actor` must be set " "in the iterator context.") with metrics.timers[WORKER_UPDATE_TIMER]: - weights = self.workers.local_worker().get_weights( - self.policies or self.local_worker.policies_to_train) + weights = self.local_worker.get_weights( + self.policies or self.local_worker.get_policies_to_train()) metrics.current_actor.set_weights.remote( weights, _get_global_vars()) @@ -524,8 +527,9 @@ def __call__(self, _: Any) -> None: cur_ts = metrics.counters[self.metric] last_update = metrics.counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update > self.target_update_freq: - to_update = self.policies or self.local_worker.policies_to_train - self.workers.local_worker().foreach_trainable_policy( - lambda p, p_id: p_id in to_update and p.update_target()) + to_update = self.policies or self.local_worker.get_policies_to_train( + ) + self.workers.local_worker().foreach_policy_to_train( + lambda p, pid: pid in to_update and p.update_target()) metrics.counters[NUM_TARGET_UPDATES] += 1 metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts diff --git a/rllib/utils/__init__.py b/rllib/utils/__init__.py index a2c0942f6f7b..e8e686dd4feb 100644 --- a/rllib/utils/__init__.py +++ b/rllib/utils/__init__.py @@ -57,7 +57,7 @@ def force_list(elements=None, to_tuple=False): if to_tuple is True: ctor = tuple return ctor() if elements is None else ctor(elements) \ - if type(elements) in [list, tuple] else ctor([elements]) + if type(elements) in [list, set, tuple] else ctor([elements]) class NullContextManager(contextlib.AbstractContextManager):