Skip to content

Commit

Permalink
tmp save: test rolloutm nanager
Browse files Browse the repository at this point in the history
  • Loading branch information
Ming Zhou committed Nov 17, 2023
1 parent 9c3af38 commit 7a77d43
Show file tree
Hide file tree
Showing 25 changed files with 242 additions and 792 deletions.
74 changes: 27 additions & 47 deletions examples/sarl/ppo_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from argparse import ArgumentParser

from malib.learner import IndependentAgent
from malib.scenarios.marl_scenario import MARLScenario

from malib.runner import run
from malib.scenarios import sarl_scenario
from malib.rl.config import Algorithm
from malib.rl.ppo import PPOPolicy, PPOTrainer, DEFAULT_CONFIG
from malib.learner.config import LearnerConfig
from malib.rollout.config import RolloutConfig
from malib.rollout.envs.gym import env_desc_gen


Expand All @@ -23,59 +24,38 @@
trainer_config["total_timesteps"] = int(1e6)
trainer_config["use_cuda"] = args.use_cuda

training_config = {
"learner_type": IndependentAgent,
"trainer_config": trainer_config,
"custom_config": {},
}

rollout_config = {
"fragment_length": 2000, # determine the size of sended data block
"max_step": 200,
"num_eval_episodes": 10,
"num_threads": 2,
"num_env_per_thread": 10,
"num_eval_threads": 1,
"use_subproc_env": False,
"batch_mode": "time_step",
"postprocessor_types": ["defaults"],
# every # rollout epoch run evaluation.
"eval_interval": 1,
"inference_server": "ray", # three kinds of inference server: `local`, `pipe` and `ray`
}

# one to one, no sharing, if sharing, implemented as:
# agent_mapping_func = lambda agent: "default"
agent_mapping_func = lambda agent: agent

algorithms = {
"default": (
PPOPolicy,
PPOTrainer,
# model configuration, None as default
{},
{"use_cuda": args.use_cuda},
)
}

env_description = env_desc_gen(env_id=args.env_id, scenario_configs={})
runtime_logdir = os.path.join(args.log_dir, f"sa_ppo_gym/{time.time()}")
runtime_logdir = os.path.join(
args.log_dir, f"gym/{args.env_id}/independent_ppo/{time.time()}"
)

if not os.path.exists(runtime_logdir):
os.makedirs(runtime_logdir)

scenario = MARLScenario(
scenario = sarl_scenario.SARLScenario(
name=f"ppo-gym-{args.env_id}",
log_dir=runtime_logdir,
algorithms=algorithms,
env_description=env_description,
training_config=training_config,
rollout_config=rollout_config,
agent_mapping_func=agent_mapping_func,
env_desc=env_desc_gen(env_id=args.env_id),
algorithm=Algorithm(
trainer=PPOTrainer,
policy=PPOPolicy,
model_config=None, # use default
trainer_config=trainer_config,
),
learner_config=LearnerConfig(
learner_type=IndependentAgent,
feature_handler_meta_gen=None,
custom_config={},
),
rollout_config=RolloutConfig(
num_workers=1,
),
agent_mapping_func=lambda agent: agent,
stopping_conditions={
"training": {"max_iteration": int(1e10)},
"rollout": {"max_iteration": 1000, "minimum_reward_improvement": 1.0},
},
)

run(scenario)
results = sarl_scenario.execution_plan(
experiment_tag=scenario.name, scenario=scenario, verbose=True
)
13 changes: 9 additions & 4 deletions malib/backend/dataset_server/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,24 @@ def __init__(
self,
grpc_thread_num_workers: int,
max_message_length: int,
feature_handler_cls: Type[BaseFeature],
feature_handler: BaseFeature = None,
feature_handler_cls: Type[BaseFeature] = None,
**feature_handler_kwargs,
) -> None:
super().__init__()

# start a service as thread
self.feature_handler: BaseFeature = feature_handler_cls(
self.feature_handler: BaseFeature = feature_handler or feature_handler_cls(
**feature_handler_kwargs
)
self.grpc_thread_num_workers = grpc_thread_num_workers
self.max_message_length = max_message_length

def start_server(self):
self.server_port = find_free_port()
self.server = service_wrapper(
grpc_thread_num_workers,
max_message_length,
self.grpc_thread_num_workers,
self.max_message_length,
self.server_port,
)(self.feature_handler)
self.server.start()
Expand Down
193 changes: 0 additions & 193 deletions malib/backend/offline_dataset_server.py

This file was deleted.

Loading

0 comments on commit 7a77d43

Please sign in to comment.