diff --git a/rllib/algorithms/maddpg/tests/test_maddpg.py b/rllib/algorithms/maddpg/tests/test_maddpg.py index b1fdab69f065..a60300ea486e 100644 --- a/rllib/algorithms/maddpg/tests/test_maddpg.py +++ b/rllib/algorithms/maddpg/tests/test_maddpg.py @@ -38,7 +38,7 @@ def test_maddpg_compilation(self): config=maddpg.MADDPGConfig.overrides(agent_id=1), ), }, - policy_mapping_fn=lambda agent_id, **kwargs: "pol2" + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol2" if agent_id else "pol1", ) diff --git a/rllib/algorithms/tests/test_algorithm.py b/rllib/algorithms/tests/test_algorithm.py index 1485d81470ec..73b03bd634fa 100644 --- a/rllib/algorithms/tests/test_algorithm.py +++ b/rllib/algorithms/tests/test_algorithm.py @@ -196,7 +196,7 @@ def _has_policies(w): # Note that the complete signature of a policy_mapping_fn # is: `agent_id, episode, worker, **kwargs`. policy_mapping_fn=( - lambda agent_id, worker, episode, **kwargs: f"p{i - 1}" + lambda agent_id, episode, worker, **kwargs: f"p{i - 1}" ), # Update list of policies to train. policies_to_train=[f"p{i - 1}"], diff --git a/rllib/env/tests/test_external_multi_agent_env.py b/rllib/env/tests/test_external_multi_agent_env.py index 5570db2a5725..474d9746ff5d 100644 --- a/rllib/env/tests/test_external_multi_agent_env.py +++ b/rllib/env/tests/test_external_multi_agent_env.py @@ -69,7 +69,11 @@ def test_external_multi_agent_env_sample(self): ) .multi_agent( policies={"p0", "p1"}, - policy_mapping_fn=lambda agent_id, **kwargs: "p{}".format(agent_id % 2), + policy_mapping_fn=( + lambda agent_id, episode, worker, **kwargs: "p{}".format( + agent_id % 2 + ) + ), ), ) batch = ev.sample() diff --git a/rllib/env/tests/test_multi_agent_env.py b/rllib/env/tests/test_multi_agent_env.py index 484a1eb76d79..f587d6308b35 100644 --- a/rllib/env/tests/test_multi_agent_env.py +++ b/rllib/env/tests/test_multi_agent_env.py @@ -177,10 +177,8 @@ def test_multi_agent_sample_sync_remote(self): ) .multi_agent( policies={"p0", "p1"}, - policy_mapping_fn=( - lambda agent_id, episode, worker, **kwargs: "p{}".format( - agent_id % 2 - ) + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: ( + "p{}".format(agent_id % 2) ), ), ) @@ -200,8 +198,8 @@ def test_multi_agent_sample_async_remote(self): ) .multi_agent( policies={"p0", "p1"}, - policy_mapping_fn=( - lambda agent_id, **kwargs: "p{}".format(agent_id % 2) + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: ( + "p{}".format(agent_id % 2) ), ), ) @@ -220,8 +218,8 @@ def test_sample_from_early_done_env(self): ) .multi_agent( policies={"p0", "p1"}, - policy_mapping_fn=( - lambda agent_id, **kwargs: "p{}".format(agent_id % 2) + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: ( + "p{}".format(agent_id % 2) ), ), ) @@ -289,7 +287,7 @@ def test_multi_agent_sample_round_robin(self): ) .multi_agent( policies={"p0"}, - policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0", + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "p0", ), ) batch = ev.sample() @@ -434,7 +432,7 @@ def compute_actions_from_input_dict( ) .multi_agent( policies={"p0", "p1"}, - policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0", + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "p0", ), ) batch = ev.sample() @@ -492,7 +490,9 @@ def gen_policy(): "policy_1": gen_policy(), "policy_2": gen_policy(), }, - policy_mapping_fn=lambda agent_id, **kwargs: "policy_1", + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: ( + "policy_1" + ), ) .framework("tf") ) diff --git a/rllib/evaluation/episode_v2.py b/rllib/evaluation/episode_v2.py index 5986216fc809..0e50505703fc 100644 --- a/rllib/evaluation/episode_v2.py +++ b/rllib/evaluation/episode_v2.py @@ -116,7 +116,9 @@ def policy_for( # duration of this episode to the returned PolicyID. if agent_id not in self._agent_to_policy or refresh: policy_id = self._agent_to_policy[agent_id] = self.policy_mapping_fn( - agent_id, episode=self, worker=self.worker + agent_id, # agent_id + self, # episode + worker=self.worker, ) # Use already determined PolicyID. else: diff --git a/rllib/evaluation/tests/test_episode.py b/rllib/evaluation/tests/test_episode.py index 1e01ed29dc72..69dbebed275c 100644 --- a/rllib/evaluation/tests/test_episode.py +++ b/rllib/evaluation/tests/test_episode.py @@ -131,7 +131,9 @@ def test_multi_agent_env(self): .callbacks(LastInfoCallback) .multi_agent( policies={str(agent_id) for agent_id in range(NUM_AGENTS)}, - policy_mapping_fn=lambda agent_id, episode, **kwargs: str(agent_id), + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: ( + str(agent_id) + ), ), ) ev.sample() diff --git a/rllib/evaluation/tests/test_episode_v2.py b/rllib/evaluation/tests/test_episode_v2.py index 12f7e56642cd..245737d85a2d 100644 --- a/rllib/evaluation/tests/test_episode_v2.py +++ b/rllib/evaluation/tests/test_episode_v2.py @@ -87,7 +87,9 @@ def test_multi_agent_env(self): config=AlgorithmConfig() .multi_agent( policies={str(agent_id) for agent_id in range(NUM_AGENTS)}, - policy_mapping_fn=lambda agent_id, episode, **kwargs: str(agent_id), + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: ( + str(agent_id) + ), ) .rollouts(enable_connectors=True, num_rollout_workers=0), ) diff --git a/rllib/evaluation/tests/test_rollout_worker.py b/rllib/evaluation/tests/test_rollout_worker.py index d114184d0d61..6724f0d61c33 100644 --- a/rllib/evaluation/tests/test_rollout_worker.py +++ b/rllib/evaluation/tests/test_rollout_worker.py @@ -739,7 +739,7 @@ def test_truncate_episodes(self): .multi_agent( policies={"pol0", "pol1"}, policy_mapping_fn=( - lambda agent_id, episode, **kwargs: "pol0" + lambda agent_id, episode, worker, **kwargs: "pol0" if agent_id == 0 else "pol1" ), @@ -764,7 +764,7 @@ def test_truncate_episodes(self): count_steps_by="agent_steps", policies={"pol0", "pol1"}, policy_mapping_fn=( - lambda agent_id, episode, **kwargs: "pol0" + lambda agent_id, episode, worker, **kwargs: "pol0" if agent_id == 0 else "pol1" ), diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index 1f30e60306a8..8f209803d97f 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -286,7 +286,7 @@ def test_traj_view_lstm_functionality(self): "pol0": (EpisodeEnvAwareLSTMPolicy, obs_space, action_space, None), } - def policy_fn(agent_id, episode, **kwargs): + def policy_fn(agent_id, episode, worker, **kwargs): return "pol0" rw = RolloutWorker( @@ -329,7 +329,7 @@ def test_traj_view_attention_functionality(self): "pol0": (EpisodeEnvAwareAttentionPolicy, obs_space, action_space, None), } - def policy_fn(agent_id, episode, **kwargs): + def policy_fn(agent_id, episode, worker, **kwargs): return "pol0" config = ( @@ -360,7 +360,9 @@ def test_counting_by_agent_steps(self): config.framework("torch") config.multi_agent( policies={f"p{i}" for i in range(num_agents)}, - policy_mapping_fn=lambda agent_id, **kwargs: "p{}".format(agent_id), + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: ( + "p{}".format(agent_id) + ), count_steps_by="agent_steps", ) diff --git a/rllib/examples/centralized_critic.py b/rllib/examples/centralized_critic.py index 2702fd5b4892..3923177f1ac6 100644 --- a/rllib/examples/centralized_critic.py +++ b/rllib/examples/centralized_critic.py @@ -247,7 +247,7 @@ def get_default_policy_class(cls, config): if __name__ == "__main__": - ray.init() + ray.init(local_mode=True) args = parser.parse_args() ModelCatalog.register_custom_model( @@ -280,7 +280,7 @@ def get_default_policy_class(cls, config): PPOConfig.overrides(framework_str=args.framework), ), }, - policy_mapping_fn=lambda agent_id, **kwargs: "pol1" + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol1" if agent_id == 0 else "pol2", ) diff --git a/rllib/examples/centralized_critic_2.py b/rllib/examples/centralized_critic_2.py index 66682a8b5c23..55cd4afee30f 100644 --- a/rllib/examples/centralized_critic_2.py +++ b/rllib/examples/centralized_critic_2.py @@ -132,7 +132,7 @@ def central_critic_observer(agent_obs, **kw): "pol1": (None, observer_space, action_space, {}), "pol2": (None, observer_space, action_space, {}), }, - policy_mapping_fn=lambda agent_id, **kwargs: "pol1" + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol1" if agent_id == 0 else "pol2", observation_fn=central_critic_observer, diff --git a/rllib/examples/coin_game_env.py b/rllib/examples/coin_game_env.py index b2a98202d572..0453a7d5abc8 100644 --- a/rllib/examples/coin_game_env.py +++ b/rllib/examples/coin_game_env.py @@ -50,7 +50,7 @@ def main(debug, stop_iters=2000, tf=False, asymmetric_env=False): {}, ), }, - "policy_mapping_fn": lambda agent_id, **kwargs: agent_id, + "policy_mapping_fn": lambda agent_id, episode, worker, **kwargs: agent_id, }, # Size of batches collected from each worker. "rollout_fragment_length": 20, diff --git a/rllib/examples/iterated_prisoners_dilemma_env.py b/rllib/examples/iterated_prisoners_dilemma_env.py index a5c1ccbfff7b..6aac2e670e26 100644 --- a/rllib/examples/iterated_prisoners_dilemma_env.py +++ b/rllib/examples/iterated_prisoners_dilemma_env.py @@ -74,7 +74,7 @@ def get_rllib_config(seeds, debug=False, stop_iters=200, framework="tf"): {}, ), }, - "policy_mapping_fn": lambda agent_id, **kwargs: agent_id, + "policy_mapping_fn": lambda agent_id, episode, worker, **kwargs: agent_id, }, "seed": tune.grid_search(seeds), "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), diff --git a/rllib/examples/multi_agent_custom_policy.py b/rllib/examples/multi_agent_custom_policy.py index 6119cd8cfdf8..26cddd16a1da 100644 --- a/rllib/examples/multi_agent_custom_policy.py +++ b/rllib/examples/multi_agent_custom_policy.py @@ -73,9 +73,11 @@ }, # Map to either random behavior or PR learning behavior based on # the agent's ID. - policy_mapping_fn=lambda agent_id, **kwargs: ["pg_policy", "random"][ - agent_id % 2 - ], + policy_mapping_fn=( + lambda agent_id, episode, worker, **kwargs: ( + ["pg_policy", "random"][agent_id % 2] + ) + ), # We wouldn't have to specify this here as the RandomPolicy does # not learn anyways (it has an empty `learn_on_batch` method), but # it's good practice to define this list here either way. diff --git a/rllib/examples/multi_agent_independent_learning.py b/rllib/examples/multi_agent_independent_learning.py index 8f9ee405c971..365f8ecb455a 100644 --- a/rllib/examples/multi_agent_independent_learning.py +++ b/rllib/examples/multi_agent_independent_learning.py @@ -31,7 +31,9 @@ def env_creator(args): # Method specific "multiagent": { "policies": set(env.agents), - "policy_mapping_fn": (lambda agent_id, episode, **kwargs: agent_id), + "policy_mapping_fn": ( + lambda agent_id, episode, worker, **kwargs: agent_id + ), }, }, ).fit() diff --git a/rllib/examples/multi_agent_parameter_sharing.py b/rllib/examples/multi_agent_parameter_sharing.py index 445c3f79309c..676e381761f0 100644 --- a/rllib/examples/multi_agent_parameter_sharing.py +++ b/rllib/examples/multi_agent_parameter_sharing.py @@ -47,7 +47,7 @@ "policies": {"shared_policy"}, # Always use "shared" policy. "policy_mapping_fn": ( - lambda agent_id, episode, **kwargs: "shared_policy" + lambda agent_id, episode, worker, **kwargs: "shared_policy" ), }, }, diff --git a/rllib/examples/sumo_env_local.py b/rllib/examples/sumo_env_local.py index 2347c469e275..c1840137be30 100644 --- a/rllib/examples/sumo_env_local.py +++ b/rllib/examples/sumo_env_local.py @@ -148,7 +148,7 @@ ) config.multi_agent( policies=policies, - policy_mapping_fn=lambda agent_id, episode, **kwargs: agent_id, + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: agent_id, policies_to_train=["ppo_policy"], ) config.environment("sumo_test_env", env_config=env_config) diff --git a/rllib/examples/two_step_game.py b/rllib/examples/two_step_game.py index ca826774f628..8ddb3932d7da 100644 --- a/rllib/examples/two_step_game.py +++ b/rllib/examples/two_step_game.py @@ -129,7 +129,7 @@ config=config.overrides(agent_id=1), ), }, - policy_mapping_fn=lambda agent_id, **kwargs: "pol2" + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol2" if agent_id else "pol1", ) diff --git a/rllib/tests/test_io.py b/rllib/tests/test_io.py index ac09c10c3afe..8b4234a9ce28 100644 --- a/rllib/tests/test_io.py +++ b/rllib/tests/test_io.py @@ -230,8 +230,8 @@ def test_multi_agent(self): .rollouts(num_rollout_workers=0) .multi_agent( policies={"policy_1", "policy_2"}, - policy_mapping_fn=( - lambda agent_id, **kwargs: random.choice(["policy_1", "policy_2"]) + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: ( + random.choice(["policy_1", "policy_2"]) ), ) ) diff --git a/rllib/tests/test_nested_observation_spaces.py b/rllib/tests/test_nested_observation_spaces.py index 346ec6324f72..484a03d6b962 100644 --- a/rllib/tests/test_nested_observation_spaces.py +++ b/rllib/tests/test_nested_observation_spaces.py @@ -518,7 +518,7 @@ def test_multi_agent_complex_spaces(self): ), }, policy_mapping_fn=( - lambda agent_id, **kwargs: { + lambda agent_id, episode, worker, **kwargs: { "tuple_agent": "tuple_policy", "dict_agent": "dict_policy", }[agent_id] diff --git a/rllib/tests/test_pettingzoo_env.py b/rllib/tests/test_pettingzoo_env.py index fe3033024cf7..70cd6a53eeb4 100644 --- a/rllib/tests/test_pettingzoo_env.py +++ b/rllib/tests/test_pettingzoo_env.py @@ -41,7 +41,7 @@ def env_creator(config): # Setup a single, shared policy for all agents. policies={"av": (None, observation_space, action_space, {})}, # Map all agents to that policy. - policy_mapping_fn=lambda agent_id, episode, **kwargs: "av", + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "av", ) .debugging(log_level="DEBUG") .rollouts( @@ -77,7 +77,7 @@ def test_pettingzoo_env(self): policies={"av": (None, observation_space, action_space, {})}, # Mapping function that always returns "av" as policy ID to use # (for any agent). - policy_mapping_fn=lambda agent_id, episode, **kwargs: "av", + policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "av", ) )