Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib]: Cleanup examples folder: Add example restoring 1 of n agents from a checkpoint. #45462

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
c748df8
Changed comment.
simonsays1980 May 10, 2024
6409007
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 13, 2024
d2f9030
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 14, 2024
a3416a8
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 15, 2024
8582ad9
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 16, 2024
b565f34
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 21, 2024
c940fc4
Added example to restore 1 of n agents from checkpoint using Pendulum…
simonsays1980 May 21, 2024
504eddd
Added example to BUILD file.
simonsays1980 May 21, 2024
92a6cd3
Modification due to @sven1977's review.
simonsays1980 May 21, 2024
c0eed1f
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 22, 2024
540ad66
Merge branch 'master' into example-restore-1-of-n-agents-from-chkpt
simonsays1980 May 22, 2024
6f6dd8c
Changed checkpoint frequency to 20 as test was not passing due to cac…
simonsays1980 May 22, 2024
ed95651
Removed '--as-test' argument from example file.
simonsays1980 May 22, 2024
9acc2fb
Added 'no_main' tag to the old example in the BUILD file.
simonsays1980 May 22, 2024
2527f6e
Added the file of the old stack test to 'main'.
simonsays1980 May 22, 2024
341cb95
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 22, 2024
b76807f
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 24, 2024
3abb6a8
Merge branch 'master' into example-restore-1-of-n-agents-from-chkpt
simonsays1980 May 24, 2024
a302bd1
Removed old example from example folder and BUILD file.
simonsays1980 May 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2157,15 +2157,6 @@ py_test(
srcs = ["examples/checkpoints/onnx_torch.py"],
)

#@OldAPIStack
py_test(
name = "examples/checkpoints/restore_1_of_n_agents_from_checkpoint",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py"],
args = ["--pre-training-iters=1", "--stop-iters=1", "--num-cpus=4"]
)

# subdirectory: connectors/
# ....................................
# Framestacking examples only run in smoke-test mode (a few iters only).
Expand Down Expand Up @@ -2751,6 +2742,15 @@ py_test(
# args = ["--enable-new-api-stack", "--num-agents=2", "--as-test", "--framework=torch", "--stop-reward=-100.0", "--num-cpus=4"],
# )

py_test(
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
name = "examples/checkpoints/restore_1_of_n_agents_from_checkpoint",
main = "examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py",
tags = ["team:rllib", "exclusive", "examples", "examples_use_all_core", "no_main"],
size = "large",
srcs = ["examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py"],
args = ["--enable-new-api-stack", "--num-agents=2", "--framework=torch", "--checkpoint-freq=20", "--checkpoint-at-end", "--num-cpus=4", "--algo=PPO"]
)

py_test(
name = "examples/multi_agent/rock_paper_scissors_heuristic_vs_learned",
main = "examples/multi_agent/rock_paper_scissors_heuristic_vs_learned.py",
Expand Down
257 changes: 135 additions & 122 deletions rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,140 +1,153 @@
# TODO (sven): Move this example script into the new API stack.

"""Simple example of how to restore only one of n agents from a trained
multi-agent Algorithm using Ray tune.

Control the number of agents and policies via --num-agents and --num-policies.
"""An example script showing how to load module weights for 1 of n agents
from checkpoint.

This example:
- Runs a multi-agent `Pendulum-v1` experiment with >= 2 policies.
- Saves a checkpoint of the `MultiAgentRLModule` used every `--checkpoint-freq`
iterations.
- Stops the experiments after the agents reach a combined return of `-800`.
- Picks the best checkpoint by combined return and restores policy 0 from it.
- Runs a second experiment with the restored `RLModule` for policy 0 and
a fresh `RLModule` for the other policies.
- Stops the second experiment after the agents reach a combined return of `-800`.

How to run this script
----------------------
`python [script file name].py --enable-new-api-stack --num-agents=2
--checkpoint-freq=20 --checkpoint-at-end`

Control the number of agents and policies (RLModules) via --num-agents and
--num-policies.

Control the number of checkpoints by setting `--checkpoint-freq` to a value > 0.
Note that the checkpoint frequency is per iteration and this example needs at
least a single checkpoint to load the RLModule weights for policy 0.
If `--checkpoint-at-end` is set, a checkpoint will be saved at the end of the
experiment.

For debugging, use the following additional command line options
`--no-tune --num-env-runners=0`
which should allow you to set breakpoints anywhere in the RLlib code and
have the execution stop there for inspection and debugging.

For logging to your WandB account, use:
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
--wandb-run-name=[optional: WandB run name (within the defined project)]`

Results to expect
-----------------
You should expect a reward of -400.0 eventually being achieved by a simple
single PPO policy (no tuning, just using RLlib's default settings). In the
second run of the experiment, the MARL module weights for policy 0 are
restored from the checkpoint of the first run. The reward for a single agent
should be -400.0 again, but the training time should be shorter (around 30
iterations instead of 190).
"""

import argparse
import gymnasium as gym
import os
import random

import ray
from ray import air, tune
from ray.air.constants import TRAINING_ITERATION
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
NUM_ENV_STEPS_SAMPLED_LIFETIME,
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentPendulum
from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME
from ray.rllib.utils.test_utils import (
add_rllib_example_script_args,
run_rllib_example_script_experiment,
)
from ray.rllib.utils.test_utils import check_learning_achieved
from ray.tune.registry import get_trainable_cls, register_env

tf1, tf, tfv = try_import_tf()

parser = argparse.ArgumentParser()

parser.add_argument("--num-agents", type=int, default=4)
parser.add_argument("--num-policies", type=int, default=2)
parser.add_argument("--pre-training-iters", type=int, default=5)
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
"--framework",
choices=["tf", "tf2", "torch"],
default="torch",
help="The DL framework specifier.",
)
parser.add_argument(
"--as-test",
action="store_true",
help="Whether this script should be run as a test: --stop-reward must "
"be achieved within --stop-timesteps AND --stop-iters.",
)
parser.add_argument(
"--stop-iters", type=int, default=200, help="Number of iterations to train."
)
parser.add_argument(
"--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
)
parser.add_argument(
"--stop-reward", type=float, default=150.0, help="Reward at which we stop training."
parser = add_rllib_example_script_args(
default_iters=200,
default_timesteps=100000,
default_reward=-400.0,
)
# TODO (sven): This arg is currently ignored (hard-set to 2).
parser.add_argument("--num-policies", type=int, default=2)


if __name__ == "__main__":
args = parser.parse_args()

ray.init(num_cpus=args.num_cpus or None)

# Get obs- and action Spaces.
single_env = gym.make("CartPole-v1")
obs_space = single_env.observation_space
act_space = single_env.action_space

# Setup PPO with an ensemble of `num_policies` different policies.
policies = {
f"policy_{i}": (None, obs_space, act_space, None)
for i in range(args.num_policies)
}
policy_ids = list(policies.keys())

def policy_mapping_fn(agent_id, episode, worker, **kwargs):
pol_id = random.choice(policy_ids)
return pol_id

config = (
PPOConfig()
.environment(MultiAgentCartPole, env_config={"num_agents": args.num_agents})
.framework(args.framework)
.training(num_sgd_iter=10)
.multi_agent(policies=policies, policy_mapping_fn=policy_mapping_fn)
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
# Register our environment with tune.
if args.num_agents > 1:
register_env(
"env",
lambda _: MultiAgentPendulum(config={"num_agents": args.num_agents}),
)
else:
raise ValueError(
f"`num_agents` must be > 1, but is {args.num_agents}."
"Read the script docstring for more information."
)

assert args.checkpoint_freq > 0, (
"This example requires at least one checkpoint to load the RLModule "
"weights for policy 0."
)

# Do some training and store the checkpoint.
results = tune.Tuner(
"PPO",
param_space=config.to_dict(),
run_config=air.RunConfig(
stop={TRAINING_ITERATION: args.pre_training_iters},
verbose=1,
checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=1, checkpoint_at_end=True
),
),
).fit()
print("Pre-training done.")

best_checkpoint = results.get_best_result().checkpoint
print(f".. best checkpoint was: {best_checkpoint}")

policy_0_checkpoint = os.path.join(
best_checkpoint.to_directory(), "policies/policy_0"
base_config = (
get_trainable_cls(args.algo)
.get_default_config()
.environment("env")
.training(
train_batch_size_per_learner=512,
mini_batch_size_per_learner=64,
lambda_=0.1,
gamma=0.95,
lr=0.0003,
vf_clip_param=10.0,
)
.rl_module(
model_config_dict={"fcnet_activation": "relu"},
)
)
restored_policy_0 = Policy.from_checkpoint(policy_0_checkpoint)
restored_policy_0_weights = restored_policy_0.get_weights()
print("Starting new tune.Tuner().fit()")

# Start our actual experiment.
stop = {
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward,
NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps,
TRAINING_ITERATION: args.stop_iters,
# Add a simple multi-agent setup.
if args.num_agents > 0:
base_config.multi_agent(
policies={f"p{i}" for i in range(args.num_agents)},
policy_mapping_fn=lambda aid, *a, **kw: f"p{aid}",
)

# Augment the base config with further settings and train the agents.
results = run_rllib_example_script_experiment(base_config, args)

# Create an env instance to get the observation and action spaces.
env = MultiAgentPendulum(config={"num_agents": args.num_agents})
# Get the default module spec from the algorithm config.
module_spec = base_config.get_default_rl_module_spec()
module_spec.model_config_dict = base_config.model_config | {
"fcnet_activation": "relu",
}

class RestoreWeightsCallback(DefaultCallbacks):
def on_algorithm_init(self, *, algorithm: "Algorithm", **kwargs) -> None:
algorithm.set_weights({"policy_0": restored_policy_0_weights})

# Make sure, the non-1st policies are not updated anymore.
config.policies_to_train = [pid for pid in policy_ids if pid != "policy_0"]
config.callbacks(RestoreWeightsCallback)

results = tune.run(
"PPO",
stop=stop,
config=config.to_dict(),
verbose=1,
module_spec.observation_space = env.envs[0].observation_space
module_spec.action_space = env.envs[0].action_space
# Create the module for each policy, but policy 0.
module_specs = {}
for i in range(1, args.num_agents or 1):
module_specs[f"p{i}"] = module_spec

# Now swap in the RLModule weights for policy 0.
chkpt_path = results.get_best_result().checkpoint.path
p_0_module_state_path = os.path.join(chkpt_path, "learner", "module_state", "p0")
module_spec.load_state_path = p_0_module_state_path
module_specs["p0"] = module_spec

# Create the MARL module.
marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs)
# Define the MARL module in the base config.
base_config.rl_module(rl_module_spec=marl_module_spec)
# We need to re-register the environment when starting a new run.
register_env(
"env",
lambda _: MultiAgentPendulum(config={"num_agents": args.num_agents}),
)
# Define stopping criteria.
stop = {
# TODO (simon): Change to -800 once the metrics are fixed. Currently
# the combined return is not correctly computed.
f"{ENV_RUNNER_RESULTS}/episode_return_mean": -400,
f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 20000,
TRAINING_ITERATION: 30,
}

if args.as_test:
check_learning_achieved(results, args.stop_reward)

ray.shutdown()
# Run the experiment again with the restored MARL module.
run_rllib_example_script_experiment(base_config, args, stop=stop)
Loading