Skip to content

Commit

Permalink
[RLlib] Cleanup examples folder 18: Add example script for offline RL…
Browse files Browse the repository at this point in the history
… (BC) training on single-agent, while evaluating w/ multi-agent setup. (#46251)
  • Loading branch information
sven1977 authored Jun 26, 2024
1 parent e9109e6 commit b257b49
Show file tree
Hide file tree
Showing 13 changed files with 267 additions and 40 deletions.
12 changes: 12 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3031,6 +3031,18 @@ py_test(

# subdirectory: offline_rl/
# ....................................

# @HybridAPIStack
py_test(
name = "examples/offline_rl/pretrain_bc_single_agent_evaluate_as_multi_agent",
main = "examples/offline_rl/pretrain_bc_single_agent_evaluate_as_multi_agent.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "large",
srcs = ["examples/offline_rl/pretrain_bc_single_agent_evaluate_as_multi_agent.py"],
data = ["tests/data/cartpole/large.json"],
args = ["--as-test"]
)

#@OldAPIStack
# TODO (sven): Doesn't seem to learn at the moment. Uncomment once fixed.
# py_test(
Expand Down
5 changes: 2 additions & 3 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,9 +1112,8 @@ def _evaluate_with_custom_eval_function(self) -> Tuple[ResultDict, int, int]:
def _evaluate_on_local_env_runner(self, env_runner):
if hasattr(env_runner, "input_reader") and env_runner.input_reader is None:
raise ValueError(
"Cannot evaluate on a local worker (wether there is no evaluation "
"EnvRunnerGroup OR no remote evaluation workers) in the Algorithm or "
"w/o an environment on that local worker!\nTry one of the following:"
"Can't evaluate on a local worker if this local worker does not have "
"an environment!\nTry one of the following:"
"\n1) Set `evaluation_interval` > 0 to force creating a separate "
"evaluation EnvRunnerGroup.\n2) Set `create_env_on_driver=True` to "
"force the local (non-eval) EnvRunner to have an environment to "
Expand Down
35 changes: 27 additions & 8 deletions rllib/algorithms/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from ray.rllib.utils.metrics import (
ALL_MODULES,
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
SAMPLE_TIMER,
SYNCH_WORKER_WEIGHTS_TIMER,
)
Expand Down Expand Up @@ -166,34 +168,51 @@ def training_step(self) -> ResultDict:
max_env_steps=self.config.train_batch_size,
)

train_batch = train_batch.as_multi_agent()
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
# TODO (sven): Use metrics API as soon as we moved to new API stack
# (from currently hybrid stack).
# self.metrics.log_dict(
# {
# NUM_AGENT_STEPS_SAMPLED_LIFETIME: len(train_batch),
# NUM_ENV_STEPS_SAMPLED_LIFETIME: len(train_batch),
# },
# reduce="sum",
# )
self._counters[NUM_AGENT_STEPS_SAMPLED] += len(train_batch)
self._counters[NUM_ENV_STEPS_SAMPLED] += len(train_batch)

# Updating the policy.
train_results = self.learner_group.update_from_batch(batch=train_batch)
# TODO (sven): Use metrics API as soon as we moved to new API stack
# (from currently hybrid stack).
# self.metrics.log_dict(
# {
# NUM_AGENT_STEPS_TRAINED_LIFETIME: len(train_batch),
# NUM_ENV_STEPS_TRAINED_LIFETIME: len(train_batch),
# },
# reduce="sum",
# )
self._counters[NUM_AGENT_STEPS_TRAINED] += len(train_batch)
self._counters[NUM_ENV_STEPS_TRAINED] += len(train_batch)

# Synchronize weights.
# As the results contain for each policy the loss and in addition the
# total loss over all policies is returned, this total loss has to be
# removed.
policies_to_update = set(train_results.keys()) - {ALL_MODULES}

global_vars = {
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
}

# with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
if self.workers.num_remote_workers() > 0:
self.workers.sync_weights(
from_worker_or_learner_group=self.learner_group,
policies=policies_to_update,
global_vars=global_vars,
)
# Get weights from Learner to local worker.
else:
self.workers.local_worker().set_weights(
self.learner_group.get_weights()
)

# TODO (sven): Use metrics API as soon as we moved to new API stack
# (from currently hybrid stack).
return train_results
2 changes: 1 addition & 1 deletion rllib/algorithms/tests/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def test_evaluation_wo_evaluation_env_runner_group(self):
algo_wo_env_on_local_worker = config.build()
self.assertRaisesRegex(
ValueError,
"Cannot evaluate on a local worker",
"Can't evaluate on a local worker",
algo_wo_env_on_local_worker.evaluate,
)
algo_wo_env_on_local_worker.stop()
Expand Down
6 changes: 6 additions & 0 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,12 @@ def _update_from_batch_or_episodes(
},
env_steps=sum(len(e) for e in episodes),
)
# Have to convert to MultiAgentBatch.
elif isinstance(batch, SampleBatch):
assert len(self.module) == 1
batch = MultiAgentBatch(
{next(iter(self.module.keys())): batch}, env_steps=len(batch)
)

# Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs
# found in this batch. If not, throw an error.
Expand Down
4 changes: 4 additions & 0 deletions rllib/core/rl_module/marl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ def keys(self) -> KeysView[ModuleID]:
"""Returns a keys view over the module IDs in this MultiAgentRLModule."""
return self._rl_modules.keys()

def __len__(self) -> int:
"""Returns the number of RLModules within this MultiAgentRLModule."""
return len(self._rl_modules)

@override(RLModule)
def as_multi_agent(self) -> "MultiAgentRLModule":
"""Returns a multi-agent wrapper around this module.
Expand Down
4 changes: 2 additions & 2 deletions rllib/env/multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,10 +530,10 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def step(self, action_dict):
obs, rew, terminated, truncated, info = {}, {}, {}, {}, {}

# the environment is expecting action for at least one agent
# The environment is expecting an action for at least one agent.
if len(action_dict) == 0:
raise ValueError(
"The environment is expecting action for at least one agent."
"The environment is expecting an action for at least one agent."
)

for i, action in action_dict.items():
Expand Down
8 changes: 6 additions & 2 deletions rllib/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,12 @@ def summarize_episodes(
episode_rewards.append(episode.episode_reward)
for k, v in episode.custom_metrics.items():
custom_metrics[k].append(v)
for (_, policy_id), reward in episode.agent_rewards.items():
if policy_id != DEFAULT_POLICY_ID:
is_multi_agent = (
len(episode.agent_rewards) > 1
or DEFAULT_POLICY_ID not in episode.agent_rewards
)
if is_multi_agent:
for (_, policy_id), reward in episode.agent_rewards.items():
policy_rewards[policy_id].append(reward)
for k, v in episode.hist_data.items():
hist_stats[k] += v
Expand Down
20 changes: 11 additions & 9 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,17 +541,19 @@ def wrap(env):
):
pol._update_model_view_requirements_from_init_state()

self.multiagent: bool = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
if self.multiagent and self.env is not None:
if not isinstance(
if (
self.config.is_multi_agent()
and self.env is not None
and not isinstance(
self.env,
(BaseEnv, ExternalMultiAgentEnv, MultiAgentEnv, ray.actor.ActorHandle),
):
raise ValueError(
f"Have multiple policies {self.policy_map}, but the "
f"env {self.env} is not a subclass of BaseEnv, "
f"MultiAgentEnv, ActorHandle, or ExternalMultiAgentEnv!"
)
)
):
raise ValueError(
f"You are running a multi-agent setup, but the env {self.env} is not a "
f"subclass of BaseEnv, MultiAgentEnv, ActorHandle, or "
f"ExternalMultiAgentEnv!"
)

if self.worker_index == 0:
logger.info("Built filter map: {}".format(self.filters))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# @HybridAPIStack

"""Example showing how to train a (SA) BC RLModule while evaluating in a MA setup.
Here, SA=single-agent and MA=multi-agent.
Note that the BC Algorithm - by default - runs on the hybrid API stack, using RLModules,
but not `ConnectorV2` and `SingleAgentEpisode` yet.
This example:
- demonstrates how you can train a single-agent BC Policy (RLModule) from a JSON
file, which contains SampleBatch (expert or non-expert) data.
- shows how you can run evaluation in a multi-agent setup (for example vs one
or more heuristic policies), while training the BC Policy.
How to run this script
----------------------
`python [script file name].py --checkpoint-at-end`
For debugging, use the following additional command line options
`--no-tune --num-env-runners=0`
which should allow you to set breakpoints anywhere in the RLlib code and
have the execution stop there for inspection and debugging.
For logging to your WandB account, use:
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
--wandb-run-name=[optional: WandB run name (within the defined project)]`
Results to expect
-----------------
In the console output, you can see that the episode returns of the "main" policy on
the evaluation track keep increasing as BC manages to more and more clone the behavior
found in our (expert) JSON file.
After 50-100 iterations, you should see the episode reward reach 450.0.
Note that the opponent (random) policy does not learn as it's a) not a trainable
RLModule and b) not being trained via the BCConfig. It's only used for evaluation
purposes here.
+---------------------+------------+-----------------+--------+--------+
| Trial name | status | loc | iter | ts |
|---------------------+------------+-----------------+--------+--------+
| BC_None_ee65e_00000 | TERMINATED | 127.0.0.1:35031 | 93 | 203754 |
+---------------------+------------+-----------------+--------+--------+
+----------------------+------------------------+
| eps. return (main) | eps. return (random) |
|----------------------+------------------------|
| 452.4 | 28.3 |
+----------------------+------------------------+
"""
import os
from pathlib import Path

import gymnasium as gym

from ray import tune
from ray.air.constants import TRAINING_ITERATION
from ray.rllib.algorithms.bc import BCConfig
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
from ray.rllib.examples._old_api_stack.policy.random_policy import RandomPolicy
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EVALUATION_RESULTS,
NUM_ENV_STEPS_TRAINED,
)
from ray.rllib.utils.test_utils import (
add_rllib_example_script_args,
run_rllib_example_script_experiment,
)
from ray.train.constants import TIME_TOTAL_S
from ray.tune.registry import register_env

parser = add_rllib_example_script_args(
default_reward=450.0,
default_timesteps=300000,
)
parser.set_defaults(num_agents=2)


if __name__ == "__main__":
args = parser.parse_args()

register_env("multi_cart", lambda cfg: MultiAgentCartPole(cfg))
dummy_env = gym.make("CartPole-v1")

rllib_dir = Path(__file__).parent.parent.parent
print(f"rllib dir={rllib_dir}")
offline_file = os.path.join(rllib_dir, "tests/data/cartpole/large.json")

base_config = (
BCConfig()
# For offline RL, we do not specify an env here (b/c we don't want any env
# instances created on the EnvRunners). Instead, we'll provide observation-
# and action-spaces here for the RLModule to know its input- and output types.
.environment(
observation_space=dummy_env.observation_space,
action_space=dummy_env.action_space,
)
.offline_data(
input_=offline_file,
)
.multi_agent(
policies={"main"},
policy_mapping_fn=lambda *a, **kw: "main",
)
.evaluation(
evaluation_interval=1,
evaluation_num_env_runners=0,
evaluation_config=BCConfig.overrides(
# Evaluate on an actual env -> switch input back to "sampler".
input_="sampler",
# Do not explore during evaluation, but act greedily.
explore=False,
# Use a multi-agent setup for evaluation.
env="multi_cart",
env_config={"num_agents": args.num_agents},
policies={
"main": PolicySpec(),
"random": PolicySpec(policy_class=RandomPolicy),
},
# Only control agent 0 with the main (trained) policy.
policy_mapping_fn=(
lambda aid, *a, **kw: "main" if aid == 0 else "random"
),
# Note that we do NOT have to specify the `policies_to_train` here,
# b/c we are inside the evaluation config (no policy is trained during
# evaluation). The fact that the BCConfig above is "only" setup
# as single-agent makes it automatically only train the policy found in
# the BCConfig's `policies` field (which is "main").
# policies_to_train=["main"],
),
)
)

policy_eval_returns = (
f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/policy_reward_mean/"
)

stop = {
# Check for the "main" policy's episode return, not the combined one.
# The combined one is the sum of the "main" policy + the "random" one.
policy_eval_returns + "main": args.stop_reward,
NUM_ENV_STEPS_TRAINED: args.stop_timesteps,
TRAINING_ITERATION: args.stop_iters,
}

run_rllib_example_script_experiment(
base_config,
args,
stop=stop,
success_metric={policy_eval_returns + "main": args.stop_reward},
# We use a special progress reporter here to show the evaluation results (of the
# "main" policy).
# In the following dict, the keys are the (possibly nested) keys that can be
# found in RLlib's (BC's) result dict, produced at every training iteration, and
# the values are the column names you would like to see in your console reports.
# Note that for nested result dict keys, you need to use slashes "/" to define
# the exact path.
progress_reporter=tune.CLIReporter(
metric_columns={
TRAINING_ITERATION: "iter",
TIME_TOTAL_S: "total time (s)",
NUM_ENV_STEPS_TRAINED: "ts",
policy_eval_returns + "main": "eps. return (main)",
policy_eval_returns + "random": "eps. return (random)",
}
),
)
Loading

0 comments on commit b257b49

Please sign in to comment.