diff --git a/rllib/algorithms/dqn/dqn.py b/rllib/algorithms/dqn/dqn.py index 013508ab7c2e..6235ccc3d335 100644 --- a/rllib/algorithms/dqn/dqn.py +++ b/rllib/algorithms/dqn/dqn.py @@ -719,6 +719,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: n_step=self.config.n_step, gamma=self.config.gamma, beta=self.config.replay_buffer_config.get("beta"), + sample_episodes=True, ) # Perform an update on the buffer-sampled train batch. diff --git a/rllib/utils/replay_buffers/episode_replay_buffer.py b/rllib/utils/replay_buffers/episode_replay_buffer.py index c61b76596e4f..90ae1fa35664 100644 --- a/rllib/utils/replay_buffers/episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/episode_replay_buffer.py @@ -1,8 +1,9 @@ from collections import deque import copy -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np +import scipy from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.utils.annotations import override @@ -211,6 +212,87 @@ def sample( *, batch_size_B: Optional[int] = None, batch_length_T: Optional[int] = None, + n_step: Optional[Union[int, Tuple]] = None, + beta: float = 0.0, + gamma: float = 0.99, + include_infos: bool = False, + include_extra_model_outputs: bool = False, + sample_episodes: Optional[bool] = False, + **kwargs, + ) -> Union[SampleBatchType, SingleAgentEpisode]: + """Samples from a buffer in a randomized way. + + Each sampled item defines a transition of the form: + + `(o_t, a_t, sum(r_(t+1:t+n+1)), o_(t+n), terminated_(t+n), truncated_(t+n))` + + where `o_t` is drawn by randomized sampling.`n` is defined by the `n_step` + applied. + + If requested, `info`s of a transitions last timestep `t+n` and respective + extra model outputs (e.g. action log-probabilities) are added to + the batch. + + Args: + num_items: Number of items (transitions) to sample from this + buffer. + batch_size_B: The number of rows (transitions) to return in the + batch + batch_length_T: THe sequence length to sample. At this point in time + only sequences of length 1 are possible. + n_step: The n-step to apply. For the default the batch contains in + `"new_obs"` the observation and in `"obs"` the observation `n` + time steps before. The reward will be the sum of rewards + collected in between these two observations and the action will + be the one executed n steps before such that we always have the + state-action pair that triggered the rewards. + If `n_step` is a tuple, it is considered as a range to sample + from. If `None`, we use `n_step=1`. + gamma: The discount factor to be used when applying n-step calculations. + The default of `0.99` should be replaced by the `Algorithm`s + discount factor. + include_infos: A boolean indicating, if `info`s should be included in + the batch. This could be of advantage, if the `info` contains + values from the environment important for loss computation. If + `True`, the info at the `"new_obs"` in the batch is included. + include_extra_model_outputs: A boolean indicating, if + `extra_model_outputs` should be included in the batch. This could be + of advantage, if the `extra_mdoel_outputs` contain outputs from the + model important for loss computation and only able to compute with the + actual state of model e.g. action log-probabilities, etc.). If `True`, + the extra model outputs at the `"obs"` in the batch is included (the + timestep at which the action is computed). + + Returns: + Either a batch with transitions in each row or (if `return_episodes=True`) + a list of 1-step long episodes containing all basic episode data and if + requested infos and extra model outputs. + """ + + if sample_episodes: + return self._sample_episodes( + num_items=num_items, + batch_size_B=batch_size_B, + batch_length_T=batch_length_T, + n_step=n_step, + beta=beta, + gamma=gamma, + include_infos=include_infos, + include_extra_model_outputs=include_extra_model_outputs, + ) + else: + return self._sample_batch( + num_items=num_items, + batch_size_B=batch_size_B, + batch_length_T=batch_length_T, + ) + + def _sample_batch( + self, + num_items: Optional[int] = None, + *, + batch_size_B: Optional[int] = None, + batch_length_T: Optional[int] = None, ) -> SampleBatchType: """Returns a batch of size B (number of "rows"), where each row has length T. @@ -332,6 +414,179 @@ def sample( return ret + def _sample_episodes( + self, + num_items: Optional[int] = None, + *, + batch_size_B: Optional[int] = None, + batch_length_T: Optional[int] = None, + n_step: Optional[Union[int, Tuple]] = None, + gamma: float = 0.99, + include_infos: bool = False, + include_extra_model_outputs: bool = False, + **kwargs, + ) -> List[SingleAgentEpisode]: + """Samples episodes from a buffer in a randomized way. + + Each sampled item defines a transition of the form: + + `(o_t, a_t, sum(r_(t+1:t+n+1)), o_(t+n), terminated_(t+n), truncated_(t+n))` + + where `o_t` is drawn by randomized sampling.`n` is defined by the `n_step` + applied. + + If requested, `info`s of a transitions last timestep `t+n` and respective + extra model outputs (e.g. action log-probabilities) are added to + the batch. + + Args: + num_items: Number of items (transitions) to sample from this + buffer. + batch_size_B: The number of rows (transitions) to return in the + batch + batch_length_T: THe sequence length to sample. At this point in time + only sequences of length 1 are possible. + n_step: The n-step to apply. For the default the batch contains in + `"new_obs"` the observation and in `"obs"` the observation `n` + time steps before. The reward will be the sum of rewards + collected in between these two observations and the action will + be the one executed n steps before such that we always have the + state-action pair that triggered the rewards. + If `n_step` is a tuple, it is considered as a range to sample + from. If `None`, we use `n_step=1`. + gamma: The discount factor to be used when applying n-step calculations. + The default of `0.99` should be replaced by the `Algorithm`s + discount factor. + include_infos: A boolean indicating, if `info`s should be included in + the batch. This could be of advantage, if the `info` contains + values from the environment important for loss computation. If + `True`, the info at the `"new_obs"` in the batch is included. + include_extra_model_outputs: A boolean indicating, if + `extra_model_outputs` should be included in the batch. This could be + of advantage, if the `extra_mdoel_outputs` contain outputs from the + model important for loss computation and only able to compute with the + actual state of model e.g. action log-probabilities, etc.). If `True`, + the extra model outputs at the `"obs"` in the batch is included (the + timestep at which the action is computed). + + Returns: + A list of 1-step long episodes containing all basic episode data and if + requested infos and extra model outputs. + """ + if num_items is not None: + assert batch_size_B is None, ( + "Cannot call `sample()` with both `num_items` and `batch_size_B` " + "provided! Use either one." + ) + batch_size_B = num_items + + # Use our default values if no sizes/lengths provided. + batch_size_B = batch_size_B or self.batch_size_B + # TODO (simon): Implement trajectory sampling for RNNs. + batch_length_T = batch_length_T or self.batch_length_T + + # Sample the n-step if necessary. + actual_n_step = n_step or 1 + random_n_step = isinstance(n_step, tuple) + + # Keep track of the indices that were sampled last for updating the + # weights later (see `ray.rllib.utils.replay_buffer.utils. + # update_priorities_in_episode_replay_buffer`). + self._last_sampled_indices = [] + + sampled_episodes = [] + + B = 0 + while B < batch_size_B: + # Pull a new uniform random index tuple: (eps_idx, ts_in_eps_idx). + index_tuple = self._indices[self.rng.integers(len(self._indices))] + + # Compute the actual episode index (offset by the number of + # already evicted episodes). + episode_idx, episode_ts = ( + index_tuple[0] - self._num_episodes_evicted, + index_tuple[1], + ) + episode = self.episodes[episode_idx] + + # If we use random n-step sampling, draw the n-step for this item. + if random_n_step: + actual_n_step = int(self.rng.integers(n_step[0], n_step[1])) + + # Skip, if we are too far to the end and `episode_ts` + n_step would go + # beyond the episode's end. + if episode_ts + actual_n_step > len(episode): + continue + + # Note, this will be the reward after executing action + # `a_(episode_ts-n_step+1)`. For `n_step>1` this will be the discounted + # sum of all discounted rewards that were collected over the last n steps. + raw_rewards = episode.get_rewards( + slice(episode_ts, episode_ts + actual_n_step) + ) + rewards = scipy.signal.lfilter([1], [1, -gamma], raw_rewards[::-1], axis=0)[ + -1 + ] + + # Generate the episode to be returned. + sampled_episode = SingleAgentEpisode( + # Ensure that each episode contains a tuple of the form: + # (o_t, a_t, sum(r_(t:t+n_step)), o_(t+n_step)) + # Two observations (t and t+n). + observations=[ + episode.get_observations(episode_ts), + episode.get_observations(episode_ts + actual_n_step), + ], + observation_space=episode.observation_space, + infos=( + [ + episode.get_infos(episode_ts), + episode.get_infos(episode_ts + actual_n_step), + ] + if include_infos + else None + ), + actions=[episode.get_actions(episode_ts)], + action_space=episode.action_space, + rewards=[rewards], + # If the sampled time step is the episode's last time step check, if + # the episode is terminated or truncated. + terminated=( + False + if episode_ts + actual_n_step < len(episode) + else episode.is_terminated + ), + truncated=( + False + if episode_ts + actual_n_step < len(episode) + else episode.is_truncated + ), + extra_model_outputs={ + # TODO (simon): Check, if we have to correct here for sequences + # later. + "n_step": [actual_n_step], + **( + { + k: [episode.get_extra_model_outputs(k, episode_ts)] + for k in episode.extra_model_outputs.keys() + } + if include_extra_model_outputs + else {} + ), + }, + # TODO (sven): Support lookback buffers. + len_lookback_buffer=0, + t_started=episode_ts, + ) + sampled_episodes.append(sampled_episode) + + # Increment counter. + B += 1 + + self.sampled_timesteps += batch_size_B + + return sampled_episodes + def get_num_episodes(self) -> int: """Returns number of episodes (completed or truncated) stored in the buffer.""" return len(self.episodes) diff --git a/rllib/utils/replay_buffers/prioritized_episode_buffer.py b/rllib/utils/replay_buffers/prioritized_episode_buffer.py index 38bbf3088697..2eee2982e188 100644 --- a/rllib/utils/replay_buffers/prioritized_episode_buffer.py +++ b/rllib/utils/replay_buffers/prioritized_episode_buffer.py @@ -310,6 +310,7 @@ def sample( gamma: float = 0.99, include_infos: bool = False, include_extra_model_outputs: bool = False, + **kwargs, ) -> SampleBatchType: """Samples from a buffer in a prioritized way. diff --git a/rllib/utils/replay_buffers/tests/test_episode_replay_buffer.py b/rllib/utils/replay_buffers/tests/test_episode_replay_buffer.py index 2f9dd1b20e10..12b4c2ccd309 100644 --- a/rllib/utils/replay_buffers/tests/test_episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/tests/test_episode_replay_buffer.py @@ -6,6 +6,8 @@ EpisodeReplayBuffer, ) +from ray.rllib.utils.test_utils import check + class TestEpisodeReplayBuffer(unittest.TestCase): @staticmethod @@ -139,6 +141,53 @@ def test_episode_replay_buffer_sample_logic(self): # (reset rewards). assert np.all(np.where(is_terminated[:, :-1], rewards[:, 1:] == 0.0, True)) + def test_episode_replay_buffer_episode_sample_logic(self): + + buffer = EpisodeReplayBuffer(capacity=10000) + + for _ in range(200): + episode = self._get_episode() + buffer.add(episode) + + for i in range(1000): + sample = buffer.sample(batch_size_B=16, n_step=1, sample_episodes=True) + check(buffer.get_sampled_timesteps(), 16 * (i + 1)) + for eps in sample: + + ( + obs, + action, + reward, + next_obs, + is_terminated, + is_truncated, + n_step, + ) = ( + eps.get_observations(0), + eps.get_actions(-1), + eps.get_rewards(-1), + eps.get_observations(-1), + eps.is_terminated, + eps.is_truncated, + eps.get_extra_model_outputs("n_step", -1), + ) + + # Make sure terminated and truncated are never both True. + assert not (is_truncated and is_terminated) + + # Note, floating point numbers cannot be compared directly. + tolerance = 1e-8 + # Assert that actions correspond to the observations. + check(obs, action, atol=tolerance) + # Assert that next observations are correctly one step after + # observations. + check(next_obs, obs + 1, atol=tolerance) + # Assert that the reward comes from the next observation. + check(reward * 10, next_obs, atol=tolerance) + + # Assert that all n-steps are 1.0 as passed into `sample`. + check(n_step, 1.0, atol=tolerance) + if __name__ == "__main__": import pytest