Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Cleanup examples folder 18: Add example script for offline RL (BC) training on single-agent, while evaluating w/ multi-agent setup. #46251

11 changes: 11 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3021,6 +3021,17 @@ py_test(

# subdirectory: offline_rl/
# ....................................

# @HybridAPIStack
py_test(
name = "examples/offline_rl/pretrain_bc_single_agent_evaluate_as_multi_agent",
main = "examples/offline_rl/pretrain_bc_single_agent_evaluate_as_multi_agent.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "large",
srcs = ["examples/offline_rl/pretrain_bc_single_agent_evaluate_as_multi_agent.py"],
args = ["--as-test"]
)

#@OldAPIStack
# TODO (sven): Doesn't seem to learn at the moment. Uncomment once fixed.
# py_test(
Expand Down
5 changes: 2 additions & 3 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,9 +1121,8 @@ def _evaluate_with_custom_eval_function(self):
def _evaluate_on_local_env_runner(self, env_runner):
if hasattr(env_runner, "input_reader") and env_runner.input_reader is None:
raise ValueError(
"Cannot evaluate on a local worker (wether there is no evaluation "
"EnvRunnerGroup OR no remote evaluation workers) in the Algorithm or "
"w/o an environment on that local worker!\nTry one of the following:"
"Can't evaluate on a local worker if this local worker does not have "
"an environment!\nTry one of the following:"
"\n1) Set `evaluation_interval` > 0 to force creating a separate "
"evaluation EnvRunnerGroup.\n2) Set `create_env_on_driver=True` to "
"force the local (non-eval) EnvRunner to have an environment to "
Expand Down
35 changes: 27 additions & 8 deletions rllib/algorithms/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from ray.rllib.utils.metrics import (
ALL_MODULES,
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
SAMPLE_TIMER,
SYNCH_WORKER_WEIGHTS_TIMER,
)
Expand Down Expand Up @@ -166,34 +168,51 @@ def training_step(self) -> ResultDict:
max_env_steps=self.config.train_batch_size,
)

train_batch = train_batch.as_multi_agent()
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
# TODO (sven): Use metrics API as soon as we moved to new API stack
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this one not using the MetricsLogger, yet? I use it in my overhaul of offline RL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you should. All good there. :)

But the hybrid API stack still goes through the summarize_episodes utility inside algorithm.py, which operates on the old RolloutMetrics objects returned by RolloutWorkers.

# (from currently hybrid stack).
# self.metrics.log_dict(
# {
# NUM_AGENT_STEPS_SAMPLED_LIFETIME: len(train_batch),
# NUM_ENV_STEPS_SAMPLED_LIFETIME: len(train_batch),
# },
# reduce="sum",
# )
self._counters[NUM_AGENT_STEPS_SAMPLED] += len(train_batch)
self._counters[NUM_ENV_STEPS_SAMPLED] += len(train_batch)

# Updating the policy.
train_results = self.learner_group.update_from_batch(batch=train_batch)
# TODO (sven): Use metrics API as soon as we moved to new API stack
# (from currently hybrid stack).
# self.metrics.log_dict(
# {
# NUM_AGENT_STEPS_TRAINED_LIFETIME: len(train_batch),
# NUM_ENV_STEPS_TRAINED_LIFETIME: len(train_batch),
# },
# reduce="sum",
# )
self._counters[NUM_AGENT_STEPS_TRAINED] += len(train_batch)
self._counters[NUM_ENV_STEPS_TRAINED] += len(train_batch)

# Synchronize weights.
# As the results contain for each policy the loss and in addition the
# total loss over all policies is returned, this total loss has to be
# removed.
policies_to_update = set(train_results.keys()) - {ALL_MODULES}

global_vars = {
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
}

# with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
if self.workers.num_remote_workers() > 0:
self.workers.sync_weights(
from_worker_or_learner_group=self.learner_group,
policies=policies_to_update,
global_vars=global_vars,
)
# Get weights from Learner to local worker.
else:
self.workers.local_worker().set_weights(
self.learner_group.get_weights()
)

# TODO (sven): Use metrics API as soon as we moved to new API stack
# (from currently hybrid stack).
return train_results
6 changes: 6 additions & 0 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,12 @@ def _update_from_batch_or_episodes(
},
env_steps=sum(len(e) for e in episodes),
)
# Have to convert to MultiAgentBatch.
elif isinstance(batch, SampleBatch):
assert len(self.module) == 1
batch = MultiAgentBatch(
{next(iter(self.module.keys())): batch}, env_steps=len(batch)
)

# Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs
# found in this batch. If not, throw an error.
Expand Down
4 changes: 4 additions & 0 deletions rllib/core/rl_module/marl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ def keys(self) -> KeysView[ModuleID]:
"""Returns a keys view over the module IDs in this MultiAgentRLModule."""
return self._rl_modules.keys()

def __len__(self) -> int:
"""Returns the number of RLModules within this MultiAgentRLModule."""
return len(self._rl_modules)

@override(RLModule)
def as_multi_agent(self) -> "MultiAgentRLModule":
"""Returns a multi-agent wrapper around this module.
Expand Down
4 changes: 2 additions & 2 deletions rllib/env/multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,10 +530,10 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def step(self, action_dict):
obs, rew, terminated, truncated, info = {}, {}, {}, {}, {}

# the environment is expecting action for at least one agent
# The environment is expecting an action for at least one agent.
if len(action_dict) == 0:
raise ValueError(
"The environment is expecting action for at least one agent."
"The environment is expecting an action for at least one agent."
)

for i, action in action_dict.items():
Expand Down
8 changes: 6 additions & 2 deletions rllib/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,12 @@ def summarize_episodes(
episode_rewards.append(episode.episode_reward)
for k, v in episode.custom_metrics.items():
custom_metrics[k].append(v)
for (_, policy_id), reward in episode.agent_rewards.items():
if policy_id != DEFAULT_POLICY_ID:
is_multi_agent = (
len(episode.agent_rewards) > 1
or DEFAULT_POLICY_ID not in episode.agent_rewards
)
if is_multi_agent:
for (_, policy_id), reward in episode.agent_rewards.items():
policy_rewards[policy_id].append(reward)
for k, v in episode.hist_data.items():
hist_stats[k] += v
Expand Down
20 changes: 11 additions & 9 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,17 +541,19 @@ def wrap(env):
):
pol._update_model_view_requirements_from_init_state()

self.multiagent: bool = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
if self.multiagent and self.env is not None:
if not isinstance(
if (
self.config.is_multi_agent()
and self.env is not None
and not isinstance(
self.env,
(BaseEnv, ExternalMultiAgentEnv, MultiAgentEnv, ray.actor.ActorHandle),
):
raise ValueError(
f"Have multiple policies {self.policy_map}, but the "
f"env {self.env} is not a subclass of BaseEnv, "
f"MultiAgentEnv, ActorHandle, or ExternalMultiAgentEnv!"
)
)
):
raise ValueError(
f"You are running a multi-agent setup, but the env {self.env} is not a "
f"subclass of BaseEnv, MultiAgentEnv, ActorHandle, or "
f"ExternalMultiAgentEnv!"
)

if self.worker_index == 0:
logger.info("Built filter map: {}".format(self.filters))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# @HybridAPIStack

"""Example showing how to train a (SA) BC RLModule while evaluating in a MA setup.

Here, SA=single-agent and MA=multi-agent.

Note that the BC Algorithm - by default - runs on the hybrid API stack, using RLModules,
but not EnvRunners or ConnectorV2s yet.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably adding here that is specifically doe snot use SingleAgentEpisode/MultiAgentEpisode?


This example:
- demonstrates how you can train a single-agent BC Policy (RLModule) from a JSON
file, which contains SampleBatch (expert or non-expert) data.
- shows how you can run evaluation in a multi-agent setup (for example vs one
or more heuristic policies), while training the BC Policy.


How to run this script
----------------------
`python [script file name].py --checkpoint-at-end`

For debugging, use the following additional command line options
`--no-tune --num-env-runners=0`
which should allow you to set breakpoints anywhere in the RLlib code and
have the execution stop there for inspection and debugging.

For logging to your WandB account, use:
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
--wandb-run-name=[optional: WandB run name (within the defined project)]`


Results to expect
-----------------
In the console output, you can see that the episode returns of the "main" policy on
the evaluation track keep increasing as BC manages to more and more clone the behavior
found in our (expert) JSON file.

After 50-100 iterations, you should see the episode reward reach 450.0.
Note that the opponent (random) policy does not learn as it's a) not a trainable
RLModule and b) not being trained via the BCConfig. It's only used for evaluation
purposes here.

+---------------------+------------+-----------------+--------+--------+
| Trial name | status | loc | iter | ts |
|---------------------+------------+-----------------+--------+--------+
| BC_None_ee65e_00000 | TERMINATED | 127.0.0.1:35031 | 93 | 203754 |
+---------------------+------------+-----------------+--------+--------+
+----------------------+------------------------+
| eps. return (main) | eps. return (random) |
|----------------------+------------------------|
| 452.4 | 28.3 |
+----------------------+------------------------+
"""
import os
from pathlib import Path

import gymnasium as gym

from ray import tune
from ray.air.constants import TRAINING_ITERATION
from ray.rllib.algorithms.bc import BCConfig
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
from ray.rllib.examples._old_api_stack.policy.random_policy import RandomPolicy
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EVALUATION_RESULTS,
NUM_ENV_STEPS_TRAINED,
)
from ray.rllib.utils.test_utils import (
add_rllib_example_script_args,
run_rllib_example_script_experiment,
)
from ray.train.constants import TIME_TOTAL_S
from ray.tune.registry import register_env

parser = add_rllib_example_script_args(
default_reward=450.0,
default_timesteps=300000,
)
parser.set_defaults(num_agents=2)


if __name__ == "__main__":
args = parser.parse_args()

register_env("multi_cart", lambda cfg: MultiAgentCartPole(cfg))
dummy_env = gym.make("CartPole-v1")

offline_file = str(
Path(os.path.join("../../", "tests/data/cartpole/large.json")).resolve()
)

base_config = (
BCConfig()
.environment(
observation_space=dummy_env.observation_space,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we give a quick note, why in this case the user needs to provide the spaces?

action_space=dummy_env.action_space,
)
.offline_data(
input_=offline_file,
)
.multi_agent(
policies={"main"},
policy_mapping_fn=lambda *a, **kw: "main",
)
.evaluation(
evaluation_interval=1,
evaluation_num_env_runners=0,
evaluation_config=BCConfig.overrides(
# Evaluate on an actual env -> switch input back to "sampler".
input_="sampler",
# Do not explore during evaluation, but act greedily.
explore=False,
# Use a multi-agent setup for evaluation.
env="multi_cart",
env_config={"num_agents": args.num_agents},
policies={
"main": PolicySpec(),
"random": PolicySpec(policy_class=RandomPolicy),
},
# Only control agent 0 with the main (trained) policy.
policy_mapping_fn=(
lambda aid, *a, **kw: "main" if aid == 0 else "random"
),
# Note that we do NOT have to specify the `policies_to_train` here,
# b/c we are inside the evaluation config (no policy is trained during
# evaluation). The fact that the BCConfig above is "only" setup
# as single-agent makes it automatically only train the policy found in
# the BCConfig's `policies` field (which is "main").
# policies_to_train=["main"],
),
)
)

policy_eval_returns = (
f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/policy_reward_mean/"
)

stop = {
# Check for the "main" policy's episode return, not the combined one.
# The combined one is the sum of the "main" policy + the "random" one.
policy_eval_returns + "main": args.stop_reward,
NUM_ENV_STEPS_TRAINED: args.stop_timesteps,
TRAINING_ITERATION: args.stop_iters,
}

run_rllib_example_script_experiment(
base_config,
args,
stop=stop,
# We use a special progress reporter here to show the evaluation results (of the
# "main" policy).
# In the following dict, the keys are the (possibly nested) keys that can be
# found in RLlib's (BC's) result dict, produced at every training iteration, and
# the values are the column names you would like to see in your console reports.
# Note that for nested result dict keys, you need to use slashes "/" to define
# the exact path.
progress_reporter=tune.CLIReporter(
metric_columns={
TRAINING_ITERATION: "iter",
TIME_TOTAL_S: "total time (s)",
NUM_ENV_STEPS_TRAINED: "ts",
policy_eval_returns + "main": "eps. return (main)",
policy_eval_returns + "random": "eps. return (random)",
}
),
)
Loading
Loading