diff --git a/rllib/BUILD b/rllib/BUILD index 65b65b7cda91..9e056f39f061 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2486,24 +2486,13 @@ py_test( # subdirectory: inference/ # .................................... -#@OldAPIStack -py_test( - name = "examples/inference/policy_inference_after_training_tf", - main = "examples/inference/policy_inference_after_training.py", - tags = ["team:rllib", "exclusive", "examples"], - size = "medium", - srcs = ["examples/inference/policy_inference_after_training.py"], - args = ["--stop-iters=3", "--framework=tf"] -) - -#@OldAPIStack py_test( - name = "examples/inference/policy_inference_after_training_torch", + name = "examples/inference/policy_inference_after_training", main = "examples/inference/policy_inference_after_training.py", tags = ["team:rllib", "exclusive", "examples"], size = "medium", srcs = ["examples/inference/policy_inference_after_training.py"], - args = ["--stop-iters=3", "--framework=torch"] + args = ["--enable-new-api-stack", "--stop-reward=100.0"] ) #@OldAPIStack diff --git a/rllib/examples/inference/policy_inference_after_training.py b/rllib/examples/inference/policy_inference_after_training.py index 91b85ecf48ce..0f61f4519cd7 100644 --- a/rllib/examples/inference/policy_inference_after_training.py +++ b/rllib/examples/inference/policy_inference_after_training.py @@ -1,53 +1,107 @@ -""" -Example showing how you can use your trained policy for inference -(computing actions) in an environment. +"""Example on how to compute actions in production on an already trained policy. + +This example uses the simplest setup possible: An RLModule (policy net) recovered +from a checkpoint and a manual env-loop (CartPole-v1). No ConnectorV2s or EnvRunners are +used in this example. + +This example shows .. + - .. how to use an already existing checkpoint to extract a single-agent RLModule + from (our policy network). + - .. how to setup this recovered policy net for action computations (with or without + using exploration). + - .. have the policy run through a very simple gymnasium based env-loop, w/o using + RLlib's ConnectorV2s or EnvRunners. + + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack --stop-reward=200.0` + +Use the `--explore-during-inference` option to switch on exploratory behavior +during inference. Normally, you should not explore during inference, though, +unless your environment has a stochastic optimal solution. +Use the `--num-episodes-during-inference=[int]` option to set the number of +episodes to run through during the inference phase using the restored RLModule. + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +Note that the shown GPU settings in this script also work in case you are not +running via tune, but instead are using the `--no-tune` command line option. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + +You can visualize experiment results in ~/ray_results using TensorBoard. -Includes options for LSTM-based models (--use-lstm), attention-net models -(--use-attention), and plain (non-recurrent) models. + +Results to expect +----------------- + +For the training step - depending on your `--stop-reward` setting, you should see +something similar to this: + +Number of trials: 1/1 (1 TERMINATED) ++-----------------------------+------------+-----------------+--------+ +| Trial name | status | loc | iter | +| | | | | +|-----------------------------+------------+-----------------+--------+ +| PPO_CartPole-v1_6660c_00000 | TERMINATED | 127.0.0.1:43566 | 8 | ++-----------------------------+------------+-----------------+--------+ ++------------------+------------------------+------------------------+ +| total time (s) | num_env_steps_sample | num_env_steps_traine | +| | d_lifetime | d_lifetime | ++------------------+------------------------+------------------------+ +| 21.0283 | 32000 | 32000 | ++------------------+------------------------+------------------------+ + +Then, after restoring the RLModule for the inference phase, your output should +look similar to: + +Training completed. Restoring new RLModule for action inference. +Episode done: Total reward = 500.0 +Episode done: Total reward = 500.0 +Episode done: Total reward = 500.0 +Episode done: Total reward = 500.0 +Episode done: Total reward = 500.0 +Episode done: Total reward = 500.0 +Episode done: Total reward = 500.0 +Episode done: Total reward = 500.0 +Episode done: Total reward = 500.0 +Episode done: Total reward = 500.0 +Done performing action inference through 10 Episodes """ -import argparse import gymnasium as gym +import numpy as np import os -import ray -from ray import air, tune -from ray.air.constants import TRAINING_ITERATION -from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.core import DEFAULT_MODULE_ID +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.numpy import convert_to_numpy, softmax from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, EPISODE_RETURN_MEAN, - NUM_ENV_STEPS_SAMPLED_LIFETIME, +) +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, ) from ray.tune.registry import get_trainable_cls -parser = argparse.ArgumentParser() -parser.add_argument( - "--run", type=str, default="PPO", help="The RLlib-registered algorithm to use." -) -parser.add_argument("--num-cpus", type=int, default=0) -parser.add_argument( - "--framework", - choices=["tf", "tf2", "torch"], - default="torch", - help="The DL framework specifier.", -) -parser.add_argument( - "--stop-iters", - type=int, - default=200, - help="Number of iterations to train before we do inference.", -) -parser.add_argument( - "--stop-timesteps", - type=int, - default=100000, - help="Number of timesteps to train before we do inference.", -) -parser.add_argument( - "--stop-reward", - type=float, - default=150.0, - help="Reward at which we stop training before we do inference.", +torch, _ = try_import_torch() + +parser = add_rllib_example_script_args(default_reward=200.0) +parser.set_defaults( + # Make sure that - by default - we produce checkpoints during training. + checkpoint_freq=1, + checkpoint_at_end=True, + # Use CartPole-v1 by default. + env="CartPole-v1", ) parser.add_argument( "--explore-during-inference", @@ -59,74 +113,70 @@ "--num-episodes-during-inference", type=int, default=10, - help="Number of episodes to do inference over after training.", + help="Number of episodes to do inference over (after restoring from a checkpoint).", ) + if __name__ == "__main__": args = parser.parse_args() - ray.init(num_cpus=args.num_cpus or None) + assert ( + args.enable_new_api_stack + ), "Must set --enable-new-api-stack when running this script!" - config = ( - get_trainable_cls(args.run) - .get_default_config() - .environment("FrozenLake-v1") - # Run with tracing enabled for tf2? - .framework(args.framework) - # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. - .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) - ) - - stop = { - TRAINING_ITERATION: args.stop_iters, - NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps, - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward, - } + base_config = get_trainable_cls(args.algo).get_default_config() print("Training policy until desired reward/timesteps/iterations. ...") - tuner = tune.Tuner( - args.run, - param_space=config.to_dict(), - run_config=air.RunConfig( - stop=stop, - verbose=2, - checkpoint_config=air.CheckpointConfig( - checkpoint_frequency=1, checkpoint_at_end=True - ), - ), - ) - results = tuner.fit() + results = run_rllib_example_script_experiment(base_config, args) - print("Training completed. Restoring new Algorithm for action inference.") + print("Training completed. Restoring new RLModule for action inference.") # Get the last checkpoint from the above training run. - checkpoint = results.get_best_result().checkpoint + best_result = results.get_best_result( + metric=f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}", mode="max" + ) # Create new Algorithm and restore its state from the last checkpoint. - algo = Algorithm.from_checkpoint(checkpoint) + rl_module = RLModule.from_checkpoint( + os.path.join( + best_result.checkpoint.path, + "learner", + "module_state", + DEFAULT_MODULE_ID, + ) + ) # Create the env to do inference in. - env = gym.make("FrozenLake-v1") + env = gym.make(args.env) obs, info = env.reset() num_episodes = 0 - episode_reward = 0.0 + episode_return = 0.0 while num_episodes < args.num_episodes_during_inference: - # Compute an action (`a`). - a = algo.compute_single_action( - observation=obs, - explore=args.explore_during_inference, - policy_id="default_policy", # <- default value - ) + # Compute an action using a B=1 observation "batch". + input_dict = {Columns.OBS: torch.from_numpy(obs).unsqueeze(0)} + # No exploration. + if not args.explore_during_inference: + rl_module_out = rl_module.forward_inference(input_dict) + # Using exploration. + else: + rl_module_out = rl_module.forward_exploration(input_dict) + + # For discrete action spaces used here, normally, an RLModule "only" + # produces action logits, from which we then have to sample. + # However, you can also write custom RLModules that output actions + # directly, performing the sampling step already inside their + # `forward_...()` methods. + logits = convert_to_numpy(rl_module_out[Columns.ACTION_DIST_INPUTS]) + # Perform the sampling step in numpy for simplicity. + action = np.random.choice(env.action_space.n, p=softmax(logits[0])) # Send the computed action `a` to the env. - obs, reward, done, truncated, _ = env.step(a) - episode_reward += reward + obs, reward, terminated, truncated, _ = env.step(action) + episode_return += reward # Is the episode `done`? -> Reset. - if done: - print(f"Episode done: Total reward = {episode_reward}") + if terminated or truncated: + print(f"Episode done: Total reward = {episode_return}") obs, info = env.reset() num_episodes += 1 - episode_reward = 0.0 - - algo.stop() + episode_return = 0.0 - ray.shutdown() + print(f"Done performing action inference through {num_episodes} Episodes")