diff --git a/examples/sarl/ppo_gym.py b/examples/sarl/ppo_gym.py index 0a734e9..ee37ec7 100644 --- a/examples/sarl/ppo_gym.py +++ b/examples/sarl/ppo_gym.py @@ -4,10 +4,11 @@ from argparse import ArgumentParser from malib.learner import IndependentAgent -from malib.scenarios.marl_scenario import MARLScenario - -from malib.runner import run +from malib.scenarios import sarl_scenario +from malib.rl.config import Algorithm from malib.rl.ppo import PPOPolicy, PPOTrainer, DEFAULT_CONFIG +from malib.learner.config import LearnerConfig +from malib.rollout.config import RolloutConfig from malib.rollout.envs.gym import env_desc_gen @@ -23,59 +24,38 @@ trainer_config["total_timesteps"] = int(1e6) trainer_config["use_cuda"] = args.use_cuda - training_config = { - "learner_type": IndependentAgent, - "trainer_config": trainer_config, - "custom_config": {}, - } - - rollout_config = { - "fragment_length": 2000, # determine the size of sended data block - "max_step": 200, - "num_eval_episodes": 10, - "num_threads": 2, - "num_env_per_thread": 10, - "num_eval_threads": 1, - "use_subproc_env": False, - "batch_mode": "time_step", - "postprocessor_types": ["defaults"], - # every # rollout epoch run evaluation. - "eval_interval": 1, - "inference_server": "ray", # three kinds of inference server: `local`, `pipe` and `ray` - } - - # one to one, no sharing, if sharing, implemented as: - # agent_mapping_func = lambda agent: "default" - agent_mapping_func = lambda agent: agent - - algorithms = { - "default": ( - PPOPolicy, - PPOTrainer, - # model configuration, None as default - {}, - {"use_cuda": args.use_cuda}, - ) - } - - env_description = env_desc_gen(env_id=args.env_id, scenario_configs={}) - runtime_logdir = os.path.join(args.log_dir, f"sa_ppo_gym/{time.time()}") + runtime_logdir = os.path.join( + args.log_dir, f"gym/{args.env_id}/independent_ppo/{time.time()}" + ) if not os.path.exists(runtime_logdir): os.makedirs(runtime_logdir) - scenario = MARLScenario( + scenario = sarl_scenario.SARLScenario( name=f"ppo-gym-{args.env_id}", log_dir=runtime_logdir, - algorithms=algorithms, - env_description=env_description, - training_config=training_config, - rollout_config=rollout_config, - agent_mapping_func=agent_mapping_func, + env_desc=env_desc_gen(env_id=args.env_id), + algorithm=Algorithm( + trainer=PPOTrainer, + policy=PPOPolicy, + model_config=None, # use default + trainer_config=trainer_config, + ), + learner_config=LearnerConfig( + learner_type=IndependentAgent, + feature_handler_meta_gen=None, + custom_config={}, + ), + rollout_config=RolloutConfig( + num_workers=1, + ), + agent_mapping_func=lambda agent: agent, stopping_conditions={ "training": {"max_iteration": int(1e10)}, "rollout": {"max_iteration": 1000, "minimum_reward_improvement": 1.0}, }, ) - run(scenario) + results = sarl_scenario.execution_plan( + experiment_tag=scenario.name, scenario=scenario, verbose=True + ) diff --git a/malib/backend/dataset_server/data_loader.py b/malib/backend/dataset_server/data_loader.py index a62a217..92000c4 100644 --- a/malib/backend/dataset_server/data_loader.py +++ b/malib/backend/dataset_server/data_loader.py @@ -21,19 +21,24 @@ def __init__( self, grpc_thread_num_workers: int, max_message_length: int, - feature_handler_cls: Type[BaseFeature], + feature_handler: BaseFeature = None, + feature_handler_cls: Type[BaseFeature] = None, **feature_handler_kwargs, ) -> None: super().__init__() # start a service as thread - self.feature_handler: BaseFeature = feature_handler_cls( + self.feature_handler: BaseFeature = feature_handler or feature_handler_cls( **feature_handler_kwargs ) + self.grpc_thread_num_workers = grpc_thread_num_workers + self.max_message_length = max_message_length + + def start_server(self): self.server_port = find_free_port() self.server = service_wrapper( - grpc_thread_num_workers, - max_message_length, + self.grpc_thread_num_workers, + self.max_message_length, self.server_port, )(self.feature_handler) self.server.start() diff --git a/malib/backend/offline_dataset_server.py b/malib/backend/offline_dataset_server.py deleted file mode 100644 index b99c9c8..0000000 --- a/malib/backend/offline_dataset_server.py +++ /dev/null @@ -1,193 +0,0 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Author: Ming Zhou - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from typing import Dict, Any, Tuple, Union, List -from concurrent.futures import ThreadPoolExecutor -from readerwriterlock import rwlock - -import traceback -import time - -import numpy as np -import ray - -from ray.util.queue import Queue - -from malib.remote.interface import RemoteInterface -from malib.utils.logging import Logger -from malib.utils.tianshou_batch import Batch -from malib.utils.replay_buffer import ReplayBuffer, MultiagentReplayBuffer - - -def write_table( - marker: rwlock.RWLockFair, - buffer: Union[MultiagentReplayBuffer, ReplayBuffer], - writer: Queue, -): - wlock = marker.gen_wlock() - while True: - try: - batches: Union[Batch, List[Batch]] = writer.get() - with wlock: - if not isinstance(batches, List): - batches = [batches] - for e in batches: - buffer.add_batch(e) - except Exception as e: - print(traceback.format_exc()) - break - - -def read_table( - marker: rwlock.RWLockFair, - buffer: Union[MultiagentReplayBuffer, ReplayBuffer], - batch_size: int, - reader: Queue, -): - rlock = marker.gen_rlock() - while True: - try: - with rlock: - if len(buffer) >= batch_size: - ret = buffer.sample(batch_size) - # batch, indices = buffer.sample(batch_size) - else: - # batch, indices = [], np.array([], int) - if isinstance(buffer, MultiagentReplayBuffer): - ret = {} - else: - ret = ([], np.array([], int)) - reader.put_nowait(ret) - except Exception as e: - print(traceback.format_exc()) - break - - -class OfflineDataset(RemoteInterface): - def __init__(self, table_capacity: int, max_consumer_size: int = 1024) -> None: - """Construct an offline datataset. It maintans a dict of datatable, each for a training instance. - - Args: - table_capacity (int): Table capacity, it indicates the buffer size of each data table. - max_consumer_size (int, optional): Defines the maximum of concurrency. Defaults to 1024. - """ - - self.tb_capacity = table_capacity - self.reader_queues: Dict[str, Queue] = {} - self.writer_queues: Dict[str, Queue] = {} - self.buffers: Dict[str, ReplayBuffer] = {} - self.markers: Dict[str, rwlock.RWLockFair] = {} - self.thread_pool = ThreadPoolExecutor(max_workers=max_consumer_size) - - def start(self): - Logger.info("Dataset server started") - - def start_producer_pipe( - self, - name: str, - stack_num: int = 1, - ignore_obs_next: bool = False, - save_only_last_obs: bool = False, - sample_avail: bool = False, - **kwargs, - ) -> Tuple[str, Queue]: - """Start a producer pipeline and create a datatable if not exisits. - - Args: - name (str): The name of datatable need to access - stack_num (int, optional): Indicates how many steps are stacked in a single data sample. Defaults to 1. - ignore_obs_next (bool, optional): Ignore the next observation or not. Defaults to False. - save_only_last_obs (bool, optional): Either save only the last observation frame. Defaults to False. - sample_avail (bool, optional): Sample action maks or not. Defaults to False. - - Returns: - Tuple[str, Queue]: A tuple of table name and queue for insert samples. - """ - - if name not in self.buffers: - buffer = ReplayBuffer( - size=self.tb_capacity, - stack_num=stack_num, - ignore_obs_next=ignore_obs_next, - save_only_last_obs=save_only_last_obs, - sample_avail=sample_avail, - **kwargs, - ) - marker = rwlock.RWLockFair() - - self.buffers[name] = buffer - self.markers[name] = marker - - if name not in self.writer_queues: - writer = Queue(actor_options={"num_cpus": 0}) - self.writer_queues[name] = writer - self.thread_pool.submit( - write_table, self.markers[name], self.buffers[name], writer - ) - - return name, self.writer_queues[name] - - def end_producer_pipe(self, name: str): - """Kill a producer pipe with given name. - - Args: - name (str): The name of related data table. - """ - - if name in self.writer_queues: - queue = self.writer_queues.pop(name) - queue.shutdown() - - def start_consumer_pipe(self, name: str, batch_size: int) -> Tuple[str, Queue]: - """Start a consumer pipeline, if there is no such a table that named as `name`, the function will be stucked until the table has been created. - - Args: - name (str): Name of datatable. - batch_size (int): Batch size. - - Returns: - Tuple[str, Queue]: A tuple of table name and queue for retrieving samples. - """ - - queue_id = f"{name}_{time.time()}" - queue = Queue(actor_options={"num_cpus": 0}) - self.reader_queues[queue_id] = queue - # make sure that the buffer is ready - while name not in self.buffers: - time.sleep(1) - self.thread_pool.submit( - read_table, self.markers[name], self.buffers[name], batch_size, queue - ) - return queue_id, queue - - def end_consumer_pipe(self, name: str): - """Kill a consumer pipeline with given table name. - - Args: - name (str): Name of related datatable. - """ - - if name in self.reader_queues: - queue = self.reader_queues.pop(name) - queue.shutdown() diff --git a/malib/backend/parameter_server.py b/malib/backend/parameter_server.py deleted file mode 100644 index 3f93e91..0000000 --- a/malib/backend/parameter_server.py +++ /dev/null @@ -1,158 +0,0 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Author: Ming Zhou - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from argparse import Namespace -from typing import Dict, Any, Sequence -from threading import Lock - -import itertools -import torch - -from malib.rl.common.policy import Policy -from malib.common.strategy_spec import StrategySpec -from malib.remote.interface import RemoteInterface -from malib.utils.logging import Logger - - -class Table: - def __init__(self, policy_meta_data: Dict[str, Any]): - policy_cls = policy_meta_data["policy_cls"] - optim_config = policy_meta_data.get("optim_config") - policy_init_kwargs = Namespace(**policy_meta_data["kwargs"]) - self.state_dict = None - if optim_config is not None: - self.policy: Policy = policy_cls( - observation_space=policy_init_kwargs.observation_space, - action_space=policy_init_kwargs.action_space, - model_config=policy_init_kwargs.model_config, - custom_config=policy_init_kwargs.custom_config, - **policy_init_kwargs.kwargs, - ) - parameters = [list(v) for v in self.policy.parameters().values()] - parameters = itertools.chain(*parameters) - self.optimizer: torch.optim.Optimizer = getattr( - torch.optim, optim_config["type"] - )(parameters, lr=optim_config["lr"]) - else: - self.optimizer: torch.optim.Optimizer = None - self.lock = Lock() - - def set_weights(self, state_dict: Dict[str, Any]): - """Update weights with given weights. - - Args: - state_dict (Dict[str, Any]): A dict of weights - """ - - with self.lock: - self.state_dict = state_dict - - def apply_gradients(self, *gradients): - raise NotImplementedError - - def get_weights(self) -> Dict[str, Any]: - """Retrive model weights. - - Returns: - Dict[str, Any]: Weights dict - """ - - with self.lock: - return self.state_dict - - -class ParameterServer(RemoteInterface): - def __init__(self, **kwargs): - self.tables: Dict[str, Table] = {} - self.lock = Lock() - - def start(self): - """For debug""" - Logger.info("Parameter server started") - - def apply_gradients(self, table_name: str, gradients: Sequence[Any]): - """Apply gradients to a data table. - - Args: - table_name (str): The specified table name. - gradients (Sequence[Any]): Given gradients to update parameters. - - Raises: - NotImplementedError: Not implemented yet. - """ - - raise NotImplementedError - - def get_weights(self, spec_id: str, spec_policy_id: str) -> Dict[str, Any]: - """Request for weight retrive, return a dict includes keys: `spec_id`, `spec_policy_id` and `weights`. - - Args: - spec_id (str): Strategy spec id. - spec_policy_id (str): Related policy id. - - Returns: - Dict[str, Any]: A dict. - """ - - table_name = f"{spec_id}/{spec_policy_id}" - weights = self.tables[table_name].get_weights() - return { - "spec_id": spec_id, - "spec_policy_id": spec_policy_id, - "weights": weights, - } - - def set_weights( - self, spec_id: str, spec_policy_id: str, state_dict: Dict[str, Any] - ): - """Set weights to a parameter table. The table name will be defined as `{spec_id}/{spec_policy_id}` - - Args: - spec_id (str): StrategySpec id. - spec_policy_id (str): Policy id in the specified strategy spec. - state_dict (Dict[str, Any]): A dict that specify the parameters. - """ - - table_name = f"{spec_id}/{spec_policy_id}" - self.tables[table_name].set_weights(state_dict) - - def create_table(self, strategy_spec: StrategySpec) -> str: - """Create parameter table with given strategy spec. This function will traverse existing policy \ - id in this spec, then generate table for policy ids which have no cooresponding tables. - - Args: - strategy_spec (StrategySpec): A startegy spec instance. - - Returns: - str: Table name formatted as `{startegy_spec_id}/{policy_id}`. - """ - - with self.lock: - for policy_id in strategy_spec.policy_ids: - table_name = f"{strategy_spec.id}/{policy_id}" - if table_name in self.tables: - continue - meta_data = strategy_spec.get_meta_data().copy() - self.tables[table_name] = Table(meta_data) - return table_name diff --git a/malib/learner/config.py b/malib/learner/config.py index ff36d10..afa3136 100644 --- a/malib/learner/config.py +++ b/malib/learner/config.py @@ -1,22 +1,23 @@ -from typing import Dict, Any, Union, Type +from typing import Dict, Any, Union, Type, Callable from dataclasses import dataclass, field from malib.learner.learner import Learner +from malib.backend.dataset_server.feature import BaseFeature # TODO(ming): rename it as LearnerConfig @dataclass -class TrainingConfig: - trainer_config: Dict[str, Any] +class LearnerConfig: learner_type: Type[Learner] + feature_handler_meta_gen: Callable[["EnvDesc", str], Callable[[str], BaseFeature]] custom_config: Dict[str, Any] = field(default_factory=dict()) @classmethod def from_raw( - cls, config: Union["TrainingConfig", Dict[str, Any]] - ) -> "TrainingConfig": - """Cat dict-style configuration to TrainingConfig instance + cls, config: Union["LearnerConfig", Dict[str, Any]] + ) -> "LearnerConfig": + """Cat dict-style configuration to LearnerConfig instance Args: config (Dict[str, Any]): A dict @@ -25,7 +26,7 @@ def from_raw( RuntimeError: Unexpected config type Returns: - TrainingConfig: A training config instance + LearnerConfig: A training config instance """ if isinstance(config, Dict): diff --git a/malib/learner/learner.py b/malib/learner/learner.py index 76d5b10..4beab20 100644 --- a/malib/learner/learner.py +++ b/malib/learner/learner.py @@ -23,10 +23,10 @@ # SOFTWARE. -from typing import Dict, Any, Tuple, Callable, Type, List, Union +from typing import Dict, Any, Tuple, Callable, List, Union, Type from abc import ABC, abstractmethod -from collections import deque +import time import traceback import torch @@ -44,6 +44,7 @@ from malib.common.task import OptimizationTask from malib.common.strategy_spec import StrategySpec from malib.backend.dataset_server.data_loader import DynamicDataset +from malib.backend.dataset_server.feature import BaseFeature from malib.rl.common.trainer import Trainer from malib.rl.common.policy import Policy from malib.rl.config import Algorithm @@ -61,11 +62,10 @@ def __init__( algorithm: Algorithm, agent_mapping_func: Callable[[AgentID], str], governed_agents: Tuple[AgentID], - trainer_config: Dict[str, Any], custom_config: Dict[str, Any] = None, - local_buffer_config: Dict = None, - verbose: bool = True, dataset: DynamicDataset = None, + feature_handler_gen: Callable[[str], BaseFeature] = None, + verbose: bool = True, ): """Construct agent interface for training. @@ -80,14 +80,14 @@ def __init__( Note that it should be a subset of the original set of environment agents. trainer_config (Dict[str, Any]): Trainer configuration. custom_config (Dict[str, Any], optional): A dict of custom configuration. Defaults to None. - local_buffer_config (Dict, optional): A dict for local buffer configuration. Defaults to None. + dataset (DynamicDataset, optional): A dataset instance. Defaults to None. + feature_handler_gen (Callable[[str], BaseFeature], optional): A function that generates feature handler. Defaults to None. verbose (bool, True): Enable logging or not. Defaults to True. """ if verbose: Logger.info("\tAssigned GPUs: {}".format(ray.get_gpu_ids())) - local_buffer_config = local_buffer_config or {} device = torch.device("cuda" if ray.get_gpu_ids() else "cpu") # initialize a strategy spec for policy maintainance. @@ -110,27 +110,31 @@ def __init__( self._summary_writer = tensorboard.SummaryWriter(log_dir=log_dir) # load policy for trainer - self._trainer: Trainer = algorithm.trainer(trainer_config, self._policy) + self._trainer: Trainer = algorithm.trainer( + algorithm.trainer_config, self._policy + ) - dataset = dataset or self.create_dataset() - self._data_loader = DataLoader(dataset, batch_size=trainer_config["batch_size"]) + if dataset is None: + dataset = DynamicDataset( + grpc_thread_num_workers=2, + max_message_length=1024, + feature_handler=feature_handler_gen(device), + ) + else: + if feature_handler_gen is not None: + # XXX(ming): should we replace feature handler ? + dataset.feature_handler = feature_handler_gen(device) + + dataset.start_server() + + self._data_loader = DataLoader( + dataset, batch_size=algorithm.trainer_config["batch_size"] + ) self._total_step = 0 self._total_epoch = 0 self._verbose = verbose - def create_dataset(self) -> DynamicDataset: - """Create dataset - - Returns: - DynamicDataset: Must be an subinstance of DynamicDataset - """ - return DynamicDataset( - grpc_thread_num_workers=1, - max_message_length=1024, - feature_handler_caller=None, - ) - @abstractmethod def multiagent_post_process( self, @@ -223,6 +227,12 @@ def train(self, task: OptimizationTask) -> Dict[str, Any]: self.set_running(True) try: + while ( + self.data_loader.dataset.readable_block_size + < self.data_loader.batch_size + ): + time.sleep(1) + while self.is_running(): for data in self.data_loader: batch_info = self.multiagent_post_process(data) diff --git a/malib/learner/manager.py b/malib/learner/manager.py index 39b5f14..86e784f 100644 --- a/malib/learner/manager.py +++ b/malib/learner/manager.py @@ -34,11 +34,9 @@ Type, Generator, ) -from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor, Future, CancelledError +from concurrent.futures import ThreadPoolExecutor import os -import traceback import ray from malib.common.task import OptimizationTask @@ -46,10 +44,9 @@ from malib.utils.logging import Logger from malib.utils.exploitability import measure_exploitability from malib.remote.interface import RemoteInterface -from malib.learner.learner import Learner from malib.common.strategy_spec import StrategySpec from malib.common.manager import Manager -from malib.learner.config import TrainingConfig +from malib.learner.config import LearnerConfig from malib.rl.config import Algorithm @@ -66,7 +63,7 @@ def __init__( env_desc: Dict[str, Any], agent_mapping_func: Callable[[AgentID], str], group_info: Dict[str, Any], - training_config: Union[Dict[str, Any], TrainingConfig], + learner_config: LearnerConfig, log_dir: str, resource_config: Dict[str, Any] = None, ray_actor_namespace: str = "learner", @@ -90,19 +87,19 @@ def __init__( super().__init__(verbose=verbose, namespace=ray_actor_namespace) resource_config = resource_config or DEFAULT_RESOURCE_CONFIG - training_config = TrainingConfig.from_raw(training_config) + learner_config = LearnerConfig.from_raw(learner_config) # interface config give the agent type used here and the group mapping if needed # FIXME(ming): resource configuration is not available now, will turn-on in the next version - if training_config.trainer_config.get("use_cuda", False): + if algorithm.trainer_config.get("use_cuda", False): num_gpus = 1 / len(group_info["agent_groups"]) else: num_gpus = 0.0 if not os.path.exists(log_dir): os.makedirs(log_dir) - learner_cls = training_config.learner_type + learner_cls = learner_config.learner_type # update num gpus resource_config["num_gpus"] = num_gpus learner_cls = learner_cls.as_remote(**resource_config) @@ -115,6 +112,7 @@ def __init__( ready_check = [] for rid, agents in group_info["agent_groups"].items(): + agents = tuple(agents) learners[rid] = learner_cls.options( name=f"learner_{rid}", max_concurrency=10, namespace=self.namespace ).remote( @@ -124,9 +122,12 @@ def __init__( action_space=group_info["action_space"][rid], algorithm=algorithm, agent_mapping_func=agent_mapping_func, - governed_agents=tuple(agents), - trainer_config=training_config.trainer_config, - custom_config=training_config.custom_config, + governed_agents=agents, + trainer_config=algorithm.trainer_config, + custom_config=learner_config.custom_config, + feature_handler_gen=learner_config.feature_handler_meta_gen( + env_desc, agents[0] + ), verbose=verbose, ) ready_check.append(learners[rid].ready.remote()) @@ -150,7 +151,7 @@ def __init__( self._group_info = group_info self._runtime_ids = tuple(group_info["agent_groups"].keys()) self._env_description = env_desc - self._training_config = training_config + self._learner_config = learner_config self._log_dir = log_dir self._agent_mapping_func = agent_mapping_func self._learners = learners diff --git a/malib/mocker/mocker_utils.py b/malib/mocker/mocker_utils.py index 1f85a5d..e4685e3 100644 --- a/malib/mocker/mocker_utils.py +++ b/malib/mocker/mocker_utils.py @@ -22,69 +22,17 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Sequence, Dict, Any, Callable, List, Union +from typing import Sequence, Dict, Any, Callable, List, Tuple import time -import ray -from ray.util import ActorPool -from malib.rollout.rollout_config import RolloutConfig +from malib.rollout.config import RolloutConfig from malib.utils.typing import AgentID from malib.common.strategy_spec import StrategySpec -from malib.rollout.rolloutworker import RolloutWorker -class FakeRolloutWorker(RolloutWorker): - def init_agent_interfaces( - self, env_desc: Dict[str, Any], runtime_ids: Sequence[AgentID] - ) -> Dict[AgentID, Any]: - return {} - - def init_actor_pool( - self, - env_desc: Dict[str, Any], - rollout_config: Dict[str, Any], - agent_mapping_func: Callable, - ) -> ActorPool: - return NotImplementedError - - def init_servers(self): - pass - - def rollout( - self, - runtime_strategy_specs: Dict[str, StrategySpec], - stopping_conditions: Dict[str, Any], - data_entrypoints: Dict[str, str], - trainable_agents: List[AgentID] = None, - ): - self.set_running(True) - return {} - - def simulate(self, runtime_strategy_specs: Dict[str, StrategySpec]): - time.sleep(0.5) - return {} - - def step_rollout( - self, - eval_step: bool, - rollout_config: Dict[str, Any], - dataset_writer_info_dict: Dict[str, Any], - ) -> List[Dict[str, Any]]: - pass - - def step_simulation( - self, - runtime_strategy_specs_list: Dict[str, StrategySpec], - rollout_config: Dict[str, Any], - ) -> Dict[str, Any]: - pass - - -from typing import Tuple - from malib.utils.typing import PolicyID from malib.common.payoff_manager import PayoffManager diff --git a/malib/registration.py b/malib/registration.py deleted file mode 100644 index ec7b2bb..0000000 --- a/malib/registration.py +++ /dev/null @@ -1,90 +0,0 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from typing import Dict, Callable, Union - - -class Registry: - """Global registry of algorithms, models, preprocessors and environments - - Examples: - >>> # register custom model - >>> Registry.register_custom_model("MyCustomModel", model_class) - >>> # register custom policy - >>> Registry.register_custom_policy("MyCustomPolicy", policy_class) - >>> # register custom environment - >>> Registry.register_custom_env("MyCustomEnvironment", environment_class) - >>> # register custom algorithm - >>> Registry.register_custom_algorithm( - ... name="MyCustomAlgo", - ... policy="registered_policy_name_or_cls", - ... trainer="registered_trainer_name_or_cls", - ... loss="registered_loss_name_or_cls") - >>> - """ - - @staticmethod - def register_custom_algorithm( - name: str, - policy: Union[type, str], - trainer: Union[type, str], - loss: Union[type, str] = None, - ) -> None: - """Register a custom algorithm by name. - - :param name: str, Name to register the algorithm under. - :param policy: Union[type, str], Python class or registered name of policy. - :param trainer: Union[type, str], Python class or registered name of trainer. - :param loss: Union[type, str], Python class or registered name of loss function. - :return: - """ - # _global_registry.register(ALGORITHM, name, policy, trainer, loss) - pass - - @staticmethod - def register_custom_model(name: str, model_class: type) -> None: - """Register a custom model by name. - - :param name: str, Name to register the model under. - :param model_class: type, Python class of the model. - :return: - """ - # _global_registry.register(MODEL, name, model_class) - pass - - @staticmethod - def register_custom_policy(name: str, policy_class: type) -> None: - """Register a custom policy by name. - - :param name: str, Name to register the policy under. - :param policy_class: type, Python class of the policy. - """ - pass - - @staticmethod - def register_custom_env(name: str, env_class: type) -> None: - """Register a custom environment by name. - - :param name: str, Name to register the environment under. - :param env_class: type, Python class of the environment. - """ - pass diff --git a/malib/rl/common/policy.py b/malib/rl/common/policy.py index ba346aa..4cd783d 100644 --- a/malib/rl/common/policy.py +++ b/malib/rl/common/policy.py @@ -249,16 +249,19 @@ def to(self, device: str = None, use_copy: bool = False) -> "Policy": Policy: A policy instance """ + if isinstance(device, torch.device): + device = device.type + if device is None: - device = "cpu" if not self.use_cuda else "cuda" + device = "cpu" if "cuda" not in self.device else "cuda" - cond1 = "cpu" in device and self.use_cuda - cond2 = "cuda" in device and not self.use_cuda + cond1 = "cpu" in device and "cuda" in self.device + cond2 = "cuda" in device and "cuda" not in self.device if "cpu" in device: - use_cuda = False + _device = device else: - use_cuda = self._custom_config.get("use_cuda", False) + _device = self.device replacement = {} if cond1 or cond2: @@ -273,7 +276,6 @@ def to(self, device: str = None, use_copy: bool = False) -> "Policy": if use_copy: ret = self.copy(self, replacement=replacement) else: - self.use_cuda = use_cuda ret = self return ret diff --git a/malib/rl/config.py b/malib/rl/config.py index 730be86..5935b99 100644 --- a/malib/rl/config.py +++ b/malib/rl/config.py @@ -14,3 +14,5 @@ class Algorithm: trainer: Type[Trainer] model_config: Dict[str, Any] + + trainer_config: Dict[str, Any] diff --git a/malib/rollout/rollout_config.py b/malib/rollout/config.py similarity index 100% rename from malib/rollout/rollout_config.py rename to malib/rollout/config.py diff --git a/malib/rollout/inference/env_runner.py b/malib/rollout/inference/env_runner.py index 0167f77..1048e78 100644 --- a/malib/rollout/inference/env_runner.py +++ b/malib/rollout/inference/env_runner.py @@ -33,7 +33,7 @@ from malib.utils.timing import Timing from malib.remote.interface import RemoteInterface from malib.rollout.envs.vector_env import VectorEnv, SubprocVecEnv -from malib.rollout.rollout_config import RolloutConfig +from malib.rollout.config import RolloutConfig from malib.rollout.inference.client import InferenceClient, PolicyReturnWithObs from malib.rollout.envs.env import Environment from malib.common.strategy_spec import StrategySpec diff --git a/malib/rollout/manager.py b/malib/rollout/manager.py index e669a42..146d6f4 100644 --- a/malib/rollout/manager.py +++ b/malib/rollout/manager.py @@ -39,7 +39,7 @@ from malib.common.manager import Manager from malib.remote.interface import RemoteInterface from malib.common.strategy_spec import StrategySpec -from malib.rollout.rollout_config import RolloutConfig +from malib.rollout.config import RolloutConfig from malib.rollout.pb_rolloutworker import PBRolloutWorker @@ -99,7 +99,7 @@ def __init__( super().__init__(verbose=verbose, namespace=ray_actor_namespace) rollout_worker_cls = PBRolloutWorker - worker_cls = rollout_worker_cls.as_remote(num_cpus=0, num_gpus=0).options() + worker_cls = rollout_worker_cls.as_remote(num_cpus=0, num_gpus=0) workers = [] ready_check = [] for i in range(num_worker): @@ -180,7 +180,9 @@ def submit( for _task in task: validate_strategy_specs(_task.strategy_specs) - self._actor_pool.submit(lambda actor, _task: actor.rollout.remote(_task)) + self._actor_pool.submit( + lambda actor, _task: actor.rollout.remote(_task), _task + ) if wait: result_list = self.wait() diff --git a/malib/rollout/rolloutworker.py b/malib/rollout/rolloutworker.py index 455b8d1..c3f24ec 100644 --- a/malib/rollout/rolloutworker.py +++ b/malib/rollout/rolloutworker.py @@ -47,7 +47,7 @@ from malib.common.strategy_spec import StrategySpec from malib.common.task import RolloutTask from malib.remote.interface import RemoteInterface -from malib.rollout.rollout_config import RolloutConfig +from malib.rollout.config import RolloutConfig from malib.rollout.inference.client import InferenceClient from malib.rollout.inference.env_runner import BasicEnvRunner diff --git a/malib/scenarios/sarl_scenario.py b/malib/scenarios/sarl_scenario.py index 653f6cd..0cc3659 100644 --- a/malib/scenarios/sarl_scenario.py +++ b/malib/scenarios/sarl_scenario.py @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Dict, Any +from typing import Dict, Any, Union from malib.common.task import TaskType, OptimizationTask, RolloutTask from malib.scenarios import Scenario @@ -30,8 +30,11 @@ from malib.utils.logging import Logger from malib.backend.league import League from malib.learner.manager import LearnerManager +from malib.learner.config import LearnerConfig +from malib.rollout.config import RolloutConfig from malib.rollout.manager import RolloutWorkerManager from malib.rollout.inference.manager import InferenceManager +from malib.rl.config import Algorithm class SARLScenario(Scenario): @@ -40,9 +43,9 @@ def __init__( name: str, log_dir: str, env_desc: Dict[str, Any], - algorithms: Dict[str, Any], - training_config: Dict[str, Any], - rollout_config: Dict[str, Any], + algorithm: Algorithm, + learner_config: Union[Dict[str, Any], LearnerConfig], + rollout_config: Union[Dict[str, Any], RolloutConfig], stopping_conditions: Dict[str, Any], resource_config: Dict[str, Any] = None, ): @@ -50,9 +53,9 @@ def __init__( name, log_dir, env_desc, - algorithms, + algorithm, lambda agent: "default", - training_config, + learner_config, rollout_config, stopping_conditions, ) @@ -66,15 +69,13 @@ def create_global_stopper(self) -> StoppingCondition: def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = True): # TODO(ming): simplify the initialization of training and rollout manager with a scenario instance as input learner_manager = LearnerManager( - experiment_tag=experiment_tag, stopping_conditions=scenario.stopping_conditions, - algorithms=scenario.algorithms, + algorithm=scenario.algorithm, env_desc=scenario.env_desc, agent_mapping_func=scenario.agent_mapping_func, group_info=scenario.group_info, - training_config=scenario.training_config, + learner_config=scenario.learner_config, log_dir=scenario.log_dir, - remote_mode=True, resource_config=scenario.resource_config["training"], ray_actor_namespace="learner_{}".format(experiment_tag), verbose=verbose, @@ -84,7 +85,8 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = group_info=scenario.group_info, ray_actor_namespace="inference_{}".format(experiment_tag), model_entry_point=learner_manager.learner_entrypoints, - scenario=scenario, + algorithm=scenario.algorithm, + verbose=verbose, ) rollout_manager = RolloutWorkerManager( @@ -99,27 +101,23 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = verbose=verbose, ) - league = League( - learner_manager, rollout_manager, inference_manager, namespace=experiment_tag - ) + league = League(learner_manager, rollout_manager, inference_manager) + # TODO(ming): further check is needed optimization_task = OptimizationTask( - active_agents=scenario.env_desc["possible_agents"], stop_conditions=scenario.stopping_conditions["training"], + strategy_specs=None, + active_agents=None, ) - strategy_specs = learner_manager.get_strategy_specs() - rollout_task = RolloutTask( - task_type=TaskType.ROLLOUT, - strategy_specs=strategy_specs, + strategy_specs=None, stopping_conditions=scenario.stopping_conditions["rollout"], data_entrypoint_mapping=learner_manager.data_entrypoints, ) evaluation_task = RolloutTask( - task_type=TaskType.EVALUATION, - strategy_specs=strategy_specs, + strategy_specs=None, ) stopper = scenario.create_global_stopper() diff --git a/malib/scenarios/scenario.py b/malib/scenarios/scenario.py index 03af706..d5e1ea2 100644 --- a/malib/scenarios/scenario.py +++ b/malib/scenarios/scenario.py @@ -31,6 +31,10 @@ from malib.utils.typing import AgentID from malib.utils.stopping_conditions import StoppingCondition +from malib.rl.config import Algorithm +from malib.learner.config import LearnerConfig +from malib.rollout.config import RolloutConfig + DEFAULT_STOPPING_CONDITIONS = {} @@ -91,16 +95,16 @@ def __init__( name: str, log_dir: str, env_desc: Dict[str, Any], - algorithms: Dict[str, Any], + algorithm: Algorithm, agent_mapping_func: LambdaType, - training_config: Dict[str, Any], - rollout_config: Dict[str, Any], + learner_config: LearnerConfig, + rollout_config: RolloutConfig, stopping_conditions: Dict[str, Any], ): self.name = name self.log_dir = log_dir self.env_desc = env_desc - self.algorithms = algorithms + self.algorithm = algorithm self.agent_mapping_func = agent_mapping_func # then generate grouping information here self.group_info = form_group_info(env_desc, agent_mapping_func) @@ -109,8 +113,8 @@ def __init__( env_desc["observation_spaces"], env_desc["action_spaces"], ) - self.training_config = training_config - self.rollout_config = rollout_config + self.learner_config = LearnerConfig.from_raw(learner_config) + self.rollout_config = RolloutConfig.from_raw(rollout_config) self.stopping_conditions = stopping_conditions or DEFAULT_STOPPING_CONDITIONS def copy(self): diff --git a/malib/settings.py b/malib/settings.py index af9b66c..428a46b 100644 --- a/malib/settings.py +++ b/malib/settings.py @@ -1,49 +1,3 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - import logging -import os - -BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - -LOG_DIR = os.path.join(BASE_DIR, "logs") -LOG_LEVEL = logging.INFO -STATISTIC_FEEDBACK = True -DATA_FEEDBACK = False -USE_REMOTE_LOGGER = True -USE_MONGO_LOGGER = False -PROFILING = False - -PARAMETER_SERVER_ACTOR = "ParameterServer" -OFFLINE_DATASET_ACTOR = "OfflineDataset" -COORDINATOR_SERVER_ACTOR = "coordinator" - -# default episode capacity when initializing -DEFAULT_EPISODE_INIT_CAPACITY = int(1e6) -# default episode maximum capacity -DEFAULT_EPISODE_CAPACITY = 30000 # int(1e15) -# related to each group of expr settings -DEFAULT_EPISODE_BLOCK_SIZE = int(75) -PICKLE_PROTOCOL_VER = 4 -PARAM_DIR = os.path.join(BASE_DIR, "../checkpoints") -DATASET_DIR = os.path.join(BASE_DIR, "dataset") +LOG_LEVEL = logging.DEBUG diff --git a/malib/utils/general.py b/malib/utils/general.py index 2bf25ad..778cbca 100644 --- a/malib/utils/general.py +++ b/malib/utils/general.py @@ -41,8 +41,6 @@ import torch import numpy as np -from malib import settings - T = TypeVar("T") diff --git a/tests/backend/test_dynamic_dataset.py b/tests/backend/test_dynamic_dataset.py index 304ecb0..5c62387 100644 --- a/tests/backend/test_dynamic_dataset.py +++ b/tests/backend/test_dynamic_dataset.py @@ -83,6 +83,7 @@ def test_sync_grpc_service_get(self): for k, v in _spaces.items() }, ) + dataset.start_server() # send data print("send 10 piece of data, entrypoint=", dataset.entrypoint) @@ -124,6 +125,7 @@ def test_async_grpc_service_get(self): k: np.zeros((100,) + v.shape, dtype=v.dtype) for k, v in _spaces.items() }, ) + dataset.start_server() def start_send(batch, entrypoint): print("send 10 piece of data, entrypoint=", entrypoint) diff --git a/tests/rollout/test_env_runner.py b/tests/rollout/test_env_runner.py index ddddaa3..33e69f9 100644 --- a/tests/rollout/test_env_runner.py +++ b/tests/rollout/test_env_runner.py @@ -7,7 +7,7 @@ from malib.rollout.inference import env_runner from malib.rollout.inference.client import InferenceClient from malib.rollout.envs import mdp -from malib.rollout.rollout_config import RolloutConfig +from malib.rollout.config import RolloutConfig from malib.rl.random import RandomPolicy diff --git a/tests/rollout/test_mdp_env.py b/tests/rollout/test_mdp_env.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/rollout/test_open_spiel.py b/tests/rollout/test_open_spiel.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/rollout/test_pb_rollout_worker.py b/tests/rollout/test_pb_rollout_worker.py index e880a3e..9d55103 100644 --- a/tests/rollout/test_pb_rollout_worker.py +++ b/tests/rollout/test_pb_rollout_worker.py @@ -26,19 +26,16 @@ import pytest import ray -from malib.backend.dataset_server.data_loader import DynamicDataset from malib.common.task import RolloutTask from malib.common.strategy_spec import StrategySpec from malib.rl.random import RandomPolicy from malib.rl.config import Algorithm from malib.rollout.envs.random import env_desc_gen -from malib.rollout.rollout_config import RolloutConfig +from malib.rollout.config import RolloutConfig from malib.rollout.pb_rolloutworker import PBRolloutWorker from malib.rollout.inference.manager import InferenceManager from malib.scenarios.scenario import form_group_info -from malib.utils.tianshou_batch import Batch -from malib.utils.typing import AgentID def gen_rollout_config(inference_server_type: str): @@ -61,9 +58,7 @@ def gen_common_requirements(n_player: int): env_desc = env_desc_gen(num_agents=n_player) algorithm = Algorithm( - policy=RandomPolicy, - trainer=None, - model_config=None, + policy=RandomPolicy, trainer=None, model_config=None, trainer_config={} ) rollout_config = RolloutConfig( @@ -79,11 +74,14 @@ def gen_common_requirements(n_player: int): return env_desc, algorithm, rollout_config, group_info +import numpy as np + from malib.learner.learner import Learner from gym import spaces -from malib.learner.learner import Learner from malib.learner.manager import LearnerManager -from malib.learner.config import TrainingConfig +from malib.learner.config import LearnerConfig +from malib.utils.episode import Episode +from malib.backend.dataset_server.feature import BaseFeature class FakeLearner(Learner): @@ -94,6 +92,28 @@ def multiagent_post_process( pass +class FakeFeatureHandler(BaseFeature): + + pass + + +def feature_handler_meta_gen(env_desc, agent_id): + def f(device): + _spaces = { + Episode.DONE: spaces.Discrete(1), + Episode.CUR_OBS: env_desc["observation_spaces"][agent_id], + Episode.ACTION: env_desc["action_spaces"][agent_id], + Episode.REWARD: spaces.Box(-np.inf, np.inf, shape=(1,), dtype=np.float32), + Episode.NEXT_OBS: env_desc["observation_spaces"][agent_id], + } + np_memory = { + k: np.zeros((100,) + v.shape, dtype=v.dtype) for k, v in _spaces.items() + } + return FakeFeatureHandler(_spaces, np_memory, device) + + return f + + @pytest.mark.parametrize("n_player", [1, 2]) class TestRolloutWorker: def test_rollout(self, n_player: int): @@ -163,8 +183,10 @@ def test_rollout_with_data_entrypoint(self, n_player: int): env_desc=env_desc, agent_mapping_func=lambda agent: "default", group_info=group_info, - training_config=TrainingConfig( - trainer_config={}, learner_type=FakeLearner, custom_config=None + learner_config=LearnerConfig( + learner_type=FakeLearner, + feature_handler_meta_gen=feature_handler_meta_gen, + custom_config=None, ), log_dir=log_dir, ) diff --git a/tests/rollout/test_rollout_manager.py b/tests/rollout/test_rollout_manager.py index 2802ef4..4258e8e 100644 --- a/tests/rollout/test_rollout_manager.py +++ b/tests/rollout/test_rollout_manager.py @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Dict, Any +from typing import Dict, Any, Callable import pytest import ray @@ -30,23 +30,33 @@ from gym import spaces from pytest_mock import MockerFixture +from malib.common.task import RolloutTask from malib.common.strategy_spec import StrategySpec +from malib.rollout.config import RolloutConfig from malib.rollout.manager import RolloutWorkerManager -from malib.mocker.mocker_utils import FakeRolloutWorker +from malib.rl.random import RandomPolicy +from malib.scenarios.scenario import form_group_info +from malib.learner.manager import LearnerManager +from malib.learner.config import LearnerConfig +from malib.rollout.inference.manager import InferenceManager +from test_pb_rollout_worker import ( + feature_handler_meta_gen, + FakeFeatureHandler, + FakeLearner, + gen_common_requirements, +) def create_manager( - mocker: MockerFixture, stopping_conditions: Dict[str, Any], rollout_config: Dict[str, Any], env_desc: Dict[str, Any], + agent_mapping_func: Callable, ): - mocker.patch("malib.rollout.manager.PBRolloutWorker", new=FakeRolloutWorker) manager = RolloutWorkerManager( - experiment_tag="test_rollout_manager", stopping_conditions=stopping_conditions, num_worker=1, - agent_mapping_func=lambda agent: agent, + group_info=form_group_info(env_desc, agent_mapping_func), rollout_config=rollout_config, env_desc=env_desc, log_dir="./logs", @@ -55,105 +65,57 @@ def create_manager( @pytest.mark.parametrize("n_players", [1, 2]) -@pytest.mark.parametrize("inference_server_type", ["local", "ray"]) class TestRolloutManager: - def test_rollout_task_send( - self, mocker: MockerFixture, n_players: int, inference_server_type: str - ): - if not ray.is_initialized(): - ray.init() - - agents = [f"player_{i}" for i in range(n_players)] - manager = create_manager( - mocker, - stopping_conditions={"rollout": {"max_iteration": 2}}, - rollout_config={ - "fragment_length": 100, - "max_step": 10, - "num_eval_episodes": 2, - "num_threads": 1, - "num_env_per_thread": 1, - "num_eval_threads": 1, - "use_subproc_env": False, - "batch_mode": "timestep", - "postprocessor_types": None, - "eval_interval": 2, - "inference_server": inference_server_type, - }, - env_desc={ - "possible_agents": agents, - "observation_spaces": { - agent: spaces.Box(-1, 1.0, shape=(2,)) for agent in agents - }, - "action_spaces": { - agent: spaces.Box(-1, 1, shape=(2,)) for agent in agents - }, - }, - ) - - strategy_specs = { - agent: StrategySpec( - identifier=agent, - policy_ids=["policy_0"], - meta_data={ - "prob_list": [1.0], - "policy_cls": None, - "kwargs": None, - "experiment_tag": "test_rollout_manager", - }, + def test_rollout_task_send(self, mocker: MockerFixture, n_players: int): + with ray.init(local_mode=True): + env_desc, algorithm, rollout_config, group_info = gen_common_requirements( + n_players ) - for agent in agents - } - task_list = [ - { - "trainable_agents": agents, - "data_entrypoints": None, - "strategy_specs": strategy_specs, + inference_namespace = "test_pb_rolloutworker" + manager = create_manager( + stopping_conditions={"rollout": {"max_iteration": 2}}, + rollout_config=RolloutConfig(), + env_desc=env_desc, + agent_mapping_func=lambda agent: "default", + ) + + learner_manager = LearnerManager( + stopping_conditions={"max_iteration": 10}, + algorithm=algorithm, + env_desc=env_desc, + agent_mapping_func=lambda agent: "default", + group_info=group_info, + learner_config=LearnerConfig( + learner_type=FakeLearner, + feature_handler_meta_gen=feature_handler_meta_gen, + custom_config=None, + ), + log_dir="./logs", + ) + + infer_manager = InferenceManager( + group_info=group_info, + ray_actor_namespace=inference_namespace, + algorithm=algorithm, + model_entry_point=learner_manager.learner_entrypoints, + ) + + rollout_config.inference_entry_points = infer_manager.inference_entry_points + + strategy_specs = { + agent: StrategySpec( + policy_cls=RandomPolicy, + observation_space=env_desc["observation_spaces"][agent], + action_space=env_desc["action_spaces"][agent], + policy_ids=["policy_0"], + ) + for agent in env_desc["possible_agents"] } - for _ in range(2) - ] - manager.rollout(task_list) - - for result in manager.retrive_results(): - print(result) - - ray.shutdown() - - def test_simulation_task_send( - self, mocker: MockerFixture, n_players: int, inference_server_type: str - ): - if not ray.is_initialized(): - ray.init() - - agents = [f"player_{i}" for i in range(n_players)] - manager = create_manager( - mocker, - stopping_conditions={"rollout": {"max_iteration": 2}}, - rollout_config={ - "fragment_length": 100, - "max_step": 10, - "num_eval_episodes": 2, - "num_threads": 1, - "num_env_per_thread": 1, - "num_eval_threads": 1, - "use_subproc_env": False, - "batch_mode": "timestep", - "postprocessor_types": None, - "eval_interval": 2, - "inference_server": inference_server_type, - }, - env_desc={ - "possible_agents": agents, - "observation_spaces": { - agent: spaces.Box(-1, 1.0, shape=(2,)) for agent in agents - }, - "action_spaces": { - agent: spaces.Box(-1, 1, shape=(2,)) for agent in agents - }, - }, - ) - - manager.simulate([None] * 2) - for result in manager.retrive_results(): - print(result) - ray.shutdown() + + task = RolloutTask( + strategy_specs=strategy_specs, + stopping_conditions={"max_iteration": 10}, + data_entrypoints=None, + ) + + results = manager.submit(task, wait=True)