Skip to content

Commit

Permalink
[RLlib] Prioritized multi-agent episode replay buffer. (#45576)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonsays1980 authored Jun 14, 2024
1 parent d844d63 commit a7aa5e4
Show file tree
Hide file tree
Showing 17 changed files with 2,148 additions and 716 deletions.
61 changes: 38 additions & 23 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,15 @@ py_test(
args = ["--as-test", "--enable-new-api-stack"]
)

py_test(
name = "learning_tests_multi_agent_cartpole_dqn",
main = "tuned_examples/dqn/multi_agent_cartpole_dqn.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_cartpole", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"],
size = "large",
srcs = ["tuned_examples/dqn/multi_agent_cartpole_dqn.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--num-cpus=4"]
)

#@OldAPIStack
py_test(
name = "learning_tests_cartpole_dqn_softq_old_api_stack",
Expand Down Expand Up @@ -452,6 +461,24 @@ py_test(
args = ["--as-test", "--enable-new-api-stack"]
)

py_test(
name = "learning_tests_multi_agent_pendulum_sac",
main = "tuned_examples/sac/multi_agent_pendulum_sac.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_pendulum", "learning_tests_continuous"],
size = "large",
srcs = ["tuned_examples/sac/multi_agent_pendulum_sac.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--num-cpus=4"]
)

py_test(
name = "learning_tests_multi_agent_pendulum_sac_multi_gpu",
main = "tuned_examples/sac/multi_agent_pendulum_sac.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_pendulum", "learning_tests_continuous", "multi_gpu"],
size = "large",
srcs = ["tuned_examples/sac/multi_agent_pendulum_sac.py"],
args = ["--enable-new-api-stack", "--num-agents=2", "--num-gpus=2"]
)

#@OldAPIStack
py_test(
name = "learning_tests_cartpole_sac_old_api_stack",
Expand All @@ -463,25 +490,6 @@ py_test(
args = ["--dir=tuned_examples/sac"]
)

# TODO (simon): These tests are not learning, yet.
# py_test(
# name = "learning_tests_multi_agent_pendulum_sac",
# main = "tuned_examples/sac/multi_agent_pendulum_sac.py",
# tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_pendulum", "learning_tests_continuous"],
# size = "large",
# srcs = ["tuned_examples/sac/multi_agent_pendulum_sac.py"],
# args = ["--enable-new-api-stack", "--num-agents=2"]
# )

# py_test(
# name = "learning_tests_multi_agent_pendulum_sac_multi_gpu",
# main = "tuned_examples/sac/multi_agent_pendulum_sac.py",
# tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_pendulum", "learning_tests_continuous", "multi_gpu"],
# size = "large",
# srcs = ["tuned_examples/sac/multi_agent_pendulum_sac.py"],
# args = ["--enable-new-api-stack", "--num-agents=2", "--num-gpus=2"]
# )

# --------------------------------------------------------------------
# Algorithms (Compilation, Losses, simple functionality tests)
# rllib/algorithms/
Expand Down Expand Up @@ -1549,10 +1557,10 @@ py_test(
)

py_test(
name = "test_multi_agent_episode_replay_buffer",
name = "test_multi_agent_episode_buffer",
tags = ["team:rllib", "utils"],
size = "small",
srcs = ["utils/replay_buffers/tests/test_multi_agent_episode_replay_buffer.py"]
srcs = ["utils/replay_buffers/tests/test_multi_agent_episode_buffer.py"]
)

py_test(
Expand All @@ -1562,6 +1570,13 @@ py_test(
srcs = ["utils/replay_buffers/tests/test_multi_agent_mixin_replay_buffer.py"]
)

py_test(
name = "test_multi_agent_prio_episode_buffer",
tags = ["team:rllib", "utils"],
size = "small",
srcs = ["utils/replay_buffers/tests/test_multi_agent_prio_episode_buffer.py"]
)

py_test(
name = "test_multi_agent_prioritized_replay_buffer",
tags = ["team:rllib", "utils"],
Expand All @@ -1577,10 +1592,10 @@ py_test(
)

py_test(
name = "test_prioritized_episode_replay_buffer",
name = "test_prioritized_episode_buffer",
tags = ["team::rllib", "utils"],
size = "small",
srcs = ["utils/replay_buffers/tests/test_prioritized_episode_replay_buffer.py"]
srcs = ["utils/replay_buffers/tests/test_prioritized_episode_buffer.py"]
)

py_test(
Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def validate(self) -> None:
"EpisodeReplayBuffer",
"PrioritizedEpisodeReplayBuffer",
"MultiAgentEpisodeReplayBuffer",
"MultiAgentPrioritizedEpisodeReplayBuffer",
]:
raise ValueError(
"When using the new `EnvRunner API` the replay buffer must be of type "
Expand Down
17 changes: 11 additions & 6 deletions rllib/algorithms/sac/torch/sac_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,6 @@ def compute_loss_for_module(
def compute_gradients(
self, loss_per_module: Dict[str, TensorType], **kwargs
) -> ParamDict:
# Set all grads to `None`.
for optim in self._optimizer_parameters:
optim.zero_grad(set_to_none=True)

grads = {}

for module_id in set(loss_per_module.keys()) - {ALL_MODULES}:
Expand All @@ -314,14 +310,23 @@ def compute_gradients(
for component in (
["qf", "policy", "alpha"] + ["qf_twin"] if config.twin_q else []
):
# Get the optimizer for the current component and module.
optim = self.get_optimizer(module_id, component)
# Zero the gradients. Note, we need to reset the gradients b/c
# each component for a module operates on the same graph.
optim.zero_grad(set_to_none=True)
# Compute the gradients for the component and module.
self.metrics.peek((module_id, component + "_loss")).backward(
retain_graph=True
)
# Store the gradients for the component and module.
# TODO (simon): Check another time the graph for overlapping
# gradients.
grads.update(
{
pid: p.grad
pid: p.grad.clone()
for pid, p in self.filter_param_dict_for_optimizer(
self._params, self.get_optimizer(module_id, component)
self._params, optim
).items()
}
)
Expand Down
13 changes: 11 additions & 2 deletions rllib/tuned_examples/dqn/multi_agent_cartpole_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# Use `parser` to add your own custom command line options to this script
# and (if needed) use their values to set up `config` below.
args = parser.parse_args()

parser.set_defaults(num_agents=2)
register_env(
"multi_agent_cartpole",
lambda _: MultiAgentCartPole({"num_agents": args.num_agents}),
Expand All @@ -27,8 +27,10 @@
# Settings identical to old stack.
train_batch_size_per_learner=32,
replay_buffer_config={
"type": "MultiAgentEpisodeReplayBuffer",
"type": "MultiAgentPrioritizedEpisodeReplayBuffer",
"capacity": 50000,
"alpha": 0.6,
"beta": 0.4,
},
n_step=3,
double_q=True,
Expand Down Expand Up @@ -61,6 +63,13 @@
}

if __name__ == "__main__":
assert (
args.num_agents > 0
), "The `--num-agents` arg must be > 0 for this script to work."
assert (
args.enable_new_api_stack
), "The `--enable-new-api-stack` arg must be activated for this script to work."

from ray.rllib.utils.test_utils import run_rllib_example_script_experiment

run_rllib_example_script_experiment(config, args, stop=stop)
24 changes: 21 additions & 3 deletions rllib/tuned_examples/sac/multi_agent_pendulum_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,16 @@
config = (
SACConfig()
.environment(env="multi_agent_pendulum")
.env_runners(num_env_runners=2)
.rl_module(
model_config_dict={
"fcnet_hiddens": [256, 256],
"fcnet_activation": "relu",
"post_fcnet_hiddens": [],
"post_fcnet_activation": None,
"post_fcnet_weights_initializer": "orthogonal_",
"post_fcnet_weights_initializer_config": {"gain": 0.01},
}
)
.training(
initial_alpha=1.001,
lr=3e-4,
Expand All @@ -35,10 +44,12 @@
train_batch_size_per_learner=256,
target_network_update_freq=1,
replay_buffer_config={
"type": "MultiAgentEpisodeReplayBuffer",
"type": "MultiAgentPrioritizedEpisodeReplayBuffer",
"capacity": 100000,
"alpha": 1.0,
"beta": 0.0,
},
num_steps_sampled_before_learning_starts=1024,
num_steps_sampled_before_learning_starts=256,
)
.rl_module(
model_config_dict={
Expand Down Expand Up @@ -70,6 +81,13 @@
}

if __name__ == "__main__":
assert (
args.num_agents > 0
), "The `--num-agents` arg must be > 0 for this script to work."
assert (
args.enable_new_api_stack
), "The `--enable-new-api-stack` arg must be activated for this script to work."

from ray.rllib.utils.test_utils import run_rllib_example_script_experiment

run_rllib_example_script_experiment(config, args, stop=stop)
2 changes: 1 addition & 1 deletion rllib/tuned_examples/sac/pendulum_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
default_reward=-250.0,
)
# Use `parser` to add your own custom command line options to this script
# and (if needed) use their values toset up `config` below.
# and (if needed) use their values to set up `config` below.
args = parser.parse_args()

config = (
Expand Down
8 changes: 6 additions & 2 deletions rllib/utils/replay_buffers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
from ray.rllib.utils.replay_buffers.multi_agent_mixin_replay_buffer import (
MultiAgentMixInReplayBuffer,
)
from ray.rllib.utils.replay_buffers.multi_agent_episode_replay_buffer import (
from ray.rllib.utils.replay_buffers.multi_agent_episode_buffer import (
MultiAgentEpisodeReplayBuffer,
)
from ray.rllib.utils.replay_buffers.multi_agent_prioritized_episode_buffer import (
MultiAgentPrioritizedEpisodeReplayBuffer,
)
from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import (
MultiAgentPrioritizedReplayBuffer,
)
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer,
ReplayMode,
)
from ray.rllib.utils.replay_buffers.prioritized_episode_replay_buffer import (
from ray.rllib.utils.replay_buffers.prioritized_episode_buffer import (
PrioritizedEpisodeReplayBuffer,
)
from ray.rllib.utils.replay_buffers.prioritized_replay_buffer import (
Expand All @@ -28,6 +31,7 @@
"FifoReplayBuffer",
"MultiAgentEpisodeReplayBuffer",
"MultiAgentMixInReplayBuffer",
"MultiAgentPrioritizedEpisodeReplayBuffer",
"MultiAgentPrioritizedReplayBuffer",
"MultiAgentReplayBuffer",
"PrioritizedEpisodeReplayBuffer",
Expand Down
44 changes: 41 additions & 3 deletions rllib/utils/replay_buffers/episode_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,10 @@ def get_num_episodes(self) -> int:
"""Returns number of episodes (completed or truncated) stored in the buffer."""
return len(self.episodes)

def get_num_episodes_evicted(self) -> int:
"""Returns number of episodes that have been evicted from the buffer."""
return self._num_episodes_evicted

def get_num_timesteps(self) -> int:
"""Returns number of individual timesteps stored in the buffer."""
return len(self._indices)
Expand All @@ -350,6 +354,15 @@ def get_added_timesteps(self) -> int:

@override(ReplayBufferInterface)
def get_state(self) -> Dict[str, Any]:
"""Gets a pickable state of the buffer.
This is used for checkpointing the buffer's state. It is specifically helpful,
for example, when a trial is paused and resumed later on. The buffer's state
can be saved to disk and reloaded when the trial is resumed.
Returns:
A dict containing all necessary information to restore the buffer's state.
"""
return {
"episodes": [eps.get_state() for eps in self.episodes],
"episode_id_to_index": list(self.episode_id_to_index.items()),
Expand All @@ -362,12 +375,37 @@ def get_state(self) -> Dict[str, Any]:

@override(ReplayBufferInterface)
def set_state(self, state) -> None:
self.episodes = deque(
[SingleAgentEpisode.from_state(eps_data) for eps_data in state["episodes"]]
)
"""Sets the state of a buffer from a previously stored state.
See `get_state()` for more information on what is stored in the state. This
method is used to restore the buffer's state from a previously stored state.
It is specifically helpful, for example, when a trial is paused and resumed
later on. The buffer's state can be saved to disk and reloaded when the trial
is resumed.
Args:
state: The state to restore the buffer from.
"""
self._set_episodes(state)
self.episode_id_to_index = dict(state["episode_id_to_index"])
self._num_episodes_evicted = state["_num_episodes_evicted"]
self._indices = state["_indices"]
self._num_timesteps = state["_num_timesteps"]
self._num_timesteps_added = state["_num_timesteps_added"]
self.sampled_timesteps = state["sampled_timesteps"]

def _set_episodes(self, state) -> None:
"""Sets the episodes from the state.
Note, this method is used for class inheritance purposes. It is specifically
helpful when a subclass of this class wants to override the behavior of how
episodes are set from the state. By default, it sets `SingleAgentEpuisode`s,
but subclasses can override this method to set episodes of a different type.
"""
if not self.episodes:
self.episodes = deque(
[
SingleAgentEpisode.from_state(eps_data)
for eps_data in state["episodes"]
]
)
Loading

0 comments on commit a7aa5e4

Please sign in to comment.