Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] MultiAgentEnv API enhancements (related to defining obs-/action spaces for agents). #47830

Merged
209 changes: 125 additions & 84 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,12 +982,26 @@ def build_env_to_module_connector(self, env):
f"pipeline)! Your function returned {val_}."
)

obs_space = getattr(env, "single_observation_space", env.observation_space)
if obs_space is None and self.is_multi_agent():
obs_space = gym.spaces.Dict(
{
aid: env.get_observation_space(aid)
for aid in env.unwrapped.possible_agents
}
)
act_space = getattr(env, "single_action_space", env.action_space)
if act_space is None and self.is_multi_agent():
act_space = gym.spaces.Dict(
{
aid: env.get_action_space(aid)
for aid in env.unwrapped.possible_agents
}
)
pipeline = EnvToModulePipeline(
input_observation_space=obs_space,
input_action_space=act_space,
connectors=custom_connectors,
input_observation_space=getattr(
env, "single_observation_space", env.observation_space
),
input_action_space=getattr(env, "single_action_space", env.action_space),
)

if self.add_default_connectors_to_env_to_module_pipeline:
Expand Down Expand Up @@ -1048,12 +1062,26 @@ def build_module_to_env_connector(self, env):
f"pipeline)! Your function returned {val_}."
)

obs_space = getattr(env, "single_observation_space", env.observation_space)
if obs_space is None and self.is_multi_agent():
obs_space = gym.spaces.Dict(
{
aid: env.get_observation_space(aid)
for aid in env.unwrapped.possible_agents
}
)
act_space = getattr(env, "single_action_space", env.action_space)
if act_space is None and self.is_multi_agent():
act_space = gym.spaces.Dict(
{
aid: env.get_action_space(aid)
for aid in env.unwrapped.possible_agents
}
)
pipeline = ModuleToEnvPipeline(
input_observation_space=obs_space,
input_action_space=act_space,
connectors=custom_connectors,
input_observation_space=getattr(
env, "single_observation_space", env.observation_space
),
input_action_space=getattr(env, "single_action_space", env.action_space),
)

if self.add_default_connectors_to_module_to_env_pipeline:
Expand Down Expand Up @@ -4916,47 +4944,54 @@ def get_multi_agent_setup(

# Infer observation space.
if policy_spec.observation_space is None:
env_unwrapped = env.unwrapped if hasattr(env, "unwrapped") else env
# Module's space is provided -> Use it as-is.
if spaces is not None and pid in spaces:
obs_space = spaces[pid][0]
elif env_obs_space is not None:
env_unwrapped = env.unwrapped if hasattr(env, "unwrapped") else env
# Multi-agent case AND different agents have different spaces:
# Need to reverse map spaces (for the different agents) to certain
# policy IDs.
if (
isinstance(env_unwrapped, MultiAgentEnv)
and hasattr(env_unwrapped, "_obs_space_in_preferred_format")
and env_unwrapped._obs_space_in_preferred_format
):
obs_space = None
mapping_fn = self.policy_mapping_fn
one_obs_space = next(iter(env_obs_space.values()))
# If all obs spaces are the same anyways, just use the first
# single-agent space.
if all(s == one_obs_space for s in env_obs_space.values()):
obs_space = one_obs_space
# Otherwise, we have to compare the ModuleID with all possible
# AgentIDs and find the agent ID that matches.
elif mapping_fn:
for aid in env_unwrapped.get_agent_ids():
# Match: Assign spaces for this agentID to the PolicyID.
if mapping_fn(aid, None, worker=None) == pid:
# Make sure, different agents that map to the same
# policy don't have different spaces.
if (
obs_space is not None
and env_obs_space[aid] != obs_space
):
raise ValueError(
"Two agents in your environment map to the "
"same policyID (as per your `policy_mapping"
"_fn`), however, these agents also have "
"different observation spaces!"
)
obs_space = env_obs_space[aid]
# Otherwise, just use env's obs space as-is.
# MultiAgentEnv -> Check, whether agents have different spaces.
elif isinstance(env_unwrapped, MultiAgentEnv):
obs_space = None
mapping_fn = self.policy_mapping_fn
aids = list(
env_unwrapped.possible_agents
if hasattr(env_unwrapped, "possible_agents")
and env_unwrapped.possible_agents
else env_unwrapped.get_agent_ids()
)
if len(aids) == 0:
one_obs_space = env_unwrapped.observation_space
else:
obs_space = env_obs_space
one_obs_space = env_unwrapped.get_observation_space(aids[0])
# If all obs spaces are the same, just use the first space.
if all(
env_unwrapped.get_observation_space(aid) == one_obs_space
for aid in aids
):
obs_space = one_obs_space
# Need to reverse-map spaces (for the different agents) to certain
# policy IDs. We have to compare the ModuleID with all possible
# AgentIDs and find the agent ID that matches.
elif mapping_fn:
for aid in aids:
# Match: Assign spaces for this agentID to the PolicyID.
if mapping_fn(aid, None, worker=None) == pid:
# Make sure, different agents that map to the same
# policy don't have different spaces.
if (
obs_space is not None
and env_unwrapped.get_observation_space(aid)
!= obs_space
):
raise ValueError(
"Two agents in your environment map to the "
"same policyID (as per your `policy_mapping"
"_fn`), however, these agents also have "
"different observation spaces!"
)
obs_space = env_unwrapped.get_observation_space(aid)
# Just use env's obs space as-is.
elif env_obs_space is not None:
obs_space = env_obs_space
# Space given directly in config.
elif self.observation_space:
obs_space = self.observation_space
Expand All @@ -4972,47 +5007,53 @@ def get_multi_agent_setup(

# Infer action space.
if policy_spec.action_space is None:
env_unwrapped = env.unwrapped if hasattr(env, "unwrapped") else env
# Module's space is provided -> Use it as-is.
if spaces is not None and pid in spaces:
act_space = spaces[pid][1]
elif env_act_space is not None:
env_unwrapped = env.unwrapped if hasattr(env, "unwrapped") else env
# Multi-agent case AND different agents have different spaces:
# Need to reverse map spaces (for the different agents) to certain
# policy IDs.
if (
isinstance(env_unwrapped, MultiAgentEnv)
and hasattr(env_unwrapped, "_action_space_in_preferred_format")
and env_unwrapped._action_space_in_preferred_format
):
act_space = None
mapping_fn = self.policy_mapping_fn
one_act_space = next(iter(env_act_space.values()))
# If all action spaces are the same anyways, just use the first
# single-agent space.
if all(s == one_act_space for s in env_act_space.values()):
act_space = one_act_space
# Otherwise, we have to compare the ModuleID with all possible
# AgentIDs and find the agent ID that matches.
elif mapping_fn:
for aid in env_unwrapped.get_agent_ids():
# Match: Assign spaces for this AgentID to the PolicyID.
if mapping_fn(aid, None, worker=None) == pid:
# Make sure, different agents that map to the same
# policy don't have different spaces.
if (
act_space is not None
and env_act_space[aid] != act_space
):
raise ValueError(
"Two agents in your environment map to the "
"same policyID (as per your `policy_mapping"
"_fn`), however, these agents also have "
"different action spaces!"
)
act_space = env_act_space[aid]
# Otherwise, just use env's action space as-is.
# MultiAgentEnv -> Check, whether agents have different spaces.
elif isinstance(env_unwrapped, MultiAgentEnv):
act_space = None
mapping_fn = self.policy_mapping_fn
aids = list(
env_unwrapped.possible_agents
if hasattr(env_unwrapped, "possible_agents")
and env_unwrapped.possible_agents
else env_unwrapped.get_agent_ids()
)
if len(aids) == 0:
one_act_space = env_unwrapped.action_space
else:
act_space = env_act_space
one_act_space = env_unwrapped.get_action_space(aids[0])
# If all obs spaces are the same, just use the first space.
if all(
env_unwrapped.get_action_space(aid) == one_act_space
for aid in aids
):
act_space = one_act_space
# Need to reverse-map spaces (for the different agents) to certain
# policy IDs. We have to compare the ModuleID with all possible
# AgentIDs and find the agent ID that matches.
elif mapping_fn:
for aid in aids:
# Match: Assign spaces for this AgentID to the PolicyID.
if mapping_fn(aid, None, worker=None) == pid:
# Make sure, different agents that map to the same
# policy don't have different spaces.
if (
act_space is not None
and env_unwrapped.get_action_space(aid) != act_space
):
raise ValueError(
"Two agents in your environment map to the "
"same policyID (as per your `policy_mapping"
"_fn`), however, these agents also have "
"different action spaces!"
)
act_space = env_unwrapped.get_action_space(aid)
# Just use env's action space as-is.
elif env_act_space is not None:
act_space = env_act_space
elif self.action_space:
act_space = self.action_space
else:
Expand Down
28 changes: 14 additions & 14 deletions rllib/algorithms/tests/test_algorithm_rl_module_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def test_e2e_load_simple_multi_rl_module(self):
for i in range(NUM_AGENTS):
module_specs[f"policy_{i}"] = RLModuleSpec(
module_class=module_class,
observation_space=env.observation_space[0],
action_space=env.action_space[0],
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
# If we want to use this externally created module in the algorithm,
# we need to provide the same config as the algorithm.
model_config_dict=config.model_config
Expand Down Expand Up @@ -115,8 +115,8 @@ def test_e2e_load_complex_multi_rl_module(self):
for i in range(NUM_AGENTS):
module_specs[f"policy_{i}"] = RLModuleSpec(
module_class=module_class,
observation_space=env.observation_space[0],
action_space=env.action_space[0],
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
# If we want to use this externally created module in the algorithm,
# we need to provide the same config as the algorithm.
model_config_dict=config.model_config
Expand All @@ -131,8 +131,8 @@ def test_e2e_load_complex_multi_rl_module(self):
# create a RLModule to load and override the "policy_1" module with
module_to_swap_in = RLModuleSpec(
module_class=module_class,
observation_space=env.observation_space[0],
action_space=env.action_space[0],
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
# Note, we need to pass in the default model config for the algorithm
# to be able to use this module later.
model_config_dict=config.model_config | {"fcnet_hiddens": [64]},
Expand All @@ -146,8 +146,8 @@ def test_e2e_load_complex_multi_rl_module(self):
# and the module_to_swap_in_checkpoint
module_specs["policy_1"] = RLModuleSpec(
module_class=module_class,
observation_space=env.observation_space[0],
action_space=env.action_space[0],
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [64]},
catalog_class=PPOCatalog,
load_state_path=module_to_swap_in_path,
Expand Down Expand Up @@ -258,8 +258,8 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self):
for i in range(num_agents):
module_specs[f"policy_{i}"] = RLModuleSpec(
module_class=module_class,
observation_space=env.observation_space[0],
action_space=env.action_space[0],
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
# Note, we need to pass in the default model config for the
# algorithm to be able to use this module later.
model_config_dict=config.model_config
Expand All @@ -274,8 +274,8 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self):
# create a RLModule to load and override the "policy_1" module with
module_to_swap_in = RLModuleSpec(
module_class=module_class,
observation_space=env.observation_space[0],
action_space=env.action_space[0],
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
# Note, we need to pass in the default model config for the algorithm
# to be able to use this module later.
model_config_dict=config.model_config | {"fcnet_hiddens": [64]},
Expand All @@ -289,8 +289,8 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self):
# and the module_to_swap_in_checkpoint
module_specs["policy_1"] = RLModuleSpec(
module_class=module_class,
observation_space=env.observation_space[0],
action_space=env.action_space[0],
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [64]},
catalog_class=PPOCatalog,
load_state_path=module_to_swap_in_path,
Expand Down
Loading
Loading