-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Cleanup examples folder 18: Add example script for offline RL (BC) training on single-agent, while evaluating w/ multi-agent setup. #46251
Changes from 1 commit
ff392a1
c80b094
cb41ddb
f12caa0
d40e202
aaab6da
433cd05
f8079a4
0b2ed4d
05edf83
c7dd3e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
# @HybridAPIStack | ||
|
||
"""Example showing how to train a (SA) BC RLModule while evaluating in a MA setup. | ||
|
||
Here, SA=single-agent and MA=multi-agent. | ||
|
||
Note that the BC Algorithm - by default - runs on the hybrid API stack, using RLModules, | ||
but not EnvRunners or ConnectorV2s yet. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably adding here that is specifically doe snot use |
||
|
||
This example: | ||
- demonstrates how you can train a single-agent BC Policy (RLModule) from a JSON | ||
file, which contains SampleBatch (expert or non-expert) data. | ||
- shows how you can run evaluation in a multi-agent setup (for example vs one | ||
or more heuristic policies), while training the BC Policy. | ||
|
||
|
||
How to run this script | ||
---------------------- | ||
`python [script file name].py --checkpoint-at-end` | ||
|
||
For debugging, use the following additional command line options | ||
`--no-tune --num-env-runners=0` | ||
which should allow you to set breakpoints anywhere in the RLlib code and | ||
have the execution stop there for inspection and debugging. | ||
|
||
For logging to your WandB account, use: | ||
`--wandb-key=[your WandB API key] --wandb-project=[some project name] | ||
--wandb-run-name=[optional: WandB run name (within the defined project)]` | ||
|
||
|
||
Results to expect | ||
----------------- | ||
In the console output, you can see that the episode returns of the "main" policy on | ||
the evaluation track keep increasing as BC manages to more and more clone the behavior | ||
found in our (expert) JSON file. | ||
|
||
After 50-100 iterations, you should see the episode reward reach 450.0. | ||
Note that the opponent (random) policy does not learn as it's a) not a trainable | ||
RLModule and b) not being trained via the BCConfig. It's only used for evaluation | ||
purposes here. | ||
|
||
+---------------------+------------+-----------------+--------+--------+ | ||
| Trial name | status | loc | iter | ts | | ||
|---------------------+------------+-----------------+--------+--------+ | ||
| BC_None_ee65e_00000 | TERMINATED | 127.0.0.1:35031 | 93 | 203754 | | ||
+---------------------+------------+-----------------+--------+--------+ | ||
+----------------------+------------------------+ | ||
| eps. return (main) | eps. return (random) | | ||
|----------------------+------------------------| | ||
| 452.4 | 28.3 | | ||
+----------------------+------------------------+ | ||
""" | ||
import os | ||
from pathlib import Path | ||
|
||
import gymnasium as gym | ||
|
||
from ray import tune | ||
from ray.air.constants import TRAINING_ITERATION | ||
from ray.rllib.algorithms.bc import BCConfig | ||
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole | ||
from ray.rllib.examples._old_api_stack.policy.random_policy import RandomPolicy | ||
from ray.rllib.policy.policy import PolicySpec | ||
from ray.rllib.utils.metrics import ( | ||
ENV_RUNNER_RESULTS, | ||
EVALUATION_RESULTS, | ||
NUM_ENV_STEPS_TRAINED, | ||
) | ||
from ray.rllib.utils.test_utils import ( | ||
add_rllib_example_script_args, | ||
run_rllib_example_script_experiment, | ||
) | ||
from ray.train.constants import TIME_TOTAL_S | ||
from ray.tune.registry import register_env | ||
|
||
parser = add_rllib_example_script_args( | ||
default_reward=450.0, | ||
default_timesteps=300000, | ||
) | ||
parser.set_defaults(num_agents=2) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
|
||
register_env("multi_cart", lambda cfg: MultiAgentCartPole(cfg)) | ||
dummy_env = gym.make("CartPole-v1") | ||
|
||
offline_file = str( | ||
Path(os.path.join("../../", "tests/data/cartpole/large.json")).resolve() | ||
) | ||
|
||
base_config = ( | ||
BCConfig() | ||
.environment( | ||
observation_space=dummy_env.observation_space, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we give a quick note, why in this case the user needs to provide the spaces? |
||
action_space=dummy_env.action_space, | ||
) | ||
.offline_data( | ||
input_=offline_file, | ||
) | ||
.multi_agent( | ||
policies={"main"}, | ||
policy_mapping_fn=lambda *a, **kw: "main", | ||
) | ||
.evaluation( | ||
evaluation_interval=1, | ||
evaluation_num_env_runners=0, | ||
evaluation_config=BCConfig.overrides( | ||
# Evaluate on an actual env -> switch input back to "sampler". | ||
input_="sampler", | ||
# Do not explore during evaluation, but act greedily. | ||
explore=False, | ||
# Use a multi-agent setup for evaluation. | ||
env="multi_cart", | ||
env_config={"num_agents": args.num_agents}, | ||
policies={ | ||
"main": PolicySpec(), | ||
"random": PolicySpec(policy_class=RandomPolicy), | ||
}, | ||
# Only control agent 0 with the main (trained) policy. | ||
policy_mapping_fn=( | ||
lambda aid, *a, **kw: "main" if aid == 0 else "random" | ||
), | ||
# Note that we do NOT have to specify the `policies_to_train` here, | ||
# b/c we are inside the evaluation config (no policy is trained during | ||
# evaluation). The fact that the BCConfig above is "only" setup | ||
# as single-agent makes it automatically only train the policy found in | ||
# the BCConfig's `policies` field (which is "main"). | ||
# policies_to_train=["main"], | ||
), | ||
) | ||
) | ||
|
||
policy_eval_returns = ( | ||
f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/policy_reward_mean/" | ||
) | ||
|
||
stop = { | ||
# Check for the "main" policy's episode return, not the combined one. | ||
# The combined one is the sum of the "main" policy + the "random" one. | ||
policy_eval_returns + "main": args.stop_reward, | ||
NUM_ENV_STEPS_TRAINED: args.stop_timesteps, | ||
TRAINING_ITERATION: args.stop_iters, | ||
} | ||
|
||
run_rllib_example_script_experiment( | ||
base_config, | ||
args, | ||
stop=stop, | ||
# We use a special progress reporter here to show the evaluation results (of the | ||
# "main" policy). | ||
# In the following dict, the keys are the (possibly nested) keys that can be | ||
# found in RLlib's (BC's) result dict, produced at every training iteration, and | ||
# the values are the column names you would like to see in your console reports. | ||
# Note that for nested result dict keys, you need to use slashes "/" to define | ||
# the exact path. | ||
progress_reporter=tune.CLIReporter( | ||
metric_columns={ | ||
TRAINING_ITERATION: "iter", | ||
TIME_TOTAL_S: "total time (s)", | ||
NUM_ENV_STEPS_TRAINED: "ts", | ||
policy_eval_returns + "main": "eps. return (main)", | ||
policy_eval_returns + "random": "eps. return (random)", | ||
} | ||
), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this one not using the
MetricsLogger
, yet? I use it in my overhaul of offline RLThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you should. All good there. :)
But the hybrid API stack still goes through the
summarize_episodes
utility insidealgorithm.py
, which operates on the oldRolloutMetrics
objects returned by RolloutWorkers.