From eeb18e50d5b1c7967941186487e3e3d04ef59fec Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Sat, 28 Sep 2024 23:46:52 +0200 Subject: [PATCH] [RLlib] MultiAgentEnv API enhancements (related to defining obs-/action spaces for agents). (#47830) Signed-off-by: ujjawal-khare --- .../tests/test_algorithm_rl_module_restore.py | 18 +-- rllib/core/rl_module/multi_rl_module.py | 59 ++++++++- .../rl_module/tests/test_multi_rl_module.py | 124 +++++++++++------- rllib/env/tests/test_multi_agent_env.py | 52 -------- 4 files changed, 134 insertions(+), 119 deletions(-) diff --git a/rllib/algorithms/tests/test_algorithm_rl_module_restore.py b/rllib/algorithms/tests/test_algorithm_rl_module_restore.py index b9979da368d3..a7b1bf7a7586 100644 --- a/rllib/algorithms/tests/test_algorithm_rl_module_restore.py +++ b/rllib/algorithms/tests/test_algorithm_rl_module_restore.py @@ -69,7 +69,7 @@ def test_e2e_load_simple_multi_rl_module(self): module_specs = {} for i in range(NUM_AGENTS): module_specs[f"policy_{i}"] = RLModuleSpec( - module_class=PPOTorchRLModule, + module_class=module_class, observation_space=env.get_observation_space(0), action_space=env.get_action_space(0), # If we want to use this externally created module in the algorithm, @@ -110,7 +110,7 @@ def test_e2e_load_complex_multi_rl_module(self): module_specs = {} for i in range(NUM_AGENTS): module_specs[f"policy_{i}"] = RLModuleSpec( - module_class=PPOTorchRLModule, + module_class=module_class, observation_space=env.get_observation_space(0), action_space=env.get_action_space(0), # If we want to use this externally created module in the algorithm, @@ -125,7 +125,7 @@ def test_e2e_load_complex_multi_rl_module(self): # create a RLModule to load and override the "policy_1" module with module_to_swap_in = RLModuleSpec( - module_class=PPOTorchRLModule, + module_class=module_class, observation_space=env.get_observation_space(0), action_space=env.get_action_space(0), # Note, we need to pass in the default model config for the algorithm @@ -140,10 +140,10 @@ def test_e2e_load_complex_multi_rl_module(self): # create a new MARL_spec with the checkpoint from the marl_checkpoint # and the module_to_swap_in_checkpoint module_specs["policy_1"] = RLModuleSpec( - module_class=PPOTorchRLModule, + module_class=module_class, observation_space=env.get_observation_space(0), action_space=env.get_action_space(0), - model_config=DefaultModelConfig(fcnet_hiddens=[64]), + model_config_dict={"fcnet_hiddens": [64]}, catalog_class=PPOCatalog, load_state_path=module_to_swap_in_path, ) @@ -250,7 +250,7 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self): module_specs = {} for i in range(num_agents): module_specs[f"policy_{i}"] = RLModuleSpec( - module_class=PPOTorchRLModule, + module_class=module_class, observation_space=env.get_observation_space(0), action_space=env.get_action_space(0), # Note, we need to pass in the default model config for the @@ -265,7 +265,7 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self): # create a RLModule to load and override the "policy_1" module with module_to_swap_in = RLModuleSpec( - module_class=PPOTorchRLModule, + module_class=module_class, observation_space=env.get_observation_space(0), action_space=env.get_action_space(0), # Note, we need to pass in the default model config for the algorithm @@ -280,10 +280,10 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self): # create a new MARL_spec with the checkpoint from the marl_checkpoint # and the module_to_swap_in_checkpoint module_specs["policy_1"] = RLModuleSpec( - module_class=PPOTorchRLModule, + module_class=module_class, observation_space=env.get_observation_space(0), action_space=env.get_action_space(0), - model_config=DefaultModelConfig(fcnet_hiddens=[64]), + model_config_dict={"fcnet_hiddens": [64]}, catalog_class=PPOCatalog, load_state_path=module_to_swap_in_path, ) diff --git a/rllib/core/rl_module/multi_rl_module.py b/rllib/core/rl_module/multi_rl_module.py index b140d9d13a96..0febfcf5ec15 100644 --- a/rllib/core/rl_module/multi_rl_module.py +++ b/rllib/core/rl_module/multi_rl_module.py @@ -18,11 +18,10 @@ ValuesView, ) -import gymnasium as gym - from ray.rllib.core import COMPONENT_MULTI_RL_MODULE_SPEC from ray.rllib.core.models.specs.typing import SpecType from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec +from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils import force_list from ray.rllib.utils.annotations import ( override, @@ -405,8 +404,54 @@ def __len__(self) -> int: """Returns the number of RLModules within this MultiRLModule.""" return len(self._rl_modules) - def __repr__(self) -> str: - return f"MARL({pprint.pformat(self._rl_modules)})" + The underlying single-agent RLModules will check the input specs. + """ + return [] + + @override(RLModule) + def _forward_train( + self, batch: MultiAgentBatch, **kwargs + ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: + """Runs the forward_train pass. + + Args: + batch: The batch of multi-agent data (i.e. mapping from module ids to + individual modules' batches). + + Returns: + The output of the forward_train pass the specified modules. + """ + return self._run_forward_pass("forward_train", batch, **kwargs) + + @override(RLModule) + def _forward_inference( + self, batch: MultiAgentBatch, **kwargs + ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: + """Runs the forward_inference pass. + + Args: + batch: The batch of multi-agent data (i.e. mapping from module ids to + individual modules' batches). + + Returns: + The output of the forward_inference pass the specified modules. + """ + return self._run_forward_pass("forward_inference", batch, **kwargs) + + @override(RLModule) + def _forward_exploration( + self, batch: MultiAgentBatch, **kwargs + ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: + """Runs the forward_exploration pass. + + Args: + batch: The batch of multi-agent data (i.e. mapping from module ids to + individual modules' batches). + + Returns: + The output of the forward_exploration pass the specified modules. + """ + return self._run_forward_pass("forward_exploration", batch, **kwargs) @override(RLModule) def get_state( @@ -461,12 +506,12 @@ def set_state(self, state: StateDict) -> None: ) # Go through all of our current modules and check, whether they are listed # in the given MultiRLModuleSpec. If not, erase them from `self`. - for module_id, module in self._rl_modules.copy().items(): - if module_id not in multi_rl_module_spec.rl_module_specs: + for module_id, module in self._rl_modules.items(): + if module_id not in multi_rl_module_spec.module_specs: self.remove_module(module_id, raise_err_if_not_found=True) # Go through all the modules in the given MultiRLModuleSpec and if # they are not present in `self`, add them. - for module_id, module_spec in multi_rl_module_spec.rl_module_specs.items(): + for module_id, module_spec in multi_rl_module_spec.module_specs.items(): if module_id not in self: self.add_module(module_id, module_spec.build(), override=False) diff --git a/rllib/core/rl_module/tests/test_multi_rl_module.py b/rllib/core/rl_module/tests/test_multi_rl_module.py index 800d36061f28..02aeab9d8901 100644 --- a/rllib/core/rl_module/tests/test_multi_rl_module.py +++ b/rllib/core/rl_module/tests/test_multi_rl_module.py @@ -2,9 +2,9 @@ import unittest from ray.rllib.core import COMPONENT_MULTI_RL_MODULE_SPEC, DEFAULT_MODULE_ID -from ray.rllib.core.rl_module.rl_module import RLModuleSpec -from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule -from ray.rllib.examples.rl_modules.classes.vpg_rlm import VPGTorchRLModule +from ray.rllib.core.rl_module.rl_module import RLModuleSpec, RLModuleConfig +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule, MultiRLModuleConfig +from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule from ray.rllib.env.multi_agent_env import make_multi_agent from ray.rllib.utils.test_utils import check @@ -15,17 +15,17 @@ def test_from_config(self): env_class = make_multi_agent("CartPole-v0") env = env_class({"num_agents": 2}) module1 = RLModuleSpec( - module_class=VPGTorchRLModule, + module_class=DiscreteBCTorchModule, observation_space=env.get_observation_space(0), action_space=env.get_action_space(0), - model_config={"hidden_dim": 32}, + model_config_dict={"fcnet_hiddens": [32]}, ) module2 = RLModuleSpec( - module_class=VPGTorchRLModule, + module_class=DiscreteBCTorchModule, observation_space=env.get_observation_space(0), action_space=env.get_action_space(0), - model_config={"hidden_dim": 32}, + model_config_dict={"fcnet_hiddens": [32]}, ) multi_rl_module = MultiRLModule( @@ -41,10 +41,12 @@ def test_as_multi_rl_module(self): env_class = make_multi_agent("CartPole-v0") env = env_class({"num_agents": 2}) - multi_rl_module = VPGTorchRLModule( - observation_space=env.get_observation_space(0), - action_space=env.get_action_space(0), - model_config={"hidden_dim": 32}, + multi_rl_module = DiscreteBCTorchModule( + config=RLModuleConfig( + env.get_observation_space(0), + env.get_action_space(0), + model_config_dict={"fcnet_hiddens": [32]}, + ) ).as_multi_rl_module() self.assertNotIsInstance(multi_rl_module, VPGTorchRLModule) @@ -60,10 +62,12 @@ def test_get_state_and_set_state(self): env_class = make_multi_agent("CartPole-v0") env = env_class({"num_agents": 2}) - module = VPGTorchRLModule( - observation_space=env.get_observation_space(0), - action_space=env.get_action_space(0), - model_config={"hidden_dim": 32}, + module = DiscreteBCTorchModule( + config=RLModuleConfig( + env.get_observation_space(0), + env.get_action_space(0), + model_config_dict={"fcnet_hiddens": [32]}, + ) ).as_multi_rl_module() state = module.get_state() @@ -77,10 +81,12 @@ def test_get_state_and_set_state(self): set(module[DEFAULT_MODULE_ID].get_state().keys()), ) - module2 = VPGTorchRLModule( - observation_space=env.get_observation_space(0), - action_space=env.get_action_space(0), - model_config={"hidden_dim": 32}, + module2 = DiscreteBCTorchModule( + config=RLModuleConfig( + env.get_observation_space(0), + env.get_action_space(0), + model_config_dict={"fcnet_hiddens": [32]}, + ) ).as_multi_rl_module() state2 = module2.get_state() check(state[DEFAULT_MODULE_ID], state2[DEFAULT_MODULE_ID], false=True) @@ -95,18 +101,22 @@ def test_add_remove_modules(self): env_class = make_multi_agent("CartPole-v0") env = env_class({"num_agents": 2}) - module = VPGTorchRLModule( - observation_space=env.get_observation_space(0), - action_space=env.get_action_space(0), - model_config={"hidden_dim": 32}, + module = DiscreteBCTorchModule( + config=RLModuleConfig( + env.get_observation_space(0), + env.get_action_space(0), + model_config_dict={"fcnet_hiddens": [32]}, + ) ).as_multi_rl_module() module.add_module( "test", - VPGTorchRLModule( - observation_space=env.get_observation_space(0), - action_space=env.get_action_space(0), - model_config={"hidden_dim": 32}, + DiscreteBCTorchModule( + config=RLModuleConfig( + env.get_observation_space(0), + env.get_action_space(0), + model_config_dict={"fcnet_hiddens": [32]}, + ) ), ) self.assertEqual(set(module.keys()), {DEFAULT_MODULE_ID, "test"}) @@ -118,20 +128,24 @@ def test_add_remove_modules(self): ValueError, lambda: module.add_module( DEFAULT_MODULE_ID, - VPGTorchRLModule( - observation_space=env.get_observation_space(0), - action_space=env.get_action_space(0), - model_config={"hidden_dim": 32}, + DiscreteBCTorchModule( + config=RLModuleConfig( + env.get_observation_space(0), + env.get_action_space(0), + model_config_dict={"fcnet_hiddens": [32]}, + ) ), ), ) module.add_module( DEFAULT_MODULE_ID, - VPGTorchRLModule( - observation_space=env.get_observation_space(0), - action_space=env.get_action_space(0), - model_config={"hidden_dim": 32}, + DiscreteBCTorchModule( + config=RLModuleConfig( + env.get_observation_space(0), + env.get_action_space(0), + model_config_dict={"fcnet_hiddens": [32]}, + ) ), override=True, ) @@ -140,26 +154,32 @@ def test_save_to_path_and_from_checkpoint(self): """Test saving and loading from checkpoint after adding / removing modules.""" env_class = make_multi_agent("CartPole-v0") env = env_class({"num_agents": 2}) - module = VPGTorchRLModule( - observation_space=env.get_observation_space(0), - action_space=env.get_action_space(0), - model_config={"hidden_dim": 32}, + module = DiscreteBCTorchModule( + config=RLModuleConfig( + env.get_observation_space(0), + env.get_action_space(0), + model_config_dict={"fcnet_hiddens": [32]}, + ) ).as_multi_rl_module() module.add_module( "test", - VPGTorchRLModule( - observation_space=env.get_observation_space(0), - action_space=env.get_action_space(0), - model_config={"hidden_dim": 32}, + DiscreteBCTorchModule( + config=RLModuleConfig( + env.get_observation_space(0), + env.get_action_space(0), + model_config_dict={"fcnet_hiddens": [32]}, + ) ), ) module.add_module( "test2", - VPGTorchRLModule( - observation_space=env.get_observation_space(0), - action_space=env.get_action_space(0), - model_config={"hidden_dim": 128}, + DiscreteBCTorchModule( + config=RLModuleConfig( + env.get_observation_space(0), + env.get_action_space(0), + model_config_dict={"fcnet_hiddens": [128]}, + ) ), ) @@ -185,10 +205,12 @@ def test_save_to_path_and_from_checkpoint(self): # Check that - after adding a new module - the checkpoint is correct. module.add_module( "test3", - VPGTorchRLModule( - observation_space=env.get_observation_space(0), - action_space=env.get_action_space(0), - model_config={"hidden_dim": 120}, + DiscreteBCTorchModule( + config=RLModuleConfig( + env.get_observation_space(0), + env.get_action_space(0), + model_config_dict={"fcnet_hiddens": [120]}, + ) ), ) # Check that - after adding a module - the checkpoint is correct. diff --git a/rllib/env/tests/test_multi_agent_env.py b/rllib/env/tests/test_multi_agent_env.py index 080713764f04..9febd9cc05d6 100644 --- a/rllib/env/tests/test_multi_agent_env.py +++ b/rllib/env/tests/test_multi_agent_env.py @@ -810,58 +810,6 @@ def is_recurrent(self): check(batch["state_in_0"][i], h) check(batch["state_out_0"][i], h) - def test_space_in_preferred_format(self): - env = NestedMultiAgentEnv() - action_space_in_preferred_format = ( - env._check_if_action_space_maps_agent_id_to_sub_space() - ) - obs_space_in_preferred_format = ( - env._check_if_obs_space_maps_agent_id_to_sub_space() - ) - assert action_space_in_preferred_format, "Act space is not in preferred format." - assert obs_space_in_preferred_format, "Obs space is not in preferred format." - - env2 = make_multi_agent("CartPole-v1")() - action_spaces_in_preferred_format = ( - env2._check_if_action_space_maps_agent_id_to_sub_space() - ) - obs_space_in_preferred_format = ( - env2._check_if_obs_space_maps_agent_id_to_sub_space() - ) - assert ( - action_spaces_in_preferred_format - ), "Action space should be in preferred format but isn't." - assert ( - obs_space_in_preferred_format - ), "Observation space should be in preferred format but isn't." - - def test_spaces_sample_contain_in_preferred_format(self): - env = NestedMultiAgentEnv() - # this environment has spaces that are in the preferred format - # for multi-agent environments where the spaces are dict spaces - # mapping agent-ids to sub-spaces - obs = env.observation_space_sample() - assert env.observation_space_contains( - obs - ), "Observation space does not contain obs" - - action = env.action_space_sample() - assert env.action_space_contains(action), "Action space does not contain action" - - def test_spaces_sample_contain_not_in_preferred_format(self): - env = make_multi_agent("CartPole-v1")({"num_agents": 2}) - # this environment has spaces that are not in the preferred format - # for multi-agent environments where the spaces not in the preferred - # format, users must override the observation_space_contains, - # action_space_contains observation_space_sample, - # and action_space_sample methods in order to do proper checks - obs = env.observation_space_sample() - assert env.observation_space_contains( - obs - ), "Observation space does not contain obs" - action = env.action_space_sample() - assert env.action_space_contains(action), "Action space does not contain action" - if __name__ == "__main__": import pytest