Skip to content

Commit

Permalink
[RLlib] MultiAgentEnv API enhancements (related to defining obs-/acti…
Browse files Browse the repository at this point in the history
…on spaces for agents). (ray-project#47830)

Signed-off-by: ujjawal-khare <[email protected]>
  • Loading branch information
sven1977 authored and ujjawal-khare committed Oct 15, 2024
1 parent 9ac5d16 commit 993452c
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 119 deletions.
18 changes: 9 additions & 9 deletions rllib/algorithms/tests/test_algorithm_rl_module_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_e2e_load_simple_multi_rl_module(self):
module_specs = {}
for i in range(NUM_AGENTS):
module_specs[f"policy_{i}"] = RLModuleSpec(
module_class=PPOTorchRLModule,
module_class=module_class,
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
# If we want to use this externally created module in the algorithm,
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_e2e_load_complex_multi_rl_module(self):
module_specs = {}
for i in range(NUM_AGENTS):
module_specs[f"policy_{i}"] = RLModuleSpec(
module_class=PPOTorchRLModule,
module_class=module_class,
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
# If we want to use this externally created module in the algorithm,
Expand All @@ -125,7 +125,7 @@ def test_e2e_load_complex_multi_rl_module(self):

# create a RLModule to load and override the "policy_1" module with
module_to_swap_in = RLModuleSpec(
module_class=PPOTorchRLModule,
module_class=module_class,
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
# Note, we need to pass in the default model config for the algorithm
Expand All @@ -140,10 +140,10 @@ def test_e2e_load_complex_multi_rl_module(self):
# create a new MARL_spec with the checkpoint from the marl_checkpoint
# and the module_to_swap_in_checkpoint
module_specs["policy_1"] = RLModuleSpec(
module_class=PPOTorchRLModule,
module_class=module_class,
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config=DefaultModelConfig(fcnet_hiddens=[64]),
model_config_dict={"fcnet_hiddens": [64]},
catalog_class=PPOCatalog,
load_state_path=module_to_swap_in_path,
)
Expand Down Expand Up @@ -250,7 +250,7 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self):
module_specs = {}
for i in range(num_agents):
module_specs[f"policy_{i}"] = RLModuleSpec(
module_class=PPOTorchRLModule,
module_class=module_class,
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
# Note, we need to pass in the default model config for the
Expand All @@ -265,7 +265,7 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self):

# create a RLModule to load and override the "policy_1" module with
module_to_swap_in = RLModuleSpec(
module_class=PPOTorchRLModule,
module_class=module_class,
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
# Note, we need to pass in the default model config for the algorithm
Expand All @@ -280,10 +280,10 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self):
# create a new MARL_spec with the checkpoint from the marl_checkpoint
# and the module_to_swap_in_checkpoint
module_specs["policy_1"] = RLModuleSpec(
module_class=PPOTorchRLModule,
module_class=module_class,
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config=DefaultModelConfig(fcnet_hiddens=[64]),
model_config_dict={"fcnet_hiddens": [64]},
catalog_class=PPOCatalog,
load_state_path=module_to_swap_in_path,
)
Expand Down
59 changes: 52 additions & 7 deletions rllib/core/rl_module/multi_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
ValuesView,
)

import gymnasium as gym

from ray.rllib.core import COMPONENT_MULTI_RL_MODULE_SPEC
from ray.rllib.core.models.specs.typing import SpecType
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import (
override,
Expand Down Expand Up @@ -405,8 +404,54 @@ def __len__(self) -> int:
"""Returns the number of RLModules within this MultiRLModule."""
return len(self._rl_modules)

def __repr__(self) -> str:
return f"MARL({pprint.pformat(self._rl_modules)})"
The underlying single-agent RLModules will check the input specs.
"""
return []

@override(RLModule)
def _forward_train(
self, batch: MultiAgentBatch, **kwargs
) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]:
"""Runs the forward_train pass.
Args:
batch: The batch of multi-agent data (i.e. mapping from module ids to
individual modules' batches).
Returns:
The output of the forward_train pass the specified modules.
"""
return self._run_forward_pass("forward_train", batch, **kwargs)

@override(RLModule)
def _forward_inference(
self, batch: MultiAgentBatch, **kwargs
) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]:
"""Runs the forward_inference pass.
Args:
batch: The batch of multi-agent data (i.e. mapping from module ids to
individual modules' batches).
Returns:
The output of the forward_inference pass the specified modules.
"""
return self._run_forward_pass("forward_inference", batch, **kwargs)

@override(RLModule)
def _forward_exploration(
self, batch: MultiAgentBatch, **kwargs
) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]:
"""Runs the forward_exploration pass.
Args:
batch: The batch of multi-agent data (i.e. mapping from module ids to
individual modules' batches).
Returns:
The output of the forward_exploration pass the specified modules.
"""
return self._run_forward_pass("forward_exploration", batch, **kwargs)

@override(RLModule)
def get_state(
Expand Down Expand Up @@ -461,12 +506,12 @@ def set_state(self, state: StateDict) -> None:
)
# Go through all of our current modules and check, whether they are listed
# in the given MultiRLModuleSpec. If not, erase them from `self`.
for module_id, module in self._rl_modules.copy().items():
if module_id not in multi_rl_module_spec.rl_module_specs:
for module_id, module in self._rl_modules.items():
if module_id not in multi_rl_module_spec.module_specs:
self.remove_module(module_id, raise_err_if_not_found=True)
# Go through all the modules in the given MultiRLModuleSpec and if
# they are not present in `self`, add them.
for module_id, module_spec in multi_rl_module_spec.rl_module_specs.items():
for module_id, module_spec in multi_rl_module_spec.module_specs.items():
if module_id not in self:
self.add_module(module_id, module_spec.build(), override=False)

Expand Down
124 changes: 73 additions & 51 deletions rllib/core/rl_module/tests/test_multi_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import unittest

from ray.rllib.core import COMPONENT_MULTI_RL_MODULE_SPEC, DEFAULT_MODULE_ID
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule
from ray.rllib.examples.rl_modules.classes.vpg_rlm import VPGTorchRLModule
from ray.rllib.core.rl_module.rl_module import RLModuleSpec, RLModuleConfig
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule, MultiRLModuleConfig
from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule
from ray.rllib.env.multi_agent_env import make_multi_agent
from ray.rllib.utils.test_utils import check

Expand All @@ -15,17 +15,17 @@ def test_from_config(self):
env_class = make_multi_agent("CartPole-v0")
env = env_class({"num_agents": 2})
module1 = RLModuleSpec(
module_class=VPGTorchRLModule,
module_class=DiscreteBCTorchModule,
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config={"hidden_dim": 32},
model_config_dict={"fcnet_hiddens": [32]},
)

module2 = RLModuleSpec(
module_class=VPGTorchRLModule,
module_class=DiscreteBCTorchModule,
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config={"hidden_dim": 32},
model_config_dict={"fcnet_hiddens": [32]},
)

multi_rl_module = MultiRLModule(
Expand All @@ -41,10 +41,12 @@ def test_as_multi_rl_module(self):
env_class = make_multi_agent("CartPole-v0")
env = env_class({"num_agents": 2})

multi_rl_module = VPGTorchRLModule(
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config={"hidden_dim": 32},
multi_rl_module = DiscreteBCTorchModule(
config=RLModuleConfig(
env.get_observation_space(0),
env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [32]},
)
).as_multi_rl_module()

self.assertNotIsInstance(multi_rl_module, VPGTorchRLModule)
Expand All @@ -60,10 +62,12 @@ def test_get_state_and_set_state(self):
env_class = make_multi_agent("CartPole-v0")
env = env_class({"num_agents": 2})

module = VPGTorchRLModule(
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config={"hidden_dim": 32},
module = DiscreteBCTorchModule(
config=RLModuleConfig(
env.get_observation_space(0),
env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [32]},
)
).as_multi_rl_module()

state = module.get_state()
Expand All @@ -77,10 +81,12 @@ def test_get_state_and_set_state(self):
set(module[DEFAULT_MODULE_ID].get_state().keys()),
)

module2 = VPGTorchRLModule(
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config={"hidden_dim": 32},
module2 = DiscreteBCTorchModule(
config=RLModuleConfig(
env.get_observation_space(0),
env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [32]},
)
).as_multi_rl_module()
state2 = module2.get_state()
check(state[DEFAULT_MODULE_ID], state2[DEFAULT_MODULE_ID], false=True)
Expand All @@ -95,18 +101,22 @@ def test_add_remove_modules(self):

env_class = make_multi_agent("CartPole-v0")
env = env_class({"num_agents": 2})
module = VPGTorchRLModule(
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config={"hidden_dim": 32},
module = DiscreteBCTorchModule(
config=RLModuleConfig(
env.get_observation_space(0),
env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [32]},
)
).as_multi_rl_module()

module.add_module(
"test",
VPGTorchRLModule(
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config={"hidden_dim": 32},
DiscreteBCTorchModule(
config=RLModuleConfig(
env.get_observation_space(0),
env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [32]},
)
),
)
self.assertEqual(set(module.keys()), {DEFAULT_MODULE_ID, "test"})
Expand All @@ -118,20 +128,24 @@ def test_add_remove_modules(self):
ValueError,
lambda: module.add_module(
DEFAULT_MODULE_ID,
VPGTorchRLModule(
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config={"hidden_dim": 32},
DiscreteBCTorchModule(
config=RLModuleConfig(
env.get_observation_space(0),
env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [32]},
)
),
),
)

module.add_module(
DEFAULT_MODULE_ID,
VPGTorchRLModule(
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config={"hidden_dim": 32},
DiscreteBCTorchModule(
config=RLModuleConfig(
env.get_observation_space(0),
env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [32]},
)
),
override=True,
)
Expand All @@ -140,26 +154,32 @@ def test_save_to_path_and_from_checkpoint(self):
"""Test saving and loading from checkpoint after adding / removing modules."""
env_class = make_multi_agent("CartPole-v0")
env = env_class({"num_agents": 2})
module = VPGTorchRLModule(
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config={"hidden_dim": 32},
module = DiscreteBCTorchModule(
config=RLModuleConfig(
env.get_observation_space(0),
env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [32]},
)
).as_multi_rl_module()

module.add_module(
"test",
VPGTorchRLModule(
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config={"hidden_dim": 32},
DiscreteBCTorchModule(
config=RLModuleConfig(
env.get_observation_space(0),
env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [32]},
)
),
)
module.add_module(
"test2",
VPGTorchRLModule(
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config={"hidden_dim": 128},
DiscreteBCTorchModule(
config=RLModuleConfig(
env.get_observation_space(0),
env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [128]},
)
),
)

Expand All @@ -185,10 +205,12 @@ def test_save_to_path_and_from_checkpoint(self):
# Check that - after adding a new module - the checkpoint is correct.
module.add_module(
"test3",
VPGTorchRLModule(
observation_space=env.get_observation_space(0),
action_space=env.get_action_space(0),
model_config={"hidden_dim": 120},
DiscreteBCTorchModule(
config=RLModuleConfig(
env.get_observation_space(0),
env.get_action_space(0),
model_config_dict={"fcnet_hiddens": [120]},
)
),
)
# Check that - after adding a module - the checkpoint is correct.
Expand Down
Loading

0 comments on commit 993452c

Please sign in to comment.