From ec683376bde36a98988dfa149bc11087f1cf299b Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Mon, 11 Mar 2024 17:11:38 +0100 Subject: [PATCH] [RLlib] Added funcitonality to add `infos` and `extra_model_outputs` to the sample output of `PrioritizedEpisodeReplayBuffer`. (#43496) --- .../prioritized_episode_replay_buffer.py | 345 ++++++++++++++++-- .../test_prioritized_episode_replay_buffer.py | 135 ++++++- 2 files changed, 456 insertions(+), 24 deletions(-) diff --git a/rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py b/rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py index 3730b10ec54b..6d4db4d136e6 100644 --- a/rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py @@ -6,6 +6,7 @@ from numpy.typing import NDArray from typing import Any, Dict, List, Optional, Tuple, Union +from ray.rllib.core.columns import Columns from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.execution.segment_tree import MinSegmentTree, SumSegmentTree from ray.rllib.policy.sample_batch import SampleBatch @@ -13,6 +14,7 @@ from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer from ray.rllib.utils.annotations import override from ray.rllib.utils.typing import SampleBatchType +from ray.rllib.utils.spaces.space_utils import batch class PrioritizedEpisodeReplayBuffer(EpisodeReplayBuffer): @@ -154,6 +156,9 @@ def __init__( self._free_nodes = deque( list(range(2 * tree_capacity)), maxlen=2 * tree_capacity ) + # Keep track of the maximum index used from the trees. This helps + # to not traverse the complete trees. + self._max_idx = 0 # Map from tree indices to sample indices (i.e. `self._indices`). self._tree_idx_to_sample_idx = {} @@ -236,6 +241,8 @@ def add( if idx_triple[0] in eps_evicted_idxs: # Here we need the index of a sample in the segment tree. self._free_nodes.appendleft(idx_triple[2]) + # Also remove the potentially maximum index. + self._max_idx -= 1 if self._max_idx == idx_triple[2] else 0 self._sum_segment[idx_triple[2]] = 0.0 self._min_segment[idx_triple[2]] = float("inf") # Otherwise update the index in the index mapping. @@ -293,10 +300,11 @@ def sample( *, batch_size_B: Optional[int] = None, batch_length_T: Optional[int] = None, - n_step: Optional[Union[int, Tuple]] = 1, + n_step: Optional[Union[int, Tuple]] = None, beta: float = 0.0, gamma: float = 0.99, include_infos: bool = False, + include_extra_model_outputs: bool = False, ) -> SampleBatchType: """Samples from a buffer in a prioritized way. @@ -340,6 +348,13 @@ def sample( the batch. This could be of advantage, if the `info` contains values from the environment important for loss computation. If `True`, the info at the `"new_obs"` in the batch is included. + include_extra_model_outputs: A boolean indicating, if + `extra_model_outputs` should be included in the batch. This could be + of advantage, if the `extra_mdoel_outputs` contain outputs from the + model important for loss computation and only able to compute with the + actual state of model e.g. action log-probabilities, etc.). If `True`, + the extra model outputs at the `"obs"` in the batch is included (the + timestep at which the action is computed). Returns: A sample batch (observations, actions, rewards, new observations, @@ -365,7 +380,7 @@ def sample( # Use random n-step sampling. random_n_step = True else: - actual_n_step = n_step + actual_n_step = n_step or 1 random_n_step = False # Rows to return. @@ -373,14 +388,16 @@ def sample( next_observations = [[] for _ in range(batch_size_B)] actions = [[] for _ in range(batch_size_B)] rewards = [[] for _ in range(batch_size_B)] - is_terminated = [[False] for _ in range(batch_size_B)] - is_truncated = [[False] for _ in range(batch_size_B)] + is_terminated = [False for _ in range(batch_size_B)] + is_truncated = [False for _ in range(batch_size_B)] weights = [[] for _ in range(batch_size_B)] n_steps = [[] for _ in range(batch_size_B)] # If `info` should be included, construct also a container for them. - # TODO (simon): Add also `extra_model_outs`. if include_infos: infos = [[] for _ in range(batch_size_B)] + # If `extra_model_outputs` should be included, construct a container for them. + if include_extra_model_outputs: + extra_model_outputs = [[] for _ in range(batch_size_B)] # Keep track of the indices that were sampled last for updating the # weights later (see `ray.rllib.utils.replay_buffer.utils. # update_priorities_in_episode_replay_buffer`). @@ -429,7 +446,7 @@ def sample( if episode_ts - actual_n_step < 0: continue else: - n_steps[B].append(actual_n_step) + n_steps[B] = actual_n_step # Starting a new chunk. # Ensure that each row contains a tuple of the form: @@ -443,22 +460,33 @@ def sample( eps_rewards = episode.get_rewards( slice(episode_ts - actual_n_step, episode_ts) ) - observations[B].append(eps_observations[0]) - next_observations[B].append(eps_observations[-1]) + observations[B] = eps_observations[0] + next_observations[B] = eps_observations[-1] # Note, this will be the reward after executing action # `a_(episode_ts-n_step+1)`. For `n_step>1` this will be the sum of # all rewards that were collected over the last n steps. - rewards[B].append( - scipy.signal.lfilter([1], [1, -gamma], eps_rewards[::-1], axis=0)[-1] - ) - # rewards[B].append(sum(eps_rewards)) + rewards[B] = scipy.signal.lfilter( + [1], [1, -gamma], eps_rewards[::-1], axis=0 + )[-1] # Note, `SingleAgentEpisode` stores the action that followed # `o_t` with `o_(t+1)`, therefore, we need the next one. - actions[B].append(episode.get_actions(episode_ts - n_step)) + actions[B] = episode.get_actions(episode_ts - actual_n_step) if include_infos: # If infos are included we include the ones from the last timestep # as usually the info contains additional values about the last state. - infos[B].append(episode.get_infos(episode_ts)) + infos[B] = episode.get_infos(episode_ts) + if include_extra_model_outputs: + # If `extra_model_outputs` are included we include the ones from the + # first timestep as usually the `extra_model_outputs` contain additional + # values from the forward pass that produced the action at the first + # timestep. + # Note, we extract them into single row dictionaries similar to the + # infos, in a connector we can then extract these into single batch + # rows. + extra_model_outputs[B] = { + k: episode.get_extra_model_outputs(k, episode_ts - actual_n_step) + for k in episode.extra_model_outputs.keys() + } # If the sampled time step is the episode's last time step check, if # the episode is terminated or truncated. @@ -469,7 +497,7 @@ def sample( # TODO (simon): Check, if we have to correct here for sequences # later. actual_size = 1 - weights[B].append(weight / max_weight * actual_size) + weights[B] = weight / max_weight * actual_size # Increment counter. B += 1 @@ -481,21 +509,293 @@ def sample( # TODO Return SampleBatch instead of this simpler dict. ret = { - SampleBatch.OBS: np.array(observations), - SampleBatch.ACTIONS: np.array(actions), - SampleBatch.REWARDS: np.array(rewards), - SampleBatch.NEXT_OBS: np.array(next_observations), - SampleBatch.TERMINATEDS: np.array(is_terminated), - SampleBatch.TRUNCATEDS: np.array(is_truncated), + # Note, observation and action spaces could be complex. `batch` + # takes care of these. + Columns.OBS: batch(observations), + Columns.ACTIONS: batch(actions), + Columns.REWARDS: np.array(rewards), + Columns.NEXT_OBS: batch(next_observations), + Columns.TERMINATEDS: np.array(is_terminated), + Columns.TRUNCATEDS: np.array(is_truncated), "weights": np.array(weights), "n_steps": np.array(n_steps), } + # Include infos if necessary. + if include_infos: + ret.update( + { + SampleBatch.INFOS: infos, + } + ) + # Include extra model outputs, if necessary. + if include_extra_model_outputs: + ret.update( + # These could be complex, too. + batch(extra_model_outputs) + ) + + return ret + + # TODO (simon): Adjust docstring. + def sample_with_keys( + self, + num_items: Optional[int] = None, + *, + batch_size_B: Optional[int] = None, + batch_length_T: Optional[int] = None, + n_step: Optional[Union[int, Tuple]] = None, + beta: float = 0.0, + gamma: float = 0.99, + include_infos: bool = False, + include_extra_model_outputs: bool = False, + ) -> SampleBatchType: + """Samples from a buffer in a prioritized way. + + This sampling method also adds (importance sampling) weights to + the returned batch. See for prioritized sampling Schaul et al. + (2016). + + Each sampled item defines a transition of the form: + + `(o_t, a_t, sum(r_(t+1:t+n+1)), o_(t+n), terminated_(t+n), truncated_(t+n))` + + where `o_(t+n)` is drawn by prioritized sampling, i.e. the priority + of `o_(t+n)` led to the sample and defines the importance weight that + is returned in the sample batch. `n` is defined by the `n_step` applied. + + If requested, `info`s of a transitions last timestep `t+n` are added to + the batch. + + Args: + num_items: Number of items (transitions) to sample from this + buffer. + batch_size_B: The number of rows (transitions) to return in the + batch + n_step: The n-step to apply. For the default the batch contains in + `"new_obs"` the observation and in `"obs"` the observation `n` + time steps before. The reward will be the sum of rewards + collected in between these two observations and the action will + be the one executed n steps before such that we always have the + state-action pair that triggered the rewards. + If `n_step` is a tuple, it is considered as a range to sample + from. If `None`, we use `n_step=1`. + beta: The exponent of the importance sampling weight (see Schaul et + al. (2016)). A `beta=0.0` does not correct for the bias introduced + by prioritized replay and `beta=1.0` fully corrects for it. + gamma: The discount factor to be used when applying n-step caluclations. + The default of `0.99` should be replaced by the `Algorithm`s + discount factor. + include_infos: A boolean indicating, if `info`s should be included in + the batch. This could be of advantage, if the `info` contains + values from the environment important for loss computation. If + `True`, the info at the `"new_obs"` in the batch is included. + include_extra_model_outputs: A boolean indicating, if + `extra_model_outputs` should be included in the batch. This could be + of advantage, if the `extra_mdoel_outputs` contain outputs from the + model important for loss computation and only able to compute with the + actual state of model e.g. action log-probabilities, etc.). If `True`, + the extra model outputs at the `"obs"` in the batch is included (the + timestep at which the action is computed). + + Returns: + A sample batch (observations, actions, rewards, new observations, + terminateds, truncateds, weights) and if requested infos and extra model + outputs. Extra model outputs are extracted to single columns in the batch + and infos are kept as a list of dictionaries. The batch keys are the episode + ids. + """ + assert beta >= 0.0 + + if num_items is not None: + assert batch_size_B is None, ( + "Cannot call `sample()` with both `num_items` and `batch_size_B` " + "provided! Use either one." + ) + batch_size_B = num_items + + # Use our default values if no sizes/lengths provided. + batch_size_B = batch_size_B or self.batch_size_B + batch_length_T = batch_length_T or self.batch_length_T + + # Sample the n-step if necessary. + if isinstance(n_step, tuple): + # Use random n-step sampling. + random_n_step = True + else: + actual_n_step = n_step or 1 + random_n_step = False + + # Columns to return. + observations = {} + next_observations = {} + actions = {} + rewards = {} + is_terminated = {} + is_truncated = {} + weights = {} + n_steps = {} + # If `info` should be included, construct also a container for them. + if include_infos: + infos = {} + # If `extra_model_outputs` should be included, construct a container for them. + if include_extra_model_outputs: + # Get the keys from an episode in the buffer. + # TODO (simon, sven): What happens, if different episodes have different + # extra model outputs or some are missing? + extra_model_outputs = { + k: {} for k in self.episodes[0].extra_model_outputs.keys() + } + # Keep track of the indices that were sampled last for updating the + # weights later (see `ray.rllib.utils.replay_buffer.utils. + # update_priorities_in_episode_replay_buffer`). + self._last_sampled_indices = [] + + # Sample proportionally from replay buffer's segments using the weights. + total_segment_sum = self._sum_segment.sum() + p_min = self._min_segment.min() / total_segment_sum + max_weight = (p_min * self.get_num_timesteps()) ** (-beta) + B = 0 + while B < batch_size_B: + # First, draw a random sample from Uniform(0, sum over all weights). + # Note, transitions with higher weight get sampled more often (as + # more random draws fall into larger intervals). + random_sum = self.rng.random() * self._sum_segment.sum(0, self._max_idx + 1) + # Get the highest index in the sum-tree for which the sum is + # smaller or equal the random sum sample. + # Note, we sample `o_(t + n_step)` as this is the state that + # brought the information contained in the TD-error (see Schaul + # et al. (2018), Algorithm 1). + idx = self._sum_segment.find_prefixsum_idx(random_sum) + # Get the theoretical probability mass for drawing this sample. + p_sample = self._sum_segment[idx] / total_segment_sum + # Compute the importance sampling weight. + weight = (p_sample * self.get_num_timesteps()) ** (-beta) + # Now, get the transition stored at this index. + index_triple = self._indices[self._tree_idx_to_sample_idx[idx]] + + # Compute the actual episode index (offset by the number of + # already evicted episodes) + episode_idx, episode_ts = ( + index_triple[0] - self._num_episodes_evicted, + index_triple[1], + ) + episode = self.episodes[episode_idx] + + # If we use random n-step sampling, draw the n-step for this item. + if random_n_step: + actual_n_step = int(self.rng.integers(n_step[0], n_step[1])) + # If we are at the end of an episode, continue. + # Note, priority sampling got us `o_(t+n)` and we need for the loss + # calculation in addition `o_t`. + # TODO (simon): Maybe introduce a variable `num_retries` until the + # while loop should break when not enough samples have been collected + # to make n-step possible. + if episode_ts - actual_n_step < 0: + continue + + # Starting a new chunk. + # Ensure that each row contains a tuple of the form: + # (o_t, a_t, sum(r_(t:t+n_step)), o_(t+n_step)) + # TODO (simon): Implement version for sequence sampling when using RNNs. + eps_observations = episode.get_observations( + slice(episode_ts - actual_n_step, episode_ts + 1) + ) + # Note, the reward that is collected by transitioning from `o_t` to + # `o_(t+1)` is stored in the next transition in `SingleAgentEpisode`. + eps_rewards = episode.get_rewards( + slice(episode_ts - actual_n_step, episode_ts) + ) + if (episode.id_,) not in observations: + # Add the key to all containers. + observations[(episode.id_,)] = [] + next_observations[(episode.id_,)] = [] + actions[(episode.id_,)] = [] + rewards[(episode.id_,)] = [] + is_terminated[(episode.id_,)] = [] + is_truncated[(episode.id_,)] = [] + weights[(episode.id_,)] = [] + n_steps[(episode.id_,)] = [] + if include_infos: + infos[(episode.id_,)] = [] + if include_extra_model_outputs: + # 'extra_model_outputs` has a structure + # `{"output_1": {(eps_id0,): [0.4, 2.3], ...}, ...}`` + for k in extra_model_outputs: + extra_model_outputs[k][(episode.id_,)] = [] + + # Add the `n_step` used for this item. + n_steps[(episode.id_,)].append(actual_n_step) + + observations[(episode.id_,)].append(eps_observations[0]) + next_observations[(episode.id_,)].append(eps_observations[-1]) + # Note, this will be the reward after executing action + # `a_(episode_ts-n_step+1)`. For `n_step>1` this will be the sum of + # all rewards that were collected over the last n steps. + rewards[(episode.id_,)].append( + scipy.signal.lfilter([1], [1, -gamma], eps_rewards[::-1], axis=0)[-1] + ) + # Note, `SingleAgentEpisode` stores the action that followed + # `o_t` with `o_(t+1)`, therefore, we need the next one. + actions[(episode.id_,)].append( + episode.get_actions(episode_ts - actual_n_step) + ) + if include_infos: + # If infos are included we include the ones from the last timestep + # as usually the info contains additional values about the last state. + infos[(episode.id_,)].append(episode.get_infos(episode_ts)) + if include_extra_model_outputs: + # If `extra_model_outputs` are included we include the ones from the + # first timestep as usually the `extra_model_outputs` contain additional + # values from the forward pass that produced the action at the first + # timestep. + for k in extra_model_outputs: + extra_model_outputs[k][(episode.id_,)].append( + episode.get_extra_model_outputs(k, episode_ts - actual_n_step) + ) + + # If the sampled time step is the episode's last time step check, if + # the episode is terminated or truncated. + if episode_ts == episode.t: + is_terminated[(episode.id_,)].append(episode.is_terminated) + is_truncated[(episode.id_,)].append(episode.is_truncated) + else: + is_terminated[(episode.id_,)].append(False) + is_truncated[(episode.id_,)].append(False) + + # TODO (simon): Check, if we have to correct here for sequences + # later. + actual_size = 1 + weights[(episode.id_,)].append(weight / max_weight * actual_size) + + # Increment counter. + B += 1 + + # Keep track of sampled indices for updating priorities later. + self._last_sampled_indices.append(idx) + + self.sampled_timesteps += batch_size_B + + # TODO Return SampleBatch instead of this simpler dict. + ret = { + Columns.OBS: observations, + Columns.ACTIONS: actions, + Columns.REWARDS: rewards, + Columns.NEXT_OBS: next_observations, + Columns.TERMINATEDS: is_terminated, + Columns.TRUNCATEDS: is_truncated, + "weights": weights, + "n_steps": n_steps, + } + # Include infos if necessary. if include_infos: ret.update( { - SampleBatch.INFOS: np.array(infos), + Columns.INFOS: infos, } ) + # Include extra model outputs, if necessary. + if include_extra_model_outputs: + ret.update(extra_model_outputs) return ret @@ -583,6 +883,7 @@ def _get_free_node_and_assign(self, sample_index, weight: float = 1.0) -> int: """ # Get an index from the free nodes in the segment trees. idx = self._free_nodes.popleft() + self._max_idx = idx if idx > self._max_idx else self._max_idx # Add the weight to the segments. self._sum_segment[idx] = weight**self._alpha self._min_segment[idx] = weight**self._alpha diff --git a/rllib/utils/replay_buffers/tests/test_prioritized_episode_replay_buffer.py b/rllib/utils/replay_buffers/tests/test_prioritized_episode_replay_buffer.py index f515d399e02d..2f316ecd879a 100644 --- a/rllib/utils/replay_buffers/tests/test_prioritized_episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/tests/test_prioritized_episode_replay_buffer.py @@ -1,3 +1,4 @@ +import tree import unittest import numpy as np @@ -9,7 +10,7 @@ class TestPrioritizedEpisodeReplayBuffer(unittest.TestCase): @staticmethod - def _get_episode(episode_len=None, id_=None): + def _get_episode(episode_len=None, id_=None, with_extra_model_outs=False): eps = SingleAgentEpisode(id_=id_, observations=[0.0], infos=[{}]) ts = np.random.randint(1, 200) if episode_len is None else episode_len for t in range(ts): @@ -18,6 +19,9 @@ def _get_episode(episode_len=None, id_=None): action=int(t), reward=0.1 * (t + 1), infos={}, + extra_model_outputs={k: k for k in range(2)} + if with_extra_model_outs + else None, ) eps.is_terminated = np.random.random() > 0.5 eps.is_truncated = False if eps.is_terminated else np.random.random() > 0.8 @@ -196,7 +200,7 @@ def test_prioritized_buffer_sample_logic(self): # Now test a random n-step sampling. for _ in range(1000): - sample = buffer.sample(batch_size_B=16, n_step=None, beta=1.0) + sample = buffer.sample(batch_size_B=16, n_step=(1, 5), beta=1.0) ( obs, actions, @@ -240,6 +244,133 @@ def test_prioritized_buffer_sample_logic(self): # Ensure that there is variation in the n-steps. self.assertTrue(np.var(n_steps) > 0.0) + def test_infos_and_extra_model_outputs(self): + # Define replay buffer (alpha=0.8) + buffer = PrioritizedEpisodeReplayBuffer(capacity=10000, alpha=0.8) + + # Fill the buffer with episodes. + for _ in range(200): + episode = self._get_episode(with_extra_model_outs=True) + buffer.add(episode) + + # Now test a sampling with infos and extra model outputs (beta=0.7). + for _ in range(1000): + sample = buffer.sample( + batch_size_B=16, + n_step=1, + beta=0.7, + include_infos=True, + include_extra_model_outputs=True, + ) + ( + obs, + actions, + rewards, + next_obs, + is_terminated, + is_truncated, + weights, + n_steps, + infos, + # Note, each extra model output gets extracted + # to its own column. + extra_model_outs_0, + extra_model_outs_1, + ) = ( + sample["obs"], + sample["actions"], + sample["rewards"], + sample["new_obs"], + sample["terminateds"], + sample["truncateds"], + sample["weights"], + sample["n_steps"], + sample["infos"], + sample[0], + sample[1], + ) + + # Make sure terminated and truncated are never both True. + assert not np.any(np.logical_and(is_truncated, is_terminated)) + + # All fields have same shape. + assert ( + obs.shape + == rewards.shape + == actions.shape + == next_obs.shape + == is_truncated.shape + == is_terminated.shape + == weights.shape + == n_steps.shape + # Note, infos will be a list of dicitonaries. + == (len(infos),) + == extra_model_outs_0.shape + == extra_model_outs_1.shape + ) + + def test_sample_with_keys(self): + # Define replay buffer (alpha=1.0). + buffer = PrioritizedEpisodeReplayBuffer(capacity=10000, alpha=0.8) + + # Fill the buffer with episodes. + for _ in range(200): + episode = self._get_episode(with_extra_model_outs=True) + buffer.add(episode) + + # Now test a sampling with infos and extra model outputs (nbeta=0.7). + for _ in range(1000): + sample = buffer.sample_with_keys( + batch_size_B=16, + n_step=1, + beta=0.7, + include_infos=True, + include_extra_model_outputs=True, + ) + + ( + obs, + actions, + rewards, + next_obs, + is_terminated, + is_truncated, + weights, + n_steps, + infos, + # Note, each extra model output gets extracted + # to its own column. + extra_model_outs_0, + extra_model_outs_1, + ) = ( + sample["obs"], + sample["actions"], + sample["rewards"], + sample["new_obs"], + sample["terminateds"], + sample["truncateds"], + sample["weights"], + sample["n_steps"], + sample["infos"], + sample[0], + sample[1], + ) + + # All fields have same shape. + assert ( + len(tree.flatten(obs)) + == len(tree.flatten(rewards)) + == len(tree.flatten(actions)) + == len(tree.flatten(next_obs)) + == len(tree.flatten(is_terminated)) + == len(tree.flatten(is_truncated)) + == len(tree.flatten(weights)) + == len(tree.flatten(n_steps)) + == sum([len(eps_infos) for eps_infos in infos.values()]) + == len(tree.flatten(extra_model_outs_0)) + == len(tree.flatten(extra_model_outs_1)) + ) + def test_update_priorities(self): # Define replay buffer (alpha=1.0). buffer = PrioritizedEpisodeReplayBuffer(capacity=100)