From c748df8139bd5756eb1bac6d373997c137f336f9 Mon Sep 17 00:00:00 2001 From: Simon Zehnder Date: Fri, 10 May 2024 12:16:30 +0200 Subject: [PATCH 1/9] Changed comment. Signed-off-by: Simon Zehnder --- rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py b/rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py index f7e818a659cf..38690232351e 100644 --- a/rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py @@ -288,7 +288,7 @@ def add( for i in range(len(eps)) ] ) - # Increase index. + # Increase index to the new length of `self._indices`. j = len(self._indices) @override(EpisodeReplayBuffer) From c940fc496bbffa1bc659cb28ab06465593e9e5a4 Mon Sep 17 00:00:00 2001 From: Simon Zehnder Date: Tue, 21 May 2024 13:03:37 +0200 Subject: [PATCH 2/9] Added example to restore 1 of n agents from checkpoint using Pendulum multi-agent environment. Signed-off-by: Simon Zehnder --- .../restore_1_of_n_agents_from_checkpoint.py | 155 ++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 rllib/examples/multi_agent/restore_1_of_n_agents_from_checkpoint.py diff --git a/rllib/examples/multi_agent/restore_1_of_n_agents_from_checkpoint.py b/rllib/examples/multi_agent/restore_1_of_n_agents_from_checkpoint.py new file mode 100644 index 000000000000..b25e4bb0be76 --- /dev/null +++ b/rllib/examples/multi_agent/restore_1_of_n_agents_from_checkpoint.py @@ -0,0 +1,155 @@ +"""Simple example of loading module weights for 1 of n agents from checkpoint. + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack --num-agents=2 +--checkpoint-freq=4 --checkpoint-at-end` + +Control the number of agents and policies (RLModules) via --num-agents and +--num-policies. + +Control the number of checkpoints by setting --checkpoint-freq to a value > 0. +Note that the checkpoint frequency is per iteration and this example needs at +least a single checkpoint to load the RLModule weights for policy 0. +If --checkpoint-at-end is set, a checkpoint will be saved at the end of the +experiment. + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + +Results to expect +----------------- +You should expect a reward of -400.0 eventually being achieved by a simple +PPO policy (no tuning, just using RLlib's default settings). In the second +run of the experiment, the MARL module weights for policy 0 are restored from +the checkpoint of the first run. The reward should be -400.0 again, but the +training time should be shorter (around 30 iterations instead of 190): + ++---------------------+------------+----------------------+--------+ +| Trial name | status | loc | iter | +|---------------------+------------+----------------------+--------+ +| PPO_env_7c6be_00000 | TERMINATED | 192.168.1.111:101257 | 26 | ++---------------------+------------+----------------------+--------+ + ++------------------+-------+-------------------+-------------+-------------+ +| total time (s) | ts | combined return | return p0 | return p1 | ++------------------+-------+-------------------+-------------+-------------| +| 86.7995 | 13312 | -395.822 | -315.359 | -325.237 | ++------------------+-------+-------------------+-------------+-------------+ +""" + +import os +from ray.air.constants import TRAINING_ITERATION +from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.examples.envs.classes.multi_agent import MultiAgentPendulum +from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) +from ray.tune.registry import get_trainable_cls, register_env + +parser = add_rllib_example_script_args( + default_iters=200, + default_timesteps=100000, + default_reward=-400.0, +) +# TODO (sven): This arg is currently ignored (hard-set to 2). +parser.add_argument("--num-policies", type=int, default=2) + + +if __name__ == "__main__": + args = parser.parse_args() + + # args.enable_new_api_stack = True + # args.num_agents = 2 + # args.checkpoint_freq = 4 + # args.checkpoint_at_end = True + # Register our environment with tune. + if args.num_agents > 1: + register_env( + "env", + lambda _: MultiAgentPendulum(config={"num_agents": args.num_agents}), + ) + else: + raise ValueError( + f"`num_agents` must be > 1, but is {args.num_agents}." + "Read the script docstring for more information." + ) + + assert args.checkpoint_freq > 0, ( + "This example requires at least one checkpoint to load the RLModule " + "weights for policy 0." + ) + + base_config = ( + get_trainable_cls(args.algo) + .get_default_config() + .environment("env") + .training( + train_batch_size_per_learner=512, + mini_batch_size_per_learner=64, + lambda_=0.1, + gamma=0.95, + lr=0.0003, + vf_clip_param=10.0, + ) + .rl_module( + model_config_dict={"fcnet_activation": "relu"}, + ) + ) + + # Add a simple multi-agent setup. + if args.num_agents > 0: + base_config.multi_agent( + policies={f"p{i}" for i in range(args.num_agents)}, + policy_mapping_fn=lambda aid, *a, **kw: f"p{aid}", + ) + + # Augment the base config with further settings and train the agents. + results = run_rllib_example_script_experiment(base_config, args) + + # Create an env instance to get the observation and action spaces. + env = MultiAgentPendulum(config={"num_agents": args.num_agents}) + # Get the default module spec from the algorithm config. + module_spec = base_config.get_default_rl_module_spec() + module_spec.model_config_dict = base_config.model_config | { + "fcnet_activation": "relu", + } + module_spec.observation_space = env.envs[0].observation_space + module_spec.action_space = env.envs[0].action_space + # Create the module for each policy, but policy 0. + module_specs = {} + for i in range(1, args.num_agents or 1): + module_specs[f"p{i}"] = module_spec + + # Now swap in the RLModule weights for policy 0. + chkpt_path = results.get_best_result().checkpoint.path + p_0_module_state_path = os.path.join(chkpt_path, "learner", "module_state", "p0") + module_spec.load_state_path = p_0_module_state_path + module_specs["p0"] = module_spec + + # Create the MARL module. + marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs) + # Define the MARL module in the base config. + base_config.rl_module(rl_module_spec=marl_module_spec) + # We need to re-register the environment when starting a new run. + register_env( + "env", + lambda _: MultiAgentPendulum(config={"num_agents": args.num_agents}), + ) + # Define stopping criteria. + stop = { + f"{ENV_RUNNER_RESULTS}/episode_return_mean": -400, + f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 20000, + TRAINING_ITERATION: 30, + } + + # Run the experiment again with the restored MARL module. + run_rllib_example_script_experiment(base_config, args, stop=stop) From 504edddd7f4f884bebf7cb5498b94d5ceef2f23b Mon Sep 17 00:00:00 2001 From: Simon Zehnder Date: Tue, 21 May 2024 13:13:50 +0200 Subject: [PATCH 3/9] Added example to BUILD file. Signed-off-by: Simon Zehnder --- rllib/BUILD | 9 +++++++++ .../multi_agent/restore_1_of_n_agents_from_checkpoint.py | 4 ---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 9333e7b1adeb..1550664d72db 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2732,6 +2732,15 @@ py_test( # args = ["--enable-new-api-stack", "--num-agents=2", "--as-test", "--framework=torch", "--stop-reward=-100.0", "--num-cpus=4"], # ) +py_test( + name = "examples/multi_agent/restore_1_of_n_agents_from_checkpoint", + main = "examples/multi_agent/restore_1_of_n_agents_from_checkpoint.py", + tags = ["team:rllib", "exclusive", "examples", "examples_use_all_core"], + size = "large", + srcs = ["examples/multi_agent/restore_1_of_n_agents_from_checkpoint.py"], + args = ["--enable-new-api-stack", "--num-agents=2", "--as-test", "--framework=torch", "--checkpoint-freq=4", "--checkpoint-at-end", "--num-cpus=4"] +) + py_test( name = "examples/multi_agent/rock_paper_scissors_heuristic_vs_learned", main = "examples/multi_agent/rock_paper_scissors_heuristic_vs_learned.py", diff --git a/rllib/examples/multi_agent/restore_1_of_n_agents_from_checkpoint.py b/rllib/examples/multi_agent/restore_1_of_n_agents_from_checkpoint.py index b25e4bb0be76..474ffa818940 100644 --- a/rllib/examples/multi_agent/restore_1_of_n_agents_from_checkpoint.py +++ b/rllib/examples/multi_agent/restore_1_of_n_agents_from_checkpoint.py @@ -67,10 +67,6 @@ if __name__ == "__main__": args = parser.parse_args() - # args.enable_new_api_stack = True - # args.num_agents = 2 - # args.checkpoint_freq = 4 - # args.checkpoint_at_end = True # Register our environment with tune. if args.num_agents > 1: register_env( From 92a6cd36d7818a43daf6a9b3823e065dcbccd08e Mon Sep 17 00:00:00 2001 From: Simon Zehnder Date: Tue, 21 May 2024 20:14:34 +0200 Subject: [PATCH 4/9] Modification due to @sven1977's review. Signed-off-by: Simon Zehnder --- rllib/BUILD | 10 +- .../restore_1_of_n_agents_from_checkpoint.py | 257 +++++++++--------- .../restore_1_of_n_agents_from_checkpoint.py | 151 ---------- 3 files changed, 140 insertions(+), 278 deletions(-) delete mode 100644 rllib/examples/multi_agent/restore_1_of_n_agents_from_checkpoint.py diff --git a/rllib/BUILD b/rllib/BUILD index 1550664d72db..9453c380747a 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2140,10 +2140,10 @@ py_test( #@OldAPIStack py_test( - name = "examples/checkpoints/restore_1_of_n_agents_from_checkpoint", + name = "examples/checkpoints/old_stack_restore_1_of_n_agents_from_checkpoint", tags = ["team:rllib", "exclusive", "examples"], size = "medium", - srcs = ["examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py"], + srcs = ["examples/checkpoints/old_stack_restore_1_of_n_agents_from_checkpoint.py"], args = ["--pre-training-iters=1", "--stop-iters=1", "--num-cpus=4"] ) @@ -2733,11 +2733,11 @@ py_test( # ) py_test( - name = "examples/multi_agent/restore_1_of_n_agents_from_checkpoint", - main = "examples/multi_agent/restore_1_of_n_agents_from_checkpoint.py", + name = "examples/checkpoints/restore_1_of_n_agents_from_checkpoint", + main = "examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py", tags = ["team:rllib", "exclusive", "examples", "examples_use_all_core"], size = "large", - srcs = ["examples/multi_agent/restore_1_of_n_agents_from_checkpoint.py"], + srcs = ["examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py"], args = ["--enable-new-api-stack", "--num-agents=2", "--as-test", "--framework=torch", "--checkpoint-freq=4", "--checkpoint-at-end", "--num-cpus=4"] ) diff --git a/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py b/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py index 9c4dc3805613..69f20801b878 100644 --- a/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py +++ b/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py @@ -1,140 +1,153 @@ -# TODO (sven): Move this example script into the new API stack. - -"""Simple example of how to restore only one of n agents from a trained -multi-agent Algorithm using Ray tune. - -Control the number of agents and policies via --num-agents and --num-policies. +"""An example script showing how to load module weights for 1 of n agents +from checkpoint. + +This example: + - Runs a multi-agent `Pendulum-v1` experiment with >= 2 policies. + - Saves a checkpoint of the `MultiAgentRLModule` used every `--checkpoint-freq` + iterations. + - Stops the experiments after the agents reach a combined return of `-800`. + - Picks the best checkpoint by combined return and restores policy 0 from it. + - Runs a second experiment with the restored `RLModule` for policy 0 and + a fresh `RLModule` for the other policies. + - Stops the second experiment after the agents reach a combined return of `-800`. + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack --num-agents=2 +--checkpoint-freq=4 --checkpoint-at-end` + +Control the number of agents and policies (RLModules) via --num-agents and +--num-policies. + +Control the number of checkpoints by setting --checkpoint-freq to a value > 0. +Note that the checkpoint frequency is per iteration and this example needs at +least a single checkpoint to load the RLModule weights for policy 0. +If --checkpoint-at-end is set, a checkpoint will be saved at the end of the +experiment. + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + +Results to expect +----------------- +You should expect a reward of -400.0 eventually being achieved by a simple +single PPO policy (no tuning, just using RLlib's default settings). In the +second run of the experiment, the MARL module weights for policy 0 are +restored from the checkpoint of the first run. The reward for a single agent +should be -400.0 again, but the training time should be shorter (around 30 +iterations instead of 190). """ -import argparse -import gymnasium as gym import os -import random - -import ray -from ray import air, tune from ray.air.constants import TRAINING_ITERATION -from ray.rllib.algorithms.algorithm import Algorithm -from ray.rllib.algorithms.callbacks import DefaultCallbacks -from ray.rllib.algorithms.ppo import PPOConfig -from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole -from ray.rllib.policy.policy import Policy -from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.metrics import ( - ENV_RUNNER_RESULTS, - EPISODE_RETURN_MEAN, - NUM_ENV_STEPS_SAMPLED_LIFETIME, +from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.examples.envs.classes.multi_agent import MultiAgentPendulum +from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, ) -from ray.rllib.utils.test_utils import check_learning_achieved +from ray.tune.registry import get_trainable_cls, register_env -tf1, tf, tfv = try_import_tf() - -parser = argparse.ArgumentParser() - -parser.add_argument("--num-agents", type=int, default=4) -parser.add_argument("--num-policies", type=int, default=2) -parser.add_argument("--pre-training-iters", type=int, default=5) -parser.add_argument("--num-cpus", type=int, default=0) -parser.add_argument( - "--framework", - choices=["tf", "tf2", "torch"], - default="torch", - help="The DL framework specifier.", -) -parser.add_argument( - "--as-test", - action="store_true", - help="Whether this script should be run as a test: --stop-reward must " - "be achieved within --stop-timesteps AND --stop-iters.", -) -parser.add_argument( - "--stop-iters", type=int, default=200, help="Number of iterations to train." -) -parser.add_argument( - "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train." -) -parser.add_argument( - "--stop-reward", type=float, default=150.0, help="Reward at which we stop training." +parser = add_rllib_example_script_args( + default_iters=200, + default_timesteps=100000, + default_reward=-400.0, ) +# TODO (sven): This arg is currently ignored (hard-set to 2). +parser.add_argument("--num-policies", type=int, default=2) + if __name__ == "__main__": args = parser.parse_args() - ray.init(num_cpus=args.num_cpus or None) - - # Get obs- and action Spaces. - single_env = gym.make("CartPole-v1") - obs_space = single_env.observation_space - act_space = single_env.action_space - - # Setup PPO with an ensemble of `num_policies` different policies. - policies = { - f"policy_{i}": (None, obs_space, act_space, None) - for i in range(args.num_policies) - } - policy_ids = list(policies.keys()) - - def policy_mapping_fn(agent_id, episode, worker, **kwargs): - pol_id = random.choice(policy_ids) - return pol_id - - config = ( - PPOConfig() - .environment(MultiAgentCartPole, env_config={"num_agents": args.num_agents}) - .framework(args.framework) - .training(num_sgd_iter=10) - .multi_agent(policies=policies, policy_mapping_fn=policy_mapping_fn) - # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. - .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) + # Register our environment with tune. + if args.num_agents > 1: + register_env( + "env", + lambda _: MultiAgentPendulum(config={"num_agents": args.num_agents}), + ) + else: + raise ValueError( + f"`num_agents` must be > 1, but is {args.num_agents}." + "Read the script docstring for more information." + ) + + assert args.checkpoint_freq > 0, ( + "This example requires at least one checkpoint to load the RLModule " + "weights for policy 0." ) - # Do some training and store the checkpoint. - results = tune.Tuner( - "PPO", - param_space=config.to_dict(), - run_config=air.RunConfig( - stop={TRAINING_ITERATION: args.pre_training_iters}, - verbose=1, - checkpoint_config=air.CheckpointConfig( - checkpoint_frequency=1, checkpoint_at_end=True - ), - ), - ).fit() - print("Pre-training done.") - - best_checkpoint = results.get_best_result().checkpoint - print(f".. best checkpoint was: {best_checkpoint}") - - policy_0_checkpoint = os.path.join( - best_checkpoint.to_directory(), "policies/policy_0" + base_config = ( + get_trainable_cls(args.algo) + .get_default_config() + .environment("env") + .training( + train_batch_size_per_learner=512, + mini_batch_size_per_learner=64, + lambda_=0.1, + gamma=0.95, + lr=0.0003, + vf_clip_param=10.0, + ) + .rl_module( + model_config_dict={"fcnet_activation": "relu"}, + ) ) - restored_policy_0 = Policy.from_checkpoint(policy_0_checkpoint) - restored_policy_0_weights = restored_policy_0.get_weights() - print("Starting new tune.Tuner().fit()") - # Start our actual experiment. - stop = { - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward, - NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps, - TRAINING_ITERATION: args.stop_iters, + # Add a simple multi-agent setup. + if args.num_agents > 0: + base_config.multi_agent( + policies={f"p{i}" for i in range(args.num_agents)}, + policy_mapping_fn=lambda aid, *a, **kw: f"p{aid}", + ) + + # Augment the base config with further settings and train the agents. + results = run_rllib_example_script_experiment(base_config, args) + + # Create an env instance to get the observation and action spaces. + env = MultiAgentPendulum(config={"num_agents": args.num_agents}) + # Get the default module spec from the algorithm config. + module_spec = base_config.get_default_rl_module_spec() + module_spec.model_config_dict = base_config.model_config | { + "fcnet_activation": "relu", } - - class RestoreWeightsCallback(DefaultCallbacks): - def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None: - algorithm.set_weights({"policy_0": restored_policy_0_weights}) - - # Make sure, the non-1st policies are not updated anymore. - config.policies_to_train = [pid for pid in policy_ids if pid != "policy_0"] - config.callbacks(RestoreWeightsCallback) - - results = tune.run( - "PPO", - stop=stop, - config=config.to_dict(), - verbose=1, + module_spec.observation_space = env.envs[0].observation_space + module_spec.action_space = env.envs[0].action_space + # Create the module for each policy, but policy 0. + module_specs = {} + for i in range(1, args.num_agents or 1): + module_specs[f"p{i}"] = module_spec + + # Now swap in the RLModule weights for policy 0. + chkpt_path = results.get_best_result().checkpoint.path + p_0_module_state_path = os.path.join(chkpt_path, "learner", "module_state", "p0") + module_spec.load_state_path = p_0_module_state_path + module_specs["p0"] = module_spec + + # Create the MARL module. + marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs) + # Define the MARL module in the base config. + base_config.rl_module(rl_module_spec=marl_module_spec) + # We need to re-register the environment when starting a new run. + register_env( + "env", + lambda _: MultiAgentPendulum(config={"num_agents": args.num_agents}), ) + # Define stopping criteria. + stop = { + # TODO (simon): Change to -800 once the metrics are fixed. Currently + # the combined return is not correctly computed. + f"{ENV_RUNNER_RESULTS}/episode_return_mean": -400, + f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 20000, + TRAINING_ITERATION: 30, + } - if args.as_test: - check_learning_achieved(results, args.stop_reward) - - ray.shutdown() + # Run the experiment again with the restored MARL module. + run_rllib_example_script_experiment(base_config, args, stop=stop) diff --git a/rllib/examples/multi_agent/restore_1_of_n_agents_from_checkpoint.py b/rllib/examples/multi_agent/restore_1_of_n_agents_from_checkpoint.py deleted file mode 100644 index 474ffa818940..000000000000 --- a/rllib/examples/multi_agent/restore_1_of_n_agents_from_checkpoint.py +++ /dev/null @@ -1,151 +0,0 @@ -"""Simple example of loading module weights for 1 of n agents from checkpoint. - -How to run this script ----------------------- -`python [script file name].py --enable-new-api-stack --num-agents=2 ---checkpoint-freq=4 --checkpoint-at-end` - -Control the number of agents and policies (RLModules) via --num-agents and ---num-policies. - -Control the number of checkpoints by setting --checkpoint-freq to a value > 0. -Note that the checkpoint frequency is per iteration and this example needs at -least a single checkpoint to load the RLModule weights for policy 0. -If --checkpoint-at-end is set, a checkpoint will be saved at the end of the -experiment. - -For debugging, use the following additional command line options -`--no-tune --num-env-runners=0` -which should allow you to set breakpoints anywhere in the RLlib code and -have the execution stop there for inspection and debugging. - -For logging to your WandB account, use: -`--wandb-key=[your WandB API key] --wandb-project=[some project name] ---wandb-run-name=[optional: WandB run name (within the defined project)]` - -Results to expect ------------------ -You should expect a reward of -400.0 eventually being achieved by a simple -PPO policy (no tuning, just using RLlib's default settings). In the second -run of the experiment, the MARL module weights for policy 0 are restored from -the checkpoint of the first run. The reward should be -400.0 again, but the -training time should be shorter (around 30 iterations instead of 190): - -+---------------------+------------+----------------------+--------+ -| Trial name | status | loc | iter | -|---------------------+------------+----------------------+--------+ -| PPO_env_7c6be_00000 | TERMINATED | 192.168.1.111:101257 | 26 | -+---------------------+------------+----------------------+--------+ - -+------------------+-------+-------------------+-------------+-------------+ -| total time (s) | ts | combined return | return p0 | return p1 | -+------------------+-------+-------------------+-------------+-------------| -| 86.7995 | 13312 | -395.822 | -315.359 | -325.237 | -+------------------+-------+-------------------+-------------+-------------+ -""" - -import os -from ray.air.constants import TRAINING_ITERATION -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec -from ray.rllib.examples.envs.classes.multi_agent import MultiAgentPendulum -from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME -from ray.rllib.utils.test_utils import ( - add_rllib_example_script_args, - run_rllib_example_script_experiment, -) -from ray.tune.registry import get_trainable_cls, register_env - -parser = add_rllib_example_script_args( - default_iters=200, - default_timesteps=100000, - default_reward=-400.0, -) -# TODO (sven): This arg is currently ignored (hard-set to 2). -parser.add_argument("--num-policies", type=int, default=2) - - -if __name__ == "__main__": - args = parser.parse_args() - - # Register our environment with tune. - if args.num_agents > 1: - register_env( - "env", - lambda _: MultiAgentPendulum(config={"num_agents": args.num_agents}), - ) - else: - raise ValueError( - f"`num_agents` must be > 1, but is {args.num_agents}." - "Read the script docstring for more information." - ) - - assert args.checkpoint_freq > 0, ( - "This example requires at least one checkpoint to load the RLModule " - "weights for policy 0." - ) - - base_config = ( - get_trainable_cls(args.algo) - .get_default_config() - .environment("env") - .training( - train_batch_size_per_learner=512, - mini_batch_size_per_learner=64, - lambda_=0.1, - gamma=0.95, - lr=0.0003, - vf_clip_param=10.0, - ) - .rl_module( - model_config_dict={"fcnet_activation": "relu"}, - ) - ) - - # Add a simple multi-agent setup. - if args.num_agents > 0: - base_config.multi_agent( - policies={f"p{i}" for i in range(args.num_agents)}, - policy_mapping_fn=lambda aid, *a, **kw: f"p{aid}", - ) - - # Augment the base config with further settings and train the agents. - results = run_rllib_example_script_experiment(base_config, args) - - # Create an env instance to get the observation and action spaces. - env = MultiAgentPendulum(config={"num_agents": args.num_agents}) - # Get the default module spec from the algorithm config. - module_spec = base_config.get_default_rl_module_spec() - module_spec.model_config_dict = base_config.model_config | { - "fcnet_activation": "relu", - } - module_spec.observation_space = env.envs[0].observation_space - module_spec.action_space = env.envs[0].action_space - # Create the module for each policy, but policy 0. - module_specs = {} - for i in range(1, args.num_agents or 1): - module_specs[f"p{i}"] = module_spec - - # Now swap in the RLModule weights for policy 0. - chkpt_path = results.get_best_result().checkpoint.path - p_0_module_state_path = os.path.join(chkpt_path, "learner", "module_state", "p0") - module_spec.load_state_path = p_0_module_state_path - module_specs["p0"] = module_spec - - # Create the MARL module. - marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs) - # Define the MARL module in the base config. - base_config.rl_module(rl_module_spec=marl_module_spec) - # We need to re-register the environment when starting a new run. - register_env( - "env", - lambda _: MultiAgentPendulum(config={"num_agents": args.num_agents}), - ) - # Define stopping criteria. - stop = { - f"{ENV_RUNNER_RESULTS}/episode_return_mean": -400, - f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 20000, - TRAINING_ITERATION: 30, - } - - # Run the experiment again with the restored MARL module. - run_rllib_example_script_experiment(base_config, args, stop=stop) From 6f6dd8cdb7cf1af30e50a200776761e2585d8718 Mon Sep 17 00:00:00 2001 From: Simon Zehnder Date: Wed, 22 May 2024 09:52:07 +0200 Subject: [PATCH 5/9] Changed checkpoint frequency to 20 as test was not passing due to cache issues. In addition added 'no_main' tag to test in BUILD b/c linter errored out. Signed-off-by: Simon Zehnder --- rllib/BUILD | 4 ++-- .../checkpoints/restore_1_of_n_agents_from_checkpoint.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 9453c380747a..a45ff5ba8611 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2735,10 +2735,10 @@ py_test( py_test( name = "examples/checkpoints/restore_1_of_n_agents_from_checkpoint", main = "examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py", - tags = ["team:rllib", "exclusive", "examples", "examples_use_all_core"], + tags = ["team:rllib", "exclusive", "examples", "examples_use_all_core", "no_main"], size = "large", srcs = ["examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py"], - args = ["--enable-new-api-stack", "--num-agents=2", "--as-test", "--framework=torch", "--checkpoint-freq=4", "--checkpoint-at-end", "--num-cpus=4"] + args = ["--enable-new-api-stack", "--num-agents=2", "--as-test", "--framework=torch", "--checkpoint-freq=20", "--checkpoint-at-end", "--num-cpus=4"] ) py_test( diff --git a/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py b/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py index 69f20801b878..4338791c71fa 100644 --- a/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py +++ b/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py @@ -14,15 +14,15 @@ How to run this script ---------------------- `python [script file name].py --enable-new-api-stack --num-agents=2 ---checkpoint-freq=4 --checkpoint-at-end` +--checkpoint-freq=20 --checkpoint-at-end` Control the number of agents and policies (RLModules) via --num-agents and --num-policies. -Control the number of checkpoints by setting --checkpoint-freq to a value > 0. +Control the number of checkpoints by setting `--checkpoint-freq` to a value > 0. Note that the checkpoint frequency is per iteration and this example needs at least a single checkpoint to load the RLModule weights for policy 0. -If --checkpoint-at-end is set, a checkpoint will be saved at the end of the +If `--checkpoint-at-end` is set, a checkpoint will be saved at the end of the experiment. For debugging, use the following additional command line options From ed95651c3f4ed226bfcff4542d7a4ed52f0e3ed4 Mon Sep 17 00:00:00 2001 From: Simon Zehnder Date: Wed, 22 May 2024 11:59:07 +0200 Subject: [PATCH 6/9] Removed '--as-test' argument from example file. Signed-off-by: Simon Zehnder --- rllib/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/BUILD b/rllib/BUILD index a45ff5ba8611..9b2d67c3689d 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2738,7 +2738,7 @@ py_test( tags = ["team:rllib", "exclusive", "examples", "examples_use_all_core", "no_main"], size = "large", srcs = ["examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py"], - args = ["--enable-new-api-stack", "--num-agents=2", "--as-test", "--framework=torch", "--checkpoint-freq=20", "--checkpoint-at-end", "--num-cpus=4"] + args = ["--enable-new-api-stack", "--num-agents=2", "--framework=torch", "--checkpoint-freq=20", "--checkpoint-at-end", "--num-cpus=4", "--algo=PPO"] ) py_test( From 9acc2fb406f37c2b76adf7f3fc966ceea47491ea Mon Sep 17 00:00:00 2001 From: Simon Zehnder Date: Wed, 22 May 2024 12:18:38 +0200 Subject: [PATCH 7/9] Added 'no_main' tag to the old example in the BUILD file. Signed-off-by: Simon Zehnder --- rllib/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/BUILD b/rllib/BUILD index 9b2d67c3689d..d7a2b38111d3 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2141,7 +2141,7 @@ py_test( #@OldAPIStack py_test( name = "examples/checkpoints/old_stack_restore_1_of_n_agents_from_checkpoint", - tags = ["team:rllib", "exclusive", "examples"], + tags = ["team:rllib", "exclusive", "examples", "no_main"], size = "medium", srcs = ["examples/checkpoints/old_stack_restore_1_of_n_agents_from_checkpoint.py"], args = ["--pre-training-iters=1", "--stop-iters=1", "--num-cpus=4"] From 2527f6e8b51cd7ed0dbe1550ec72e6939a782f60 Mon Sep 17 00:00:00 2001 From: Simon Zehnder Date: Wed, 22 May 2024 17:48:52 +0200 Subject: [PATCH 8/9] Added the file of the old stack test to 'main'. Signed-off-by: Simon Zehnder --- rllib/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/rllib/BUILD b/rllib/BUILD index d7a2b38111d3..a47fbed441db 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2141,6 +2141,7 @@ py_test( #@OldAPIStack py_test( name = "examples/checkpoints/old_stack_restore_1_of_n_agents_from_checkpoint", + main = "examples/checkpoints/old_stack_restore_1_of_n_agents_from_checkpoint.py", tags = ["team:rllib", "exclusive", "examples", "no_main"], size = "medium", srcs = ["examples/checkpoints/old_stack_restore_1_of_n_agents_from_checkpoint.py"], From a302bd1c99799f88b452306a98ae60c233d87065 Mon Sep 17 00:00:00 2001 From: Simon Zehnder Date: Fri, 24 May 2024 13:53:06 +0200 Subject: [PATCH 9/9] Removed old example from example folder and BUILD file. Signed-off-by: Simon Zehnder --- rllib/BUILD | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 8549e7640dd6..097ad26ca80c 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2157,16 +2157,6 @@ py_test( srcs = ["examples/checkpoints/onnx_torch.py"], ) -#@OldAPIStack -py_test( - name = "examples/checkpoints/old_stack_restore_1_of_n_agents_from_checkpoint", - main = "examples/checkpoints/old_stack_restore_1_of_n_agents_from_checkpoint.py", - tags = ["team:rllib", "exclusive", "examples", "no_main"], - size = "medium", - srcs = ["examples/checkpoints/old_stack_restore_1_of_n_agents_from_checkpoint.py"], - args = ["--pre-training-iters=1", "--stop-iters=1", "--num-cpus=4"] -) - # subdirectory: connectors/ # .................................... # Framestacking examples only run in smoke-test mode (a few iters only).