-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Cleanup examples folder #14: Add example script for policy (RLModule) inference on new API stack. #45831
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,53 +1,107 @@ | ||
""" | ||
Example showing how you can use your trained policy for inference | ||
(computing actions) in an environment. | ||
"""Example on how to compute actions in production on an already trained policy. | ||
|
||
This example uses the simplest setup possible: An RLModule (policy net) recovered | ||
from a checkpoint and a manual env-loop (CartPole-v1). No ConnectorV2s or EnvRunners are | ||
used in this example. | ||
|
||
This example shows .. | ||
- .. how to use an already existing checkpoint to extract a single-agent RLModule | ||
from (our policy network). | ||
- .. how to setup this recovered policy net for action computations (with or without | ||
using exploration). | ||
- .. have the policy run through a very simple gymnasium based env-loop, w/o using | ||
RLlib's ConnectorV2s or EnvRunners. | ||
|
||
|
||
How to run this script | ||
---------------------- | ||
`python [script file name].py --enable-new-api-stack --stop-reward=200.0` | ||
|
||
Use the `--explore-during-inference` option to switch on exploratory behavior | ||
during inference. Normally, you should not explore during inference, though, | ||
unless your environment has a stochastic optimal solution. | ||
Use the `--num-episodes-during-inference=[int]` option to set the number of | ||
episodes to run through during the inference phase using the restored RLModule. | ||
|
||
For debugging, use the following additional command line options | ||
`--no-tune --num-env-runners=0` | ||
which should allow you to set breakpoints anywhere in the RLlib code and | ||
have the execution stop there for inspection and debugging. | ||
|
||
Note that the shown GPU settings in this script also work in case you are not | ||
running via tune, but instead are using the `--no-tune` command line option. | ||
|
||
For logging to your WandB account, use: | ||
`--wandb-key=[your WandB API key] --wandb-project=[some project name] | ||
--wandb-run-name=[optional: WandB run name (within the defined project)]` | ||
|
||
You can visualize experiment results in ~/ray_results using TensorBoard. | ||
|
||
Includes options for LSTM-based models (--use-lstm), attention-net models | ||
(--use-attention), and plain (non-recurrent) models. | ||
|
||
Results to expect | ||
----------------- | ||
|
||
For the training step - depending on your `--stop-reward` setting, you should see | ||
something similar to this: | ||
|
||
Number of trials: 1/1 (1 TERMINATED) | ||
+-----------------------------+------------+-----------------+--------+ | ||
| Trial name | status | loc | iter | | ||
| | | | | | ||
|-----------------------------+------------+-----------------+--------+ | ||
| PPO_CartPole-v1_6660c_00000 | TERMINATED | 127.0.0.1:43566 | 8 | | ||
+-----------------------------+------------+-----------------+--------+ | ||
+------------------+------------------------+------------------------+ | ||
| total time (s) | num_env_steps_sample | num_env_steps_traine | | ||
| | d_lifetime | d_lifetime | | ||
+------------------+------------------------+------------------------+ | ||
| 21.0283 | 32000 | 32000 | | ||
+------------------+------------------------+------------------------+ | ||
|
||
Then, after restoring the RLModule for the inference phase, your output should | ||
look similar to: | ||
|
||
Training completed. Restoring new RLModule for action inference. | ||
Episode done: Total reward = 500.0 | ||
Episode done: Total reward = 500.0 | ||
Episode done: Total reward = 500.0 | ||
Episode done: Total reward = 500.0 | ||
Episode done: Total reward = 500.0 | ||
Episode done: Total reward = 500.0 | ||
Episode done: Total reward = 500.0 | ||
Episode done: Total reward = 500.0 | ||
Episode done: Total reward = 500.0 | ||
Episode done: Total reward = 500.0 | ||
Done performing action inference through 10 Episodes | ||
""" | ||
import argparse | ||
import gymnasium as gym | ||
import numpy as np | ||
import os | ||
|
||
import ray | ||
from ray import air, tune | ||
from ray.air.constants import TRAINING_ITERATION | ||
from ray.rllib.algorithms.algorithm import Algorithm | ||
from ray.rllib.core import DEFAULT_MODULE_ID | ||
from ray.rllib.core.columns import Columns | ||
from ray.rllib.core.rl_module.rl_module import RLModule | ||
from ray.rllib.utils.framework import try_import_torch | ||
from ray.rllib.utils.numpy import convert_to_numpy, softmax | ||
from ray.rllib.utils.metrics import ( | ||
ENV_RUNNER_RESULTS, | ||
EPISODE_RETURN_MEAN, | ||
NUM_ENV_STEPS_SAMPLED_LIFETIME, | ||
) | ||
from ray.rllib.utils.test_utils import ( | ||
add_rllib_example_script_args, | ||
run_rllib_example_script_experiment, | ||
) | ||
from ray.tune.registry import get_trainable_cls | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--run", type=str, default="PPO", help="The RLlib-registered algorithm to use." | ||
) | ||
parser.add_argument("--num-cpus", type=int, default=0) | ||
parser.add_argument( | ||
"--framework", | ||
choices=["tf", "tf2", "torch"], | ||
default="torch", | ||
help="The DL framework specifier.", | ||
) | ||
parser.add_argument( | ||
"--stop-iters", | ||
type=int, | ||
default=200, | ||
help="Number of iterations to train before we do inference.", | ||
) | ||
parser.add_argument( | ||
"--stop-timesteps", | ||
type=int, | ||
default=100000, | ||
help="Number of timesteps to train before we do inference.", | ||
) | ||
parser.add_argument( | ||
"--stop-reward", | ||
type=float, | ||
default=150.0, | ||
help="Reward at which we stop training before we do inference.", | ||
torch, _ = try_import_torch() | ||
|
||
parser = add_rllib_example_script_args(default_reward=200.0) | ||
parser.set_defaults( | ||
# Make sure that - by default - we produce checkpoints during training. | ||
checkpoint_freq=1, | ||
checkpoint_at_end=True, | ||
# Use CartPole-v1 by default. | ||
env="CartPole-v1", | ||
) | ||
parser.add_argument( | ||
"--explore-during-inference", | ||
|
@@ -59,74 +113,70 @@ | |
"--num-episodes-during-inference", | ||
type=int, | ||
default=10, | ||
help="Number of episodes to do inference over after training.", | ||
help="Number of episodes to do inference over (after restoring from a checkpoint).", | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
|
||
ray.init(num_cpus=args.num_cpus or None) | ||
assert ( | ||
args.enable_new_api_stack | ||
), "Must set --enable-new-api-stack when running this script!" | ||
|
||
config = ( | ||
get_trainable_cls(args.run) | ||
.get_default_config() | ||
.environment("FrozenLake-v1") | ||
# Run with tracing enabled for tf2? | ||
.framework(args.framework) | ||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. | ||
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) | ||
) | ||
|
||
stop = { | ||
TRAINING_ITERATION: args.stop_iters, | ||
NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps, | ||
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward, | ||
} | ||
base_config = get_trainable_cls(args.algo).get_default_config() | ||
|
||
print("Training policy until desired reward/timesteps/iterations. ...") | ||
tuner = tune.Tuner( | ||
args.run, | ||
param_space=config.to_dict(), | ||
run_config=air.RunConfig( | ||
stop=stop, | ||
verbose=2, | ||
checkpoint_config=air.CheckpointConfig( | ||
checkpoint_frequency=1, checkpoint_at_end=True | ||
), | ||
), | ||
) | ||
results = tuner.fit() | ||
results = run_rllib_example_script_experiment(base_config, args) | ||
|
||
print("Training completed. Restoring new Algorithm for action inference.") | ||
print("Training completed. Restoring new RLModule for action inference.") | ||
# Get the last checkpoint from the above training run. | ||
checkpoint = results.get_best_result().checkpoint | ||
best_result = results.get_best_result( | ||
metric=f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}", mode="max" | ||
) | ||
# Create new Algorithm and restore its state from the last checkpoint. | ||
algo = Algorithm.from_checkpoint(checkpoint) | ||
rl_module = RLModule.from_checkpoint( | ||
os.path.join( | ||
best_result.checkpoint.path, | ||
"learner", | ||
"module_state", | ||
DEFAULT_MODULE_ID, | ||
) | ||
) | ||
|
||
# Create the env to do inference in. | ||
env = gym.make("FrozenLake-v1") | ||
env = gym.make(args.env) | ||
obs, info = env.reset() | ||
|
||
num_episodes = 0 | ||
episode_reward = 0.0 | ||
episode_return = 0.0 | ||
|
||
while num_episodes < args.num_episodes_during_inference: | ||
# Compute an action (`a`). | ||
a = algo.compute_single_action( | ||
observation=obs, | ||
explore=args.explore_during_inference, | ||
policy_id="default_policy", # <- default value | ||
) | ||
# Compute an action using a B=1 observation "batch". | ||
input_dict = {Columns.OBS: torch.from_numpy(obs).unsqueeze(0)} | ||
# No exploration. | ||
if not args.explore_during_inference: | ||
rl_module_out = rl_module.forward_inference(input_dict) | ||
# Using exploration. | ||
else: | ||
rl_module_out = rl_module.forward_exploration(input_dict) | ||
|
||
# For discrete action spaces used here, normally, an RLModule "only" | ||
# produces action logits, from which we then have to sample. | ||
# However, you can also write custom RLModules that output actions | ||
# directly, performing the sampling step already inside their | ||
# `forward_...()` methods. | ||
logits = convert_to_numpy(rl_module_out[Columns.ACTION_DIST_INPUTS]) | ||
# Perform the sampling step in numpy for simplicity. | ||
action = np.random.choice(env.action_space.n, p=softmax(logits[0])) | ||
# Send the computed action `a` to the env. | ||
obs, reward, done, truncated, _ = env.step(a) | ||
episode_reward += reward | ||
obs, reward, terminated, truncated, _ = env.step(action) | ||
episode_return += reward | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a beautiful example to show how simple this runs now. Let us think about some ways to simplify it also for modules using connectors. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome! Yeah, I agree. Doing these examples feel very fast and easy. I don't have to do much debugging at all to make these run right from the get-go. There is another PR that does something very similar, but with a connector (that handles the LSTM states). |
||
# Is the episode `done`? -> Reset. | ||
if done: | ||
print(f"Episode done: Total reward = {episode_reward}") | ||
if terminated or truncated: | ||
print(f"Episode done: Total reward = {episode_return}") | ||
obs, info = env.reset() | ||
num_episodes += 1 | ||
episode_reward = 0.0 | ||
|
||
algo.stop() | ||
episode_return = 0.0 | ||
|
||
ray.shutdown() | ||
print(f"Done performing action inference through {num_episodes} Episodes") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use here the input specs of the module to infer the keys? Here it’s simple but in other scenarios it could help users to figure out what to feed in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not 100% sure. The input specs don't tell us much about what exact data is required for the keys in the specs. I'm honestly thinking about removing the specs altogether at some point (long-term).
For example: If my RLModule - right now - says: I need
obs
andprev_rewards
, then I still don't know for example, how many of the previous rewards are required. This detailed information - crucial for building the batch - is not something my model would tell me, I will have to provide a proper ConnectorV2 logic along with it.