-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Cleanup examples
folder 04: Curriculum and checkpoint-by-custom-criteria examples moved to new API stack.
#44706
Changes from all commits
7728726
68dc2f6
ac1ba10
e1d1058
44e436b
fc19603
2645a8b
3bdac7f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,67 +1,90 @@ | ||
# TODO (sven): Move this example script into the new API stack. | ||
"""Example extracting a checkpoint from n trials using one or more custom criteria. | ||
|
||
import argparse | ||
import os | ||
This example: | ||
- runs a simple CartPole experiment with three different learning rates (three tune | ||
"trials"). During the experiment, for each trial, we create a checkpoint at each | ||
iteration. | ||
- at the end of the experiment, we compare the trials and pick the one that performed | ||
best, based on the criterion: Lowest episode count per single iteration (for CartPole, | ||
a low episode count means the episodes are very long and thus the reward is also very | ||
high). | ||
- from that best trial (with the lowest episode count), we then pick those checkpoints | ||
that a) have the lowest policy loss (good) and b) have the highest value function loss | ||
(bad). | ||
|
||
import ray | ||
from ray import air, tune | ||
from ray.tune.registry import get_trainable_cls | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--run", type=str, default="PPO", help="The RLlib-registered algorithm to use." | ||
) | ||
parser.add_argument("--num-cpus", type=int, default=0) | ||
parser.add_argument( | ||
"--framework", | ||
choices=["tf", "tf2", "torch"], | ||
default="torch", | ||
help="The DL framework specifier.", | ||
How to run this script | ||
---------------------- | ||
`python [script file name].py --enable-new-api-stack` | ||
|
||
For debugging, use the following additional command line options | ||
`--no-tune --num-env-runners=0` | ||
which should allow you to set breakpoints anywhere in the RLlib code and | ||
have the execution stop there for inspection and debugging. | ||
|
||
For logging to your WandB account, use: | ||
`--wandb-key=[your WandB API key] --wandb-project=[some project name] | ||
--wandb-run-name=[optional: WandB run name (within the defined project)]` | ||
|
||
|
||
Results to expect | ||
----------------- | ||
In the console output, you can see the performance of the three different learning | ||
rates used here: | ||
|
||
+-----------------------------+------------+-----------------+--------+--------+ | ||
| Trial name | status | loc | lr | iter | | ||
|-----------------------------+------------+-----------------+--------+--------+ | ||
| PPO_CartPole-v1_d7dbe_00000 | TERMINATED | 127.0.0.1:98487 | 0.01 | 17 | | ||
| PPO_CartPole-v1_d7dbe_00001 | TERMINATED | 127.0.0.1:98488 | 0.001 | 8 | | ||
| PPO_CartPole-v1_d7dbe_00002 | TERMINATED | 127.0.0.1:98489 | 0.0001 | 9 | | ||
+-----------------------------+------------+-----------------+--------+--------+ | ||
|
||
+------------------+-------+----------+----------------------+----------------------+ | ||
| total time (s) | ts | reward | episode_reward_max | episode_reward_min | | ||
|------------------+-------+----------+----------------------+----------------------+ | ||
| 28.1068 | 39797 | 151.11 | 500 | 12 | | ||
| 13.304 | 18728 | 158.91 | 500 | 15 | | ||
| 14.8848 | 21069 | 167.36 | 500 | 13 | | ||
+------------------+-------+----------+----------------------+----------------------+ | ||
|
||
+--------------------+ | ||
| episode_len_mean | | ||
|--------------------| | ||
| 151.11 | | ||
| 158.91 | | ||
| 167.36 | | ||
+--------------------+ | ||
""" | ||
|
||
from ray import tune | ||
from ray.rllib.utils.test_utils import ( | ||
add_rllib_example_script_args, | ||
run_rllib_example_script_experiment, | ||
) | ||
parser.add_argument("--stop-iters", type=int, default=200) | ||
parser.add_argument("--stop-timesteps", type=int, default=100000) | ||
parser.add_argument("--stop-reward", type=float, default=150.0) | ||
parser.add_argument( | ||
"--local-mode", | ||
action="store_true", | ||
help="Init Ray in local mode for easier debugging.", | ||
from ray.tune.registry import get_trainable_cls | ||
|
||
parser = add_rllib_example_script_args( | ||
default_reward=450.0, default_timesteps=100000, default_iters=200 | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
|
||
ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode) | ||
# Force-set `args.checkpoint_freq` to 1. | ||
args.checkpoint_freq = 1 | ||
|
||
# Simple generic config. | ||
config = ( | ||
get_trainable_cls(args.run) | ||
base_config = ( | ||
get_trainable_cls(args.algo) | ||
.get_default_config() | ||
.environment("CartPole-v1") | ||
# Run with tracing enabled for tf2. | ||
.framework(args.framework) | ||
# Run 3 trials. | ||
.training( | ||
lr=tune.grid_search([0.01, 0.001, 0.0001]), train_batch_size=2341 | ||
) # TEST | ||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. | ||
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) | ||
.training(lr=tune.grid_search([0.01, 0.001, 0.0001]), train_batch_size=2341) | ||
) | ||
|
||
stop = { | ||
"training_iteration": args.stop_iters, | ||
"timesteps_total": args.stop_timesteps, | ||
"episode_reward_mean": args.stop_reward, | ||
} | ||
|
||
# Run tune for some iterations and generate checkpoints. | ||
tuner = tune.Tuner( | ||
args.run, | ||
param_space=config.to_dict(), | ||
run_config=air.RunConfig( | ||
stop=stop, checkpoint_config=air.CheckpointConfig(checkpoint_frequency=1) | ||
), | ||
) | ||
results = tuner.fit() | ||
results = run_rllib_example_script_experiment(base_config, args) | ||
|
||
# Get the best of the 3 trials by using some metric. | ||
# NOTE: Choosing the min `episodes_this_iter` automatically picks the trial | ||
|
@@ -79,28 +102,32 @@ | |
best_result = results.get_best_result(metric=metric, mode="min", scope="all") | ||
value_best_metric = best_result.metrics_dataframe[metric].min() | ||
print( | ||
"Best trial's lowest episode length (over all " | ||
"iterations): {}".format(value_best_metric) | ||
f"Best trial was the one with lr={best_result.metrics['config']['lr']}. " | ||
"Reached lowest episode count ({value_best_metric}) in a single iteration." | ||
) | ||
|
||
# Confirm, we picked the right trial. | ||
assert value_best_metric <= results.get_dataframe()[metric].min() | ||
|
||
# Get the best checkpoints from the trial, based on different metrics. | ||
# Checkpoint with the lowest policy loss value: | ||
if config._enable_new_api_stack: | ||
if args.enable_new_api_stack: | ||
policy_loss_key = "info/learner/default_policy/policy_loss" | ||
else: | ||
policy_loss_key = "info/learner/default_policy/learner_stats/policy_loss" | ||
ckpt = results.get_best_result(metric=policy_loss_key, mode="min").checkpoint | ||
print("Lowest pol-loss: {}".format(ckpt)) | ||
best_result = results.get_best_result(metric=policy_loss_key, mode="min") | ||
ckpt = best_result.checkpoint | ||
lowest_policy_loss = best_result.metrics_dataframe[policy_loss_key].min() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could also ask here for the best checkpoint along the training path There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, cool, so And if the last is not the best one, it's better to do: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This actually doesn't seem to work well with nested keys.
|
||
print(f"Checkpoint w/ lowest policy loss: {ckpt}") | ||
print(f"Lowest policy loss: {lowest_policy_loss}") | ||
|
||
# Checkpoint with the highest value-function loss: | ||
if config._enable_new_api_stack: | ||
if args.enable_new_api_stack: | ||
vf_loss_key = "info/learner/default_policy/vf_loss" | ||
else: | ||
vf_loss_key = "info/learner/default_policy/learner_stats/vf_loss" | ||
ckpt = results.get_best_result(metric=vf_loss_key, mode="max").checkpoint | ||
print("Highest vf-loss: {}".format(ckpt)) | ||
|
||
ray.shutdown() | ||
best_result = results.get_best_result(metric=vf_loss_key, mode="max") | ||
ckpt = best_result.checkpoint | ||
highest_value_fn_loss = best_result.metrics_dataframe[vf_loss_key].max() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here as well |
||
print(f"Checkpoint w/ highest value function loss: {ckpt}") | ||
print(f"Highest value function loss: {highest_value_fn_loss}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works also with tune, but
--local-mode
:)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely! I'm always afraid, we are going to get rid of Ray local-mode at some point. Also, for any number of Learner workers > 0, local mode doesn't work (not sure why, actually).