Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Policy mapping fn can not be called with keyword arguments. #31141

Merged
merged 5 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rllib/algorithms/maddpg/tests/test_maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_maddpg_compilation(self):
config=maddpg.MADDPGConfig.overrides(agent_id=1),
),
},
policy_mapping_fn=lambda agent_id, **kwargs: "pol2"
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol2"
if agent_id
else "pol1",
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/tests/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def _has_policies(w):
# Note that the complete signature of a policy_mapping_fn
# is: `agent_id, episode, worker, **kwargs`.
policy_mapping_fn=(
lambda agent_id, worker, episode, **kwargs: f"p{i - 1}"
lambda agent_id, episode, worker, **kwargs: f"p{i - 1}"
),
# Update list of policies to train.
policies_to_train=[f"p{i - 1}"],
Expand Down
6 changes: 5 additions & 1 deletion rllib/env/tests/test_external_multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ def test_external_multi_agent_env_sample(self):
)
.multi_agent(
policies={"p0", "p1"},
policy_mapping_fn=lambda agent_id, **kwargs: "p{}".format(agent_id % 2),
policy_mapping_fn=(
lambda agent_id, episode, worker, **kwargs: "p{}".format(
agent_id % 2
)
),
),
)
batch = ev.sample()
Expand Down
22 changes: 11 additions & 11 deletions rllib/env/tests/test_multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,8 @@ def test_multi_agent_sample_sync_remote(self):
)
.multi_agent(
policies={"p0", "p1"},
policy_mapping_fn=(
lambda agent_id, episode, worker, **kwargs: "p{}".format(
agent_id % 2
)
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
"p{}".format(agent_id % 2)
),
),
)
Expand All @@ -200,8 +198,8 @@ def test_multi_agent_sample_async_remote(self):
)
.multi_agent(
policies={"p0", "p1"},
policy_mapping_fn=(
lambda agent_id, **kwargs: "p{}".format(agent_id % 2)
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
"p{}".format(agent_id % 2)
),
),
)
Expand All @@ -220,8 +218,8 @@ def test_sample_from_early_done_env(self):
)
.multi_agent(
policies={"p0", "p1"},
policy_mapping_fn=(
lambda agent_id, **kwargs: "p{}".format(agent_id % 2)
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
"p{}".format(agent_id % 2)
),
),
)
Expand Down Expand Up @@ -289,7 +287,7 @@ def test_multi_agent_sample_round_robin(self):
)
.multi_agent(
policies={"p0"},
policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0",
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "p0",
),
)
batch = ev.sample()
Expand Down Expand Up @@ -434,7 +432,7 @@ def compute_actions_from_input_dict(
)
.multi_agent(
policies={"p0", "p1"},
policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0",
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "p0",
),
)
batch = ev.sample()
Expand Down Expand Up @@ -492,7 +490,9 @@ def gen_policy():
"policy_1": gen_policy(),
"policy_2": gen_policy(),
},
policy_mapping_fn=lambda agent_id, **kwargs: "policy_1",
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
"policy_1"
),
)
.framework("tf")
)
Expand Down
4 changes: 3 additions & 1 deletion rllib/evaluation/episode_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def policy_for(
# duration of this episode to the returned PolicyID.
if agent_id not in self._agent_to_policy or refresh:
policy_id = self._agent_to_policy[agent_id] = self.policy_mapping_fn(
agent_id, episode=self, worker=self.worker
agent_id, # agent_id
self, # episode
worker=self.worker,
)
# Use already determined PolicyID.
else:
Expand Down
4 changes: 3 additions & 1 deletion rllib/evaluation/tests/test_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def test_multi_agent_env(self):
.callbacks(LastInfoCallback)
.multi_agent(
policies={str(agent_id) for agent_id in range(NUM_AGENTS)},
policy_mapping_fn=lambda agent_id, episode, **kwargs: str(agent_id),
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
str(agent_id)
),
),
)
ev.sample()
Expand Down
4 changes: 3 additions & 1 deletion rllib/evaluation/tests/test_episode_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def test_multi_agent_env(self):
config=AlgorithmConfig()
.multi_agent(
policies={str(agent_id) for agent_id in range(NUM_AGENTS)},
policy_mapping_fn=lambda agent_id, episode, **kwargs: str(agent_id),
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
str(agent_id)
),
)
.rollouts(enable_connectors=True, num_rollout_workers=0),
)
Expand Down
4 changes: 2 additions & 2 deletions rllib/evaluation/tests/test_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def test_truncate_episodes(self):
.multi_agent(
policies={"pol0", "pol1"},
policy_mapping_fn=(
lambda agent_id, episode, **kwargs: "pol0"
lambda agent_id, episode, worker, **kwargs: "pol0"
if agent_id == 0
else "pol1"
),
Expand All @@ -764,7 +764,7 @@ def test_truncate_episodes(self):
count_steps_by="agent_steps",
policies={"pol0", "pol1"},
policy_mapping_fn=(
lambda agent_id, episode, **kwargs: "pol0"
lambda agent_id, episode, worker, **kwargs: "pol0"
if agent_id == 0
else "pol1"
),
Expand Down
8 changes: 5 additions & 3 deletions rllib/evaluation/tests/test_trajectory_view_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def test_traj_view_lstm_functionality(self):
"pol0": (EpisodeEnvAwareLSTMPolicy, obs_space, action_space, None),
}

def policy_fn(agent_id, episode, **kwargs):
def policy_fn(agent_id, episode, worker, **kwargs):
return "pol0"

rw = RolloutWorker(
Expand Down Expand Up @@ -329,7 +329,7 @@ def test_traj_view_attention_functionality(self):
"pol0": (EpisodeEnvAwareAttentionPolicy, obs_space, action_space, None),
}

def policy_fn(agent_id, episode, **kwargs):
def policy_fn(agent_id, episode, worker, **kwargs):
return "pol0"

config = (
Expand Down Expand Up @@ -360,7 +360,9 @@ def test_counting_by_agent_steps(self):
config.framework("torch")
config.multi_agent(
policies={f"p{i}" for i in range(num_agents)},
policy_mapping_fn=lambda agent_id, **kwargs: "p{}".format(agent_id),
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
"p{}".format(agent_id)
),
count_steps_by="agent_steps",
)

Expand Down
4 changes: 2 additions & 2 deletions rllib/examples/centralized_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def get_default_policy_class(cls, config):


if __name__ == "__main__":
ray.init()
ray.init(local_mode=True)
args = parser.parse_args()

ModelCatalog.register_custom_model(
Expand Down Expand Up @@ -280,7 +280,7 @@ def get_default_policy_class(cls, config):
PPOConfig.overrides(framework_str=args.framework),
),
},
policy_mapping_fn=lambda agent_id, **kwargs: "pol1"
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol1"
if agent_id == 0
else "pol2",
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/centralized_critic_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def central_critic_observer(agent_obs, **kw):
"pol1": (None, observer_space, action_space, {}),
"pol2": (None, observer_space, action_space, {}),
},
policy_mapping_fn=lambda agent_id, **kwargs: "pol1"
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol1"
if agent_id == 0
else "pol2",
observation_fn=central_critic_observer,
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/coin_game_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def main(debug, stop_iters=2000, tf=False, asymmetric_env=False):
{},
),
},
"policy_mapping_fn": lambda agent_id, **kwargs: agent_id,
"policy_mapping_fn": lambda agent_id, episode, worker, **kwargs: agent_id,
},
# Size of batches collected from each worker.
"rollout_fragment_length": 20,
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/iterated_prisoners_dilemma_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_rllib_config(seeds, debug=False, stop_iters=200, framework="tf"):
{},
),
},
"policy_mapping_fn": lambda agent_id, **kwargs: agent_id,
"policy_mapping_fn": lambda agent_id, episode, worker, **kwargs: agent_id,
},
"seed": tune.grid_search(seeds),
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
Expand Down
8 changes: 5 additions & 3 deletions rllib/examples/multi_agent_custom_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@
},
# Map to either random behavior or PR learning behavior based on
# the agent's ID.
policy_mapping_fn=lambda agent_id, **kwargs: ["pg_policy", "random"][
agent_id % 2
],
policy_mapping_fn=(
lambda agent_id, episode, worker, **kwargs: (
["pg_policy", "random"][agent_id % 2]
)
),
# We wouldn't have to specify this here as the RandomPolicy does
# not learn anyways (it has an empty `learn_on_batch` method), but
# it's good practice to define this list here either way.
Expand Down
4 changes: 3 additions & 1 deletion rllib/examples/multi_agent_independent_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def env_creator(args):
# Method specific
"multiagent": {
"policies": set(env.agents),
"policy_mapping_fn": (lambda agent_id, episode, **kwargs: agent_id),
"policy_mapping_fn": (
lambda agent_id, episode, worker, **kwargs: agent_id
),
},
},
).fit()
2 changes: 1 addition & 1 deletion rllib/examples/multi_agent_parameter_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"policies": {"shared_policy"},
# Always use "shared" policy.
"policy_mapping_fn": (
lambda agent_id, episode, **kwargs: "shared_policy"
lambda agent_id, episode, worker, **kwargs: "shared_policy"
),
},
},
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/sumo_env_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@
)
config.multi_agent(
policies=policies,
policy_mapping_fn=lambda agent_id, episode, **kwargs: agent_id,
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: agent_id,
policies_to_train=["ppo_policy"],
)
config.environment("sumo_test_env", env_config=env_config)
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/two_step_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@
config=config.overrides(agent_id=1),
),
},
policy_mapping_fn=lambda agent_id, **kwargs: "pol2"
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol2"
if agent_id
else "pol1",
)
Expand Down
4 changes: 2 additions & 2 deletions rllib/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ def test_multi_agent(self):
.rollouts(num_rollout_workers=0)
.multi_agent(
policies={"policy_1", "policy_2"},
policy_mapping_fn=(
lambda agent_id, **kwargs: random.choice(["policy_1", "policy_2"])
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
random.choice(["policy_1", "policy_2"])
),
)
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/tests/test_nested_observation_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def test_multi_agent_complex_spaces(self):
),
},
policy_mapping_fn=(
lambda agent_id, **kwargs: {
lambda agent_id, episode, worker, **kwargs: {
"tuple_agent": "tuple_policy",
"dict_agent": "dict_policy",
}[agent_id]
Expand Down
4 changes: 2 additions & 2 deletions rllib/tests/test_pettingzoo_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def env_creator(config):
# Setup a single, shared policy for all agents.
policies={"av": (None, observation_space, action_space, {})},
# Map all agents to that policy.
policy_mapping_fn=lambda agent_id, episode, **kwargs: "av",
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "av",
)
.debugging(log_level="DEBUG")
.rollouts(
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_pettingzoo_env(self):
policies={"av": (None, observation_space, action_space, {})},
# Mapping function that always returns "av" as policy ID to use
# (for any agent).
policy_mapping_fn=lambda agent_id, episode, **kwargs: "av",
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "av",
)
)

Expand Down