diff --git a/rllib/BUILD b/rllib/BUILD index 5af156d68c77..097ad26ca80c 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2157,15 +2157,6 @@ py_test( srcs = ["examples/checkpoints/onnx_torch.py"], ) -#@OldAPIStack -py_test( - name = "examples/checkpoints/restore_1_of_n_agents_from_checkpoint", - tags = ["team:rllib", "exclusive", "examples"], - size = "medium", - srcs = ["examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py"], - args = ["--pre-training-iters=1", "--stop-iters=1", "--num-cpus=4"] -) - # subdirectory: connectors/ # .................................... # Framestacking examples only run in smoke-test mode (a few iters only). @@ -2751,6 +2742,15 @@ py_test( # args = ["--enable-new-api-stack", "--num-agents=2", "--as-test", "--framework=torch", "--stop-reward=-100.0", "--num-cpus=4"], # ) +py_test( + name = "examples/checkpoints/restore_1_of_n_agents_from_checkpoint", + main = "examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py", + tags = ["team:rllib", "exclusive", "examples", "examples_use_all_core", "no_main"], + size = "large", + srcs = ["examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py"], + args = ["--enable-new-api-stack", "--num-agents=2", "--framework=torch", "--checkpoint-freq=20", "--checkpoint-at-end", "--num-cpus=4", "--algo=PPO"] +) + py_test( name = "examples/multi_agent/rock_paper_scissors_heuristic_vs_learned", main = "examples/multi_agent/rock_paper_scissors_heuristic_vs_learned.py", diff --git a/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py b/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py index 9c4dc3805613..4338791c71fa 100644 --- a/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py +++ b/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py @@ -1,140 +1,153 @@ -# TODO (sven): Move this example script into the new API stack. - -"""Simple example of how to restore only one of n agents from a trained -multi-agent Algorithm using Ray tune. - -Control the number of agents and policies via --num-agents and --num-policies. +"""An example script showing how to load module weights for 1 of n agents +from checkpoint. + +This example: + - Runs a multi-agent `Pendulum-v1` experiment with >= 2 policies. + - Saves a checkpoint of the `MultiAgentRLModule` used every `--checkpoint-freq` + iterations. + - Stops the experiments after the agents reach a combined return of `-800`. + - Picks the best checkpoint by combined return and restores policy 0 from it. + - Runs a second experiment with the restored `RLModule` for policy 0 and + a fresh `RLModule` for the other policies. + - Stops the second experiment after the agents reach a combined return of `-800`. + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack --num-agents=2 +--checkpoint-freq=20 --checkpoint-at-end` + +Control the number of agents and policies (RLModules) via --num-agents and +--num-policies. + +Control the number of checkpoints by setting `--checkpoint-freq` to a value > 0. +Note that the checkpoint frequency is per iteration and this example needs at +least a single checkpoint to load the RLModule weights for policy 0. +If `--checkpoint-at-end` is set, a checkpoint will be saved at the end of the +experiment. + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + +Results to expect +----------------- +You should expect a reward of -400.0 eventually being achieved by a simple +single PPO policy (no tuning, just using RLlib's default settings). In the +second run of the experiment, the MARL module weights for policy 0 are +restored from the checkpoint of the first run. The reward for a single agent +should be -400.0 again, but the training time should be shorter (around 30 +iterations instead of 190). """ -import argparse -import gymnasium as gym import os -import random - -import ray -from ray import air, tune from ray.air.constants import TRAINING_ITERATION -from ray.rllib.algorithms.algorithm import Algorithm -from ray.rllib.algorithms.callbacks import DefaultCallbacks -from ray.rllib.algorithms.ppo import PPOConfig -from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole -from ray.rllib.policy.policy import Policy -from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.metrics import ( - ENV_RUNNER_RESULTS, - EPISODE_RETURN_MEAN, - NUM_ENV_STEPS_SAMPLED_LIFETIME, +from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.examples.envs.classes.multi_agent import MultiAgentPendulum +from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, ) -from ray.rllib.utils.test_utils import check_learning_achieved +from ray.tune.registry import get_trainable_cls, register_env -tf1, tf, tfv = try_import_tf() - -parser = argparse.ArgumentParser() - -parser.add_argument("--num-agents", type=int, default=4) -parser.add_argument("--num-policies", type=int, default=2) -parser.add_argument("--pre-training-iters", type=int, default=5) -parser.add_argument("--num-cpus", type=int, default=0) -parser.add_argument( - "--framework", - choices=["tf", "tf2", "torch"], - default="torch", - help="The DL framework specifier.", -) -parser.add_argument( - "--as-test", - action="store_true", - help="Whether this script should be run as a test: --stop-reward must " - "be achieved within --stop-timesteps AND --stop-iters.", -) -parser.add_argument( - "--stop-iters", type=int, default=200, help="Number of iterations to train." -) -parser.add_argument( - "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train." -) -parser.add_argument( - "--stop-reward", type=float, default=150.0, help="Reward at which we stop training." +parser = add_rllib_example_script_args( + default_iters=200, + default_timesteps=100000, + default_reward=-400.0, ) +# TODO (sven): This arg is currently ignored (hard-set to 2). +parser.add_argument("--num-policies", type=int, default=2) + if __name__ == "__main__": args = parser.parse_args() - ray.init(num_cpus=args.num_cpus or None) - - # Get obs- and action Spaces. - single_env = gym.make("CartPole-v1") - obs_space = single_env.observation_space - act_space = single_env.action_space - - # Setup PPO with an ensemble of `num_policies` different policies. - policies = { - f"policy_{i}": (None, obs_space, act_space, None) - for i in range(args.num_policies) - } - policy_ids = list(policies.keys()) - - def policy_mapping_fn(agent_id, episode, worker, **kwargs): - pol_id = random.choice(policy_ids) - return pol_id - - config = ( - PPOConfig() - .environment(MultiAgentCartPole, env_config={"num_agents": args.num_agents}) - .framework(args.framework) - .training(num_sgd_iter=10) - .multi_agent(policies=policies, policy_mapping_fn=policy_mapping_fn) - # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. - .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) + # Register our environment with tune. + if args.num_agents > 1: + register_env( + "env", + lambda _: MultiAgentPendulum(config={"num_agents": args.num_agents}), + ) + else: + raise ValueError( + f"`num_agents` must be > 1, but is {args.num_agents}." + "Read the script docstring for more information." + ) + + assert args.checkpoint_freq > 0, ( + "This example requires at least one checkpoint to load the RLModule " + "weights for policy 0." ) - # Do some training and store the checkpoint. - results = tune.Tuner( - "PPO", - param_space=config.to_dict(), - run_config=air.RunConfig( - stop={TRAINING_ITERATION: args.pre_training_iters}, - verbose=1, - checkpoint_config=air.CheckpointConfig( - checkpoint_frequency=1, checkpoint_at_end=True - ), - ), - ).fit() - print("Pre-training done.") - - best_checkpoint = results.get_best_result().checkpoint - print(f".. best checkpoint was: {best_checkpoint}") - - policy_0_checkpoint = os.path.join( - best_checkpoint.to_directory(), "policies/policy_0" + base_config = ( + get_trainable_cls(args.algo) + .get_default_config() + .environment("env") + .training( + train_batch_size_per_learner=512, + mini_batch_size_per_learner=64, + lambda_=0.1, + gamma=0.95, + lr=0.0003, + vf_clip_param=10.0, + ) + .rl_module( + model_config_dict={"fcnet_activation": "relu"}, + ) ) - restored_policy_0 = Policy.from_checkpoint(policy_0_checkpoint) - restored_policy_0_weights = restored_policy_0.get_weights() - print("Starting new tune.Tuner().fit()") - # Start our actual experiment. - stop = { - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward, - NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps, - TRAINING_ITERATION: args.stop_iters, + # Add a simple multi-agent setup. + if args.num_agents > 0: + base_config.multi_agent( + policies={f"p{i}" for i in range(args.num_agents)}, + policy_mapping_fn=lambda aid, *a, **kw: f"p{aid}", + ) + + # Augment the base config with further settings and train the agents. + results = run_rllib_example_script_experiment(base_config, args) + + # Create an env instance to get the observation and action spaces. + env = MultiAgentPendulum(config={"num_agents": args.num_agents}) + # Get the default module spec from the algorithm config. + module_spec = base_config.get_default_rl_module_spec() + module_spec.model_config_dict = base_config.model_config | { + "fcnet_activation": "relu", } - - class RestoreWeightsCallback(DefaultCallbacks): - def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None: - algorithm.set_weights({"policy_0": restored_policy_0_weights}) - - # Make sure, the non-1st policies are not updated anymore. - config.policies_to_train = [pid for pid in policy_ids if pid != "policy_0"] - config.callbacks(RestoreWeightsCallback) - - results = tune.run( - "PPO", - stop=stop, - config=config.to_dict(), - verbose=1, + module_spec.observation_space = env.envs[0].observation_space + module_spec.action_space = env.envs[0].action_space + # Create the module for each policy, but policy 0. + module_specs = {} + for i in range(1, args.num_agents or 1): + module_specs[f"p{i}"] = module_spec + + # Now swap in the RLModule weights for policy 0. + chkpt_path = results.get_best_result().checkpoint.path + p_0_module_state_path = os.path.join(chkpt_path, "learner", "module_state", "p0") + module_spec.load_state_path = p_0_module_state_path + module_specs["p0"] = module_spec + + # Create the MARL module. + marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs) + # Define the MARL module in the base config. + base_config.rl_module(rl_module_spec=marl_module_spec) + # We need to re-register the environment when starting a new run. + register_env( + "env", + lambda _: MultiAgentPendulum(config={"num_agents": args.num_agents}), ) + # Define stopping criteria. + stop = { + # TODO (simon): Change to -800 once the metrics are fixed. Currently + # the combined return is not correctly computed. + f"{ENV_RUNNER_RESULTS}/episode_return_mean": -400, + f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 20000, + TRAINING_ITERATION: 30, + } - if args.as_test: - check_learning_achieved(results, args.stop_reward) - - ray.shutdown() + # Run the experiment again with the restored MARL module. + run_rllib_example_script_experiment(base_config, args, stop=stop)