Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Prioritized version of an episode-based replay buffer. #42832

Merged
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
bcdabd0
Initialized prioritized episode replay buffer.
simonsays1980 Jan 25, 2024
1409857
Added 'add' method for replay buffer.
simonsays1980 Jan 26, 2024
8eedc7b
Added sampling method for 'PrioritizedEpisodeReplayBuffer'. Needs to …
simonsays1980 Jan 29, 2024
3a58ae3
LINTER.
simonsays1980 Jan 29, 2024
79fd9e6
Modified 'add()' and 'sample()' methods in 'PrioritiztedEpisodeRepla…
simonsays1980 Jan 30, 2024
ee89408
Added docstrings and a couple of todos.
simonsays1980 Jan 30, 2024
f364e1e
Added tests for 'update_priorities'.
simonsays1980 Jan 30, 2024
6f24a56
Added optional sampling of 'n_step' in 'sample()'-method.
simonsays1980 Jan 31, 2024
b4939a5
Merge branch 'master' into prioritized-episode-replay-buffer
simonsays1980 Jan 31, 2024
ae97345
Added a util function to update priorities if multi-agent setting is …
simonsays1980 Jan 31, 2024
41025df
Corrected some typos and added more comments.
simonsays1980 Jan 31, 2024
d8fbe57
Added changes from @sven1977's review.
simonsays1980 Jan 31, 2024
2bf19d3
Update rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py
simonsays1980 Jan 31, 2024
81366ad
Merge branch 'master' into prioritized-episode-replay-buffer
sven1977 Jan 31, 2024
d337fc0
Made lines shorted and added tests to 'rllib/BUILD'.
simonsays1980 Feb 1, 2024
d77e9f8
Merge branch 'master' into prioritized-episode-replay-buffer
simonsays1980 Feb 1, 2024
cd3cab0
CHanged line length after merge.
simonsays1980 Feb 1, 2024
e1a1659
Introduced discounting in 'sample()' for n-step.
simonsays1980 Feb 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3409,7 +3409,8 @@ def get_marl_module_spec(
"is passed in nor in the default module spec used in "
"the algorithm."
)

# TODO (sven): Find a good way to pack module specific parameters from
# the algorithms into the `model_config_dict`.
if module_spec.observation_space is None:
module_spec.observation_space = policy_spec.observation_space
if module_spec.action_space is None:
Expand Down
4 changes: 4 additions & 0 deletions rllib/utils/replay_buffers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
MultiAgentReplayBuffer,
ReplayMode,
)
from ray.rllib.utils.replay_buffers.prioritized_episode_replay_buffer import (
PrioritizedEpisodeReplayBuffer,
)
from ray.rllib.utils.replay_buffers.prioritized_replay_buffer import (
PrioritizedReplayBuffer,
)
Expand All @@ -23,6 +26,7 @@
"MultiAgentMixInReplayBuffer",
"MultiAgentPrioritizedReplayBuffer",
"MultiAgentReplayBuffer",
"PrioritizedEpisodeReplayBuffer",
"PrioritizedReplayBuffer",
"ReplayMode",
"ReplayBuffer",
Expand Down
6 changes: 4 additions & 2 deletions rllib/utils/replay_buffers/episode_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,16 @@ def add(self, episodes: Union[List["SingleAgentEpisode"], "SingleAgentEpisode"])
eps_idx = self.episode_id_to_index[eps.id_]
existing_eps = self.episodes[eps_idx - self._num_episodes_evicted]
old_len = len(existing_eps)
self._indices.extend([(eps_idx, old_len + i) for i in range(len(eps))])
self._indices.extend(
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
[(eps_idx, old_len + i, None) for i in range(len(eps))]
)
existing_eps.concat_episode(eps)
# New episode. Add to end of our episodes deque.
else:
self.episodes.append(eps)
eps_idx = len(self.episodes) - 1 + self._num_episodes_evicted
self.episode_id_to_index[eps.id_] = eps_idx
self._indices.extend([(eps_idx, i) for i in range(len(eps))])
self._indices.extend([(eps_idx, i, None) for i in range(len(eps))])

# Eject old records from front of deque (only if we have more than 1 episode
# in the buffer).
Expand Down
Loading
Loading