Skip to content

Commit

Permalink
[RLlib] Policy mapping fn can not be called with keyword arguments. (r…
Browse files Browse the repository at this point in the history
…ay-project#31141)

Signed-off-by: Jun Gong <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
Jun Gong authored and WeichenXu123 committed Dec 19, 2022
1 parent 924004d commit 4837e48
Show file tree
Hide file tree
Showing 21 changed files with 55 additions and 39 deletions.
2 changes: 1 addition & 1 deletion rllib/algorithms/maddpg/tests/test_maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_maddpg_compilation(self):
config=maddpg.MADDPGConfig.overrides(agent_id=1),
),
},
policy_mapping_fn=lambda agent_id, **kwargs: "pol2"
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol2"
if agent_id
else "pol1",
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/tests/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def _has_policies(w):
# Note that the complete signature of a policy_mapping_fn
# is: `agent_id, episode, worker, **kwargs`.
policy_mapping_fn=(
lambda agent_id, worker, episode, **kwargs: f"p{i - 1}"
lambda agent_id, episode, worker, **kwargs: f"p{i - 1}"
),
# Update list of policies to train.
policies_to_train=[f"p{i - 1}"],
Expand Down
6 changes: 5 additions & 1 deletion rllib/env/tests/test_external_multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ def test_external_multi_agent_env_sample(self):
)
.multi_agent(
policies={"p0", "p1"},
policy_mapping_fn=lambda agent_id, **kwargs: "p{}".format(agent_id % 2),
policy_mapping_fn=(
lambda agent_id, episode, worker, **kwargs: "p{}".format(
agent_id % 2
)
),
),
)
batch = ev.sample()
Expand Down
22 changes: 11 additions & 11 deletions rllib/env/tests/test_multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,8 @@ def test_multi_agent_sample_sync_remote(self):
)
.multi_agent(
policies={"p0", "p1"},
policy_mapping_fn=(
lambda agent_id, episode, worker, **kwargs: "p{}".format(
agent_id % 2
)
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
"p{}".format(agent_id % 2)
),
),
)
Expand All @@ -200,8 +198,8 @@ def test_multi_agent_sample_async_remote(self):
)
.multi_agent(
policies={"p0", "p1"},
policy_mapping_fn=(
lambda agent_id, **kwargs: "p{}".format(agent_id % 2)
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
"p{}".format(agent_id % 2)
),
),
)
Expand All @@ -220,8 +218,8 @@ def test_sample_from_early_done_env(self):
)
.multi_agent(
policies={"p0", "p1"},
policy_mapping_fn=(
lambda agent_id, **kwargs: "p{}".format(agent_id % 2)
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
"p{}".format(agent_id % 2)
),
),
)
Expand Down Expand Up @@ -289,7 +287,7 @@ def test_multi_agent_sample_round_robin(self):
)
.multi_agent(
policies={"p0"},
policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0",
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "p0",
),
)
batch = ev.sample()
Expand Down Expand Up @@ -434,7 +432,7 @@ def compute_actions_from_input_dict(
)
.multi_agent(
policies={"p0", "p1"},
policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0",
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "p0",
),
)
batch = ev.sample()
Expand Down Expand Up @@ -492,7 +490,9 @@ def gen_policy():
"policy_1": gen_policy(),
"policy_2": gen_policy(),
},
policy_mapping_fn=lambda agent_id, **kwargs: "policy_1",
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
"policy_1"
),
)
.framework("tf")
)
Expand Down
4 changes: 3 additions & 1 deletion rllib/evaluation/episode_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def policy_for(
# duration of this episode to the returned PolicyID.
if agent_id not in self._agent_to_policy or refresh:
policy_id = self._agent_to_policy[agent_id] = self.policy_mapping_fn(
agent_id, episode=self, worker=self.worker
agent_id, # agent_id
self, # episode
worker=self.worker,
)
# Use already determined PolicyID.
else:
Expand Down
4 changes: 3 additions & 1 deletion rllib/evaluation/tests/test_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def test_multi_agent_env(self):
.callbacks(LastInfoCallback)
.multi_agent(
policies={str(agent_id) for agent_id in range(NUM_AGENTS)},
policy_mapping_fn=lambda agent_id, episode, **kwargs: str(agent_id),
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
str(agent_id)
),
),
)
ev.sample()
Expand Down
4 changes: 3 additions & 1 deletion rllib/evaluation/tests/test_episode_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def test_multi_agent_env(self):
config=AlgorithmConfig()
.multi_agent(
policies={str(agent_id) for agent_id in range(NUM_AGENTS)},
policy_mapping_fn=lambda agent_id, episode, **kwargs: str(agent_id),
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
str(agent_id)
),
)
.rollouts(enable_connectors=True, num_rollout_workers=0),
)
Expand Down
4 changes: 2 additions & 2 deletions rllib/evaluation/tests/test_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def test_truncate_episodes(self):
.multi_agent(
policies={"pol0", "pol1"},
policy_mapping_fn=(
lambda agent_id, episode, **kwargs: "pol0"
lambda agent_id, episode, worker, **kwargs: "pol0"
if agent_id == 0
else "pol1"
),
Expand All @@ -764,7 +764,7 @@ def test_truncate_episodes(self):
count_steps_by="agent_steps",
policies={"pol0", "pol1"},
policy_mapping_fn=(
lambda agent_id, episode, **kwargs: "pol0"
lambda agent_id, episode, worker, **kwargs: "pol0"
if agent_id == 0
else "pol1"
),
Expand Down
8 changes: 5 additions & 3 deletions rllib/evaluation/tests/test_trajectory_view_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def test_traj_view_lstm_functionality(self):
"pol0": (EpisodeEnvAwareLSTMPolicy, obs_space, action_space, None),
}

def policy_fn(agent_id, episode, **kwargs):
def policy_fn(agent_id, episode, worker, **kwargs):
return "pol0"

rw = RolloutWorker(
Expand Down Expand Up @@ -329,7 +329,7 @@ def test_traj_view_attention_functionality(self):
"pol0": (EpisodeEnvAwareAttentionPolicy, obs_space, action_space, None),
}

def policy_fn(agent_id, episode, **kwargs):
def policy_fn(agent_id, episode, worker, **kwargs):
return "pol0"

config = (
Expand Down Expand Up @@ -360,7 +360,9 @@ def test_counting_by_agent_steps(self):
config.framework("torch")
config.multi_agent(
policies={f"p{i}" for i in range(num_agents)},
policy_mapping_fn=lambda agent_id, **kwargs: "p{}".format(agent_id),
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
"p{}".format(agent_id)
),
count_steps_by="agent_steps",
)

Expand Down
4 changes: 2 additions & 2 deletions rllib/examples/centralized_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def get_default_policy_class(cls, config):


if __name__ == "__main__":
ray.init()
ray.init(local_mode=True)
args = parser.parse_args()

ModelCatalog.register_custom_model(
Expand Down Expand Up @@ -280,7 +280,7 @@ def get_default_policy_class(cls, config):
PPOConfig.overrides(framework_str=args.framework),
),
},
policy_mapping_fn=lambda agent_id, **kwargs: "pol1"
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol1"
if agent_id == 0
else "pol2",
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/centralized_critic_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def central_critic_observer(agent_obs, **kw):
"pol1": (None, observer_space, action_space, {}),
"pol2": (None, observer_space, action_space, {}),
},
policy_mapping_fn=lambda agent_id, **kwargs: "pol1"
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol1"
if agent_id == 0
else "pol2",
observation_fn=central_critic_observer,
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/coin_game_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def main(debug, stop_iters=2000, tf=False, asymmetric_env=False):
{},
),
},
"policy_mapping_fn": lambda agent_id, **kwargs: agent_id,
"policy_mapping_fn": lambda agent_id, episode, worker, **kwargs: agent_id,
},
# Size of batches collected from each worker.
"rollout_fragment_length": 20,
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/iterated_prisoners_dilemma_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_rllib_config(seeds, debug=False, stop_iters=200, framework="tf"):
{},
),
},
"policy_mapping_fn": lambda agent_id, **kwargs: agent_id,
"policy_mapping_fn": lambda agent_id, episode, worker, **kwargs: agent_id,
},
"seed": tune.grid_search(seeds),
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
Expand Down
8 changes: 5 additions & 3 deletions rllib/examples/multi_agent_custom_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@
},
# Map to either random behavior or PR learning behavior based on
# the agent's ID.
policy_mapping_fn=lambda agent_id, **kwargs: ["pg_policy", "random"][
agent_id % 2
],
policy_mapping_fn=(
lambda agent_id, episode, worker, **kwargs: (
["pg_policy", "random"][agent_id % 2]
)
),
# We wouldn't have to specify this here as the RandomPolicy does
# not learn anyways (it has an empty `learn_on_batch` method), but
# it's good practice to define this list here either way.
Expand Down
4 changes: 3 additions & 1 deletion rllib/examples/multi_agent_independent_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def env_creator(args):
# Method specific
"multiagent": {
"policies": set(env.agents),
"policy_mapping_fn": (lambda agent_id, episode, **kwargs: agent_id),
"policy_mapping_fn": (
lambda agent_id, episode, worker, **kwargs: agent_id
),
},
},
).fit()
2 changes: 1 addition & 1 deletion rllib/examples/multi_agent_parameter_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"policies": {"shared_policy"},
# Always use "shared" policy.
"policy_mapping_fn": (
lambda agent_id, episode, **kwargs: "shared_policy"
lambda agent_id, episode, worker, **kwargs: "shared_policy"
),
},
},
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/sumo_env_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@
)
config.multi_agent(
policies=policies,
policy_mapping_fn=lambda agent_id, episode, **kwargs: agent_id,
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: agent_id,
policies_to_train=["ppo_policy"],
)
config.environment("sumo_test_env", env_config=env_config)
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/two_step_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@
config=config.overrides(agent_id=1),
),
},
policy_mapping_fn=lambda agent_id, **kwargs: "pol2"
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol2"
if agent_id
else "pol1",
)
Expand Down
4 changes: 2 additions & 2 deletions rllib/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ def test_multi_agent(self):
.rollouts(num_rollout_workers=0)
.multi_agent(
policies={"policy_1", "policy_2"},
policy_mapping_fn=(
lambda agent_id, **kwargs: random.choice(["policy_1", "policy_2"])
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
random.choice(["policy_1", "policy_2"])
),
)
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/tests/test_nested_observation_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def test_multi_agent_complex_spaces(self):
),
},
policy_mapping_fn=(
lambda agent_id, **kwargs: {
lambda agent_id, episode, worker, **kwargs: {
"tuple_agent": "tuple_policy",
"dict_agent": "dict_policy",
}[agent_id]
Expand Down
4 changes: 2 additions & 2 deletions rllib/tests/test_pettingzoo_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def env_creator(config):
# Setup a single, shared policy for all agents.
policies={"av": (None, observation_space, action_space, {})},
# Map all agents to that policy.
policy_mapping_fn=lambda agent_id, episode, **kwargs: "av",
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "av",
)
.debugging(log_level="DEBUG")
.rollouts(
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_pettingzoo_env(self):
policies={"av": (None, observation_space, action_space, {})},
# Mapping function that always returns "av" as policy ID to use
# (for any agent).
policy_mapping_fn=lambda agent_id, episode, **kwargs: "av",
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "av",
)
)

Expand Down

0 comments on commit 4837e48

Please sign in to comment.