Skip to content

Commit

Permalink
[RLlib; Off-policy] Add episode sampling to EpisodeReplayBuffer. (r…
Browse files Browse the repository at this point in the history
…ay-project#47500)

Signed-off-by: ujjawal-khare <[email protected]>
  • Loading branch information
simonsays1980 authored and ujjawal-khare committed Oct 15, 2024
1 parent f3e8e1f commit 5ecfd27
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 1 deletion.
1 change: 1 addition & 0 deletions rllib/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
n_step=self.config.n_step,
gamma=self.config.gamma,
beta=self.config.replay_buffer_config.get("beta"),
sample_episodes=True,
)

# Perform an update on the buffer-sampled train batch.
Expand Down
257 changes: 256 additions & 1 deletion rllib/utils/replay_buffers/episode_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from collections import deque
import copy
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import scipy

from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.utils.annotations import override
Expand Down Expand Up @@ -211,6 +212,87 @@ def sample(
*,
batch_size_B: Optional[int] = None,
batch_length_T: Optional[int] = None,
n_step: Optional[Union[int, Tuple]] = None,
beta: float = 0.0,
gamma: float = 0.99,
include_infos: bool = False,
include_extra_model_outputs: bool = False,
sample_episodes: Optional[bool] = False,
**kwargs,
) -> Union[SampleBatchType, SingleAgentEpisode]:
"""Samples from a buffer in a randomized way.
Each sampled item defines a transition of the form:
`(o_t, a_t, sum(r_(t+1:t+n+1)), o_(t+n), terminated_(t+n), truncated_(t+n))`
where `o_t` is drawn by randomized sampling.`n` is defined by the `n_step`
applied.
If requested, `info`s of a transitions last timestep `t+n` and respective
extra model outputs (e.g. action log-probabilities) are added to
the batch.
Args:
num_items: Number of items (transitions) to sample from this
buffer.
batch_size_B: The number of rows (transitions) to return in the
batch
batch_length_T: THe sequence length to sample. At this point in time
only sequences of length 1 are possible.
n_step: The n-step to apply. For the default the batch contains in
`"new_obs"` the observation and in `"obs"` the observation `n`
time steps before. The reward will be the sum of rewards
collected in between these two observations and the action will
be the one executed n steps before such that we always have the
state-action pair that triggered the rewards.
If `n_step` is a tuple, it is considered as a range to sample
from. If `None`, we use `n_step=1`.
gamma: The discount factor to be used when applying n-step calculations.
The default of `0.99` should be replaced by the `Algorithm`s
discount factor.
include_infos: A boolean indicating, if `info`s should be included in
the batch. This could be of advantage, if the `info` contains
values from the environment important for loss computation. If
`True`, the info at the `"new_obs"` in the batch is included.
include_extra_model_outputs: A boolean indicating, if
`extra_model_outputs` should be included in the batch. This could be
of advantage, if the `extra_mdoel_outputs` contain outputs from the
model important for loss computation and only able to compute with the
actual state of model e.g. action log-probabilities, etc.). If `True`,
the extra model outputs at the `"obs"` in the batch is included (the
timestep at which the action is computed).
Returns:
Either a batch with transitions in each row or (if `return_episodes=True`)
a list of 1-step long episodes containing all basic episode data and if
requested infos and extra model outputs.
"""

if sample_episodes:
return self._sample_episodes(
num_items=num_items,
batch_size_B=batch_size_B,
batch_length_T=batch_length_T,
n_step=n_step,
beta=beta,
gamma=gamma,
include_infos=include_infos,
include_extra_model_outputs=include_extra_model_outputs,
)
else:
return self._sample_batch(
num_items=num_items,
batch_size_B=batch_size_B,
batch_length_T=batch_length_T,
)

def _sample_batch(
self,
num_items: Optional[int] = None,
*,
batch_size_B: Optional[int] = None,
batch_length_T: Optional[int] = None,
) -> SampleBatchType:
"""Returns a batch of size B (number of "rows"), where each row has length T.
Expand Down Expand Up @@ -332,6 +414,179 @@ def sample(

return ret

def _sample_episodes(
self,
num_items: Optional[int] = None,
*,
batch_size_B: Optional[int] = None,
batch_length_T: Optional[int] = None,
n_step: Optional[Union[int, Tuple]] = None,
gamma: float = 0.99,
include_infos: bool = False,
include_extra_model_outputs: bool = False,
**kwargs,
) -> List[SingleAgentEpisode]:
"""Samples episodes from a buffer in a randomized way.
Each sampled item defines a transition of the form:
`(o_t, a_t, sum(r_(t+1:t+n+1)), o_(t+n), terminated_(t+n), truncated_(t+n))`
where `o_t` is drawn by randomized sampling.`n` is defined by the `n_step`
applied.
If requested, `info`s of a transitions last timestep `t+n` and respective
extra model outputs (e.g. action log-probabilities) are added to
the batch.
Args:
num_items: Number of items (transitions) to sample from this
buffer.
batch_size_B: The number of rows (transitions) to return in the
batch
batch_length_T: THe sequence length to sample. At this point in time
only sequences of length 1 are possible.
n_step: The n-step to apply. For the default the batch contains in
`"new_obs"` the observation and in `"obs"` the observation `n`
time steps before. The reward will be the sum of rewards
collected in between these two observations and the action will
be the one executed n steps before such that we always have the
state-action pair that triggered the rewards.
If `n_step` is a tuple, it is considered as a range to sample
from. If `None`, we use `n_step=1`.
gamma: The discount factor to be used when applying n-step calculations.
The default of `0.99` should be replaced by the `Algorithm`s
discount factor.
include_infos: A boolean indicating, if `info`s should be included in
the batch. This could be of advantage, if the `info` contains
values from the environment important for loss computation. If
`True`, the info at the `"new_obs"` in the batch is included.
include_extra_model_outputs: A boolean indicating, if
`extra_model_outputs` should be included in the batch. This could be
of advantage, if the `extra_mdoel_outputs` contain outputs from the
model important for loss computation and only able to compute with the
actual state of model e.g. action log-probabilities, etc.). If `True`,
the extra model outputs at the `"obs"` in the batch is included (the
timestep at which the action is computed).
Returns:
A list of 1-step long episodes containing all basic episode data and if
requested infos and extra model outputs.
"""
if num_items is not None:
assert batch_size_B is None, (
"Cannot call `sample()` with both `num_items` and `batch_size_B` "
"provided! Use either one."
)
batch_size_B = num_items

# Use our default values if no sizes/lengths provided.
batch_size_B = batch_size_B or self.batch_size_B
# TODO (simon): Implement trajectory sampling for RNNs.
batch_length_T = batch_length_T or self.batch_length_T

# Sample the n-step if necessary.
actual_n_step = n_step or 1
random_n_step = isinstance(n_step, tuple)

# Keep track of the indices that were sampled last for updating the
# weights later (see `ray.rllib.utils.replay_buffer.utils.
# update_priorities_in_episode_replay_buffer`).
self._last_sampled_indices = []

sampled_episodes = []

B = 0
while B < batch_size_B:
# Pull a new uniform random index tuple: (eps_idx, ts_in_eps_idx).
index_tuple = self._indices[self.rng.integers(len(self._indices))]

# Compute the actual episode index (offset by the number of
# already evicted episodes).
episode_idx, episode_ts = (
index_tuple[0] - self._num_episodes_evicted,
index_tuple[1],
)
episode = self.episodes[episode_idx]

# If we use random n-step sampling, draw the n-step for this item.
if random_n_step:
actual_n_step = int(self.rng.integers(n_step[0], n_step[1]))

# Skip, if we are too far to the end and `episode_ts` + n_step would go
# beyond the episode's end.
if episode_ts + actual_n_step > len(episode):
continue

# Note, this will be the reward after executing action
# `a_(episode_ts-n_step+1)`. For `n_step>1` this will be the discounted
# sum of all discounted rewards that were collected over the last n steps.
raw_rewards = episode.get_rewards(
slice(episode_ts, episode_ts + actual_n_step)
)
rewards = scipy.signal.lfilter([1], [1, -gamma], raw_rewards[::-1], axis=0)[
-1
]

# Generate the episode to be returned.
sampled_episode = SingleAgentEpisode(
# Ensure that each episode contains a tuple of the form:
# (o_t, a_t, sum(r_(t:t+n_step)), o_(t+n_step))
# Two observations (t and t+n).
observations=[
episode.get_observations(episode_ts),
episode.get_observations(episode_ts + actual_n_step),
],
observation_space=episode.observation_space,
infos=(
[
episode.get_infos(episode_ts),
episode.get_infos(episode_ts + actual_n_step),
]
if include_infos
else None
),
actions=[episode.get_actions(episode_ts)],
action_space=episode.action_space,
rewards=[rewards],
# If the sampled time step is the episode's last time step check, if
# the episode is terminated or truncated.
terminated=(
False
if episode_ts + actual_n_step < len(episode)
else episode.is_terminated
),
truncated=(
False
if episode_ts + actual_n_step < len(episode)
else episode.is_truncated
),
extra_model_outputs={
# TODO (simon): Check, if we have to correct here for sequences
# later.
"n_step": [actual_n_step],
**(
{
k: [episode.get_extra_model_outputs(k, episode_ts)]
for k in episode.extra_model_outputs.keys()
}
if include_extra_model_outputs
else {}
),
},
# TODO (sven): Support lookback buffers.
len_lookback_buffer=0,
t_started=episode_ts,
)
sampled_episodes.append(sampled_episode)

# Increment counter.
B += 1

self.sampled_timesteps += batch_size_B

return sampled_episodes

def get_num_episodes(self) -> int:
"""Returns number of episodes (completed or truncated) stored in the buffer."""
return len(self.episodes)
Expand Down
1 change: 1 addition & 0 deletions rllib/utils/replay_buffers/prioritized_episode_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def sample(
gamma: float = 0.99,
include_infos: bool = False,
include_extra_model_outputs: bool = False,
**kwargs,
) -> SampleBatchType:
"""Samples from a buffer in a prioritized way.
Expand Down
49 changes: 49 additions & 0 deletions rllib/utils/replay_buffers/tests/test_episode_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
EpisodeReplayBuffer,
)

from ray.rllib.utils.test_utils import check


class TestEpisodeReplayBuffer(unittest.TestCase):
@staticmethod
Expand Down Expand Up @@ -139,6 +141,53 @@ def test_episode_replay_buffer_sample_logic(self):
# (reset rewards).
assert np.all(np.where(is_terminated[:, :-1], rewards[:, 1:] == 0.0, True))

def test_episode_replay_buffer_episode_sample_logic(self):

buffer = EpisodeReplayBuffer(capacity=10000)

for _ in range(200):
episode = self._get_episode()
buffer.add(episode)

for i in range(1000):
sample = buffer.sample(batch_size_B=16, n_step=1, sample_episodes=True)
check(buffer.get_sampled_timesteps(), 16 * (i + 1))
for eps in sample:

(
obs,
action,
reward,
next_obs,
is_terminated,
is_truncated,
n_step,
) = (
eps.get_observations(0),
eps.get_actions(-1),
eps.get_rewards(-1),
eps.get_observations(-1),
eps.is_terminated,
eps.is_truncated,
eps.get_extra_model_outputs("n_step", -1),
)

# Make sure terminated and truncated are never both True.
assert not (is_truncated and is_terminated)

# Note, floating point numbers cannot be compared directly.
tolerance = 1e-8
# Assert that actions correspond to the observations.
check(obs, action, atol=tolerance)
# Assert that next observations are correctly one step after
# observations.
check(next_obs, obs + 1, atol=tolerance)
# Assert that the reward comes from the next observation.
check(reward * 10, next_obs, atol=tolerance)

# Assert that all n-steps are 1.0 as passed into `sample`.
check(n_step, 1.0, atol=tolerance)


if __name__ == "__main__":
import pytest
Expand Down

0 comments on commit 5ecfd27

Please sign in to comment.