diff --git a/rllib/BUILD b/rllib/BUILD index 2f56aa2b08cd..50932e847781 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -906,6 +906,14 @@ py_test( srcs = ["algorithms/dreamer/tests/test_dreamer.py"] ) +# DT +py_test( + name = "test_segmentation_buffer", + tags = ["team:rllib", "algorithms_dir"], + size = "medium", + srcs = ["algorithms/dt/tests/test_segmentation_buffer.py"] +) + # ES py_test( name = "test_es", diff --git a/rllib/algorithms/dt/__init__.py b/rllib/algorithms/dt/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/rllib/algorithms/dt/segmentation_buffer.py b/rllib/algorithms/dt/segmentation_buffer.py new file mode 100644 index 000000000000..d3ba9843cd51 --- /dev/null +++ b/rllib/algorithms/dt/segmentation_buffer.py @@ -0,0 +1,214 @@ +import logging +from collections import defaultdict +from typing import List +import random + +import numpy as np + +from ray.rllib.evaluation.postprocessing import discount_cumsum +from ray.rllib.policy.sample_batch import SampleBatch, concat_samples, MultiAgentBatch +from ray.rllib.utils.typing import SampleBatchType + +logger = logging.getLogger(__name__) + + +def front_pad_with_zero(arr: np.ndarray, max_seq_len: int): + """Pad arr on the front/left with 0 up to max_seq_len.""" + length = arr.shape[0] + pad_length = max_seq_len - length + if pad_length > 0: + return np.concatenate( + [np.zeros((pad_length, *arr.shape[1:]), dtype=arr.dtype), arr], axis=0 + ) + else: + return arr + + +class SegmentationBuffer: + """A minimal replay buffer used by Decision Transformer (DT) + to process episodes into max_seq_len length segments and do shuffling. + """ + + def __init__( + self, + capacity: int = 20, + max_seq_len: int = 20, + max_ep_len: int = 1000, + ): + """ + Args: + capacity: Maximum number of episodes the buffer can store. + max_seq_len: Length of segments that are sampled. + max_ep_len: Maximum length of episodes added. + """ + self.capacity = capacity + self.max_seq_len = max_seq_len + self.max_ep_len = max_ep_len + + self._buffer: List[SampleBatch] = [] + + def add(self, batch: SampleBatch): + """Add a SampleBatch of episodes. Replace if full. + + Args: + batch: SampleBatch of full episodes. + """ + episodes = batch.split_by_episode(key=SampleBatch.DONES) + for episode in episodes: + self._add_single_episode(episode) + + def _add_single_episode(self, episode: SampleBatch): + # Truncate if episode too long. + # Note: sometimes this happens if the dataset shuffles such that the + # same episode is concatenated together twice (which is okay). + ep_len = episode.env_steps() + + if ep_len > self.max_ep_len: + raise ValueError( + f"The maximum rollout length is {self.max_ep_len} but we tried to add a" + f"rollout of {episode.env_steps()} steps to the SegmentationBuffer." + ) + + # compute returns to go + rewards = episode[SampleBatch.REWARDS].reshape(-1) + rtg = discount_cumsum(rewards, 1.0) + # rtg needs to be one longer than the rest for return targets during training. + rtg = np.concatenate([rtg, np.zeros((1,), dtype=np.float32)], axis=0) + episode[SampleBatch.RETURNS_TO_GO] = rtg[:, None] + + # Add timesteps and masks + episode[SampleBatch.T] = np.arange(ep_len, dtype=np.int32) + episode[SampleBatch.ATTENTION_MASKS] = np.ones(ep_len, dtype=np.float32) + + # Add to the buffer. + if len(self._buffer) < self.capacity: + self._buffer.append(episode) + else: + # TODO: add config for sampling and eviction policies. + replace_ind = random.randint(0, self.capacity - 1) + self._buffer[replace_ind] = episode + + def sample(self, batch_size: int) -> SampleBatch: + """Sample segments from the buffer. + + Args: + batch_size: number of segments to sample. + + Returns: + SampleBatch of segments with keys and shape { + OBS: [batch_size, max_seq_len, obs_dim], + ACTIONS: [batch_size, max_seq_len, act_dim], + RETURNS_TO_GO: [batch_size, max_seq_len + 1, 1], + T: [batch_size, max_seq_len], + ATTENTION_MASKS: [batch_size, max_seq_len], + } + """ + samples = [self._sample_single() for _ in range(batch_size)] + return concat_samples(samples) + + def _sample_single(self) -> SampleBatch: + # TODO: sample proportional to episode length + # Sample a random episode from the buffer and then sample a random + # segment from that episode. + buffer_ind = random.randint(0, len(self._buffer) - 1) + + episode = self._buffer[buffer_ind] + ep_len = episode[SampleBatch.OBS].shape[0] + + # ei (end index) is exclusive + ei = random.randint(1, ep_len) + # si (start index) is inclusive + si = max(ei - self.max_seq_len, 0) + + # Slice segments from obs, actions, timesteps, and rtgs + obs = episode[SampleBatch.OBS][si:ei] + actions = episode[SampleBatch.ACTIONS][si:ei] + timesteps = episode[SampleBatch.T][si:ei] + masks = episode[SampleBatch.ATTENTION_MASKS][si:ei] + # Note that returns-to-go needs an extra elem as the rtg target for the last + # action token passed into the transformer. + returns_to_go = episode[SampleBatch.RETURNS_TO_GO][si : ei + 1] + + # Front-pad if we're at the beginning of the episode and we need more tokens + # to pass into the transformer. Or if the episode length is shorter + # than max_seq_len. + obs = front_pad_with_zero(obs, self.max_seq_len) + actions = front_pad_with_zero(actions, self.max_seq_len) + returns_to_go = front_pad_with_zero(returns_to_go, self.max_seq_len + 1) + timesteps = front_pad_with_zero(timesteps, self.max_seq_len) + masks = front_pad_with_zero(masks, self.max_seq_len) + + assert obs.shape[0] == self.max_seq_len + assert actions.shape[0] == self.max_seq_len + assert timesteps.shape[0] == self.max_seq_len + assert masks.shape[0] == self.max_seq_len + assert returns_to_go.shape[0] == self.max_seq_len + 1 + + return SampleBatch( + { + SampleBatch.OBS: obs[None], + SampleBatch.ACTIONS: actions[None], + SampleBatch.RETURNS_TO_GO: returns_to_go[None], + SampleBatch.T: timesteps[None], + SampleBatch.ATTENTION_MASKS: masks[None], + } + ) + + +class MultiAgentSegmentationBuffer: + """A minimal replay buffer used by Decision Transformer (DT) + to process episodes into max_seq_len length segments and do shuffling. + Stores MultiAgentSample. + """ + + def __init__( + self, + capacity: int = 20, + max_seq_len: int = 20, + max_ep_len: int = 1000, + ): + """ + Args: + capacity: Maximum number of episodes the buffer can store. + max_seq_len: Length of segments that are sampled. + max_ep_len: Maximum length of episodes added. + """ + + def new_buffer(): + return SegmentationBuffer(capacity, max_seq_len, max_ep_len) + + self.buffers = defaultdict(new_buffer) + + def add(self, batch: SampleBatchType): + """Add a MultiAgentBatch of episodes. Replace if full. + + Args: + batch: MultiAgentBatch of full episodes. + """ + # Make a copy so the replay buffer doesn't pin plasma memory. + batch = batch.copy() + # Handle everything as if multi-agent. + batch = batch.as_multi_agent() + + for policy_id, sample_batch in batch.policy_batches.items(): + self.buffers[policy_id].add(sample_batch) + + def sample(self, batch_size: int) -> MultiAgentBatch: + """Sample segments from the buffer. + + Args: + batch_size: number of segments to sample. + + Returns: + MultiAgentBatch of segments with keys and shape { + OBS: [batch_size, max_seq_len, obs_dim], + ACTIONS: [batch_size, max_seq_len, act_dim], + RETURNS_TO_GO: [batch_size, max_seq_len + 1, 1], + T: [batch_size, max_seq_len], + ATTENTION_MASKS: [batch_size, max_seq_len], + } + """ + samples = {} + for policy_id, buffer in self.buffers.items(): + samples[policy_id] = buffer.sample(batch_size) + return MultiAgentBatch(samples, batch_size) diff --git a/rllib/algorithms/dt/tests/test_segmentation_buffer.py b/rllib/algorithms/dt/tests/test_segmentation_buffer.py new file mode 100644 index 000000000000..b2cf83b46dce --- /dev/null +++ b/rllib/algorithms/dt/tests/test_segmentation_buffer.py @@ -0,0 +1,419 @@ +import unittest +from typing import Union, List + +import numpy as np + +import ray +from ray.rllib.algorithms.dt.segmentation_buffer import ( + SegmentationBuffer, + MultiAgentSegmentationBuffer, +) +from ray.rllib.policy.sample_batch import ( + SampleBatch, + MultiAgentBatch, + concat_samples, + DEFAULT_POLICY_ID, +) +from ray.rllib.utils import test_utils +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.typing import PolicyID + +tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() + + +def _generate_episode_batch(ep_len, eps_id, obs_dim=8, act_dim=3): + """Generate a batch containing one episode.""" + # These values are not actually correct as usual. But using eps_id + # as the values allow us to identify them in the tests. + batch = SampleBatch( + { + SampleBatch.OBS: np.full((ep_len, obs_dim), eps_id, dtype=np.float32), + SampleBatch.ACTIONS: np.full( + (ep_len, act_dim), eps_id + 100, dtype=np.float32 + ), + SampleBatch.REWARDS: np.ones((ep_len,), dtype=np.float32), + SampleBatch.RETURNS_TO_GO: np.arange( + ep_len, -1, -1, dtype=np.float32 + ).reshape((ep_len + 1, 1)), + SampleBatch.EPS_ID: np.full((ep_len,), eps_id, dtype=np.int32), + SampleBatch.T: np.arange(ep_len, dtype=np.int32), + SampleBatch.ATTENTION_MASKS: np.ones(ep_len, dtype=np.float32), + SampleBatch.DONES: np.array([False] * (ep_len - 1) + [True]), + } + ) + return batch + + +def _assert_sample_batch_keys(batch: SampleBatch): + """Assert sampled batch has the requisite keys.""" + assert SampleBatch.OBS in batch + assert SampleBatch.ACTIONS in batch + assert SampleBatch.RETURNS_TO_GO in batch + assert SampleBatch.T in batch + assert SampleBatch.ATTENTION_MASKS in batch + + +def _assert_sample_batch_not_equal(b1: SampleBatch, b2: SampleBatch): + """Assert that the two batches are not equal.""" + for key in b1.keys() & b2.keys(): + if b1[key].shape == b2[key].shape: + assert not np.allclose( + b1[key], b2[key] + ), f"Key {key} contain the same value when they should not." + + +def _assert_is_segment(segment: SampleBatch, episode: SampleBatch): + """Assert that the sampled segment is a segment of episode.""" + timesteps = segment[SampleBatch.T] + masks = segment[SampleBatch.ATTENTION_MASKS] > 0.5 + seq_len = timesteps.shape[0] + episode_segment = episode.slice(timesteps[0], timesteps[-1] + 1) + assert np.allclose( + segment[SampleBatch.OBS][masks], episode_segment[SampleBatch.OBS] + ) + assert np.allclose( + segment[SampleBatch.ACTIONS][masks], episode_segment[SampleBatch.ACTIONS] + ) + assert np.allclose( + segment[SampleBatch.RETURNS_TO_GO][:seq_len][masks], + episode_segment[SampleBatch.RETURNS_TO_GO], + ) + + +def _get_internal_buffer( + buffer: Union[SegmentationBuffer, MultiAgentSegmentationBuffer], + policy_id: PolicyID = DEFAULT_POLICY_ID, +) -> List[SampleBatch]: + """Get the internal buffer list from the buffer. If MultiAgent then return the + internal buffer corresponding to the given policy_id. + """ + if type(buffer) == SegmentationBuffer: + return buffer._buffer + elif type(buffer) == MultiAgentSegmentationBuffer: + return buffer.buffers[policy_id]._buffer + else: + raise NotImplementedError + + +def _as_sample_batch( + batch: Union[SampleBatch, MultiAgentBatch], + policy_id: PolicyID = DEFAULT_POLICY_ID, +) -> SampleBatch: + """Returns a SampleBatch. If MultiAgentBatch then return the SampleBatch + corresponding to the given policy_id. + """ + if type(batch) == SampleBatch: + return batch + elif type(batch) == MultiAgentBatch: + return batch.policy_batches[policy_id] + else: + raise NotImplementedError + + +class TestSegmentationBuffer(unittest.TestCase): + @classmethod + def setUpClass(cls): + ray.init() + + @classmethod + def tearDownClass(cls): + ray.shutdown() + + def test_add(self): + """Test adding to segmentation buffer.""" + for buffer_cls in [SegmentationBuffer, MultiAgentSegmentationBuffer]: + max_seq_len = 3 + max_ep_len = 10 + capacity = 1 + buffer = buffer_cls(capacity, max_seq_len, max_ep_len) + + # generate batch + episode_batches = [] + for i in range(4): + episode_batches.append(_generate_episode_batch(max_ep_len, i)) + batch = concat_samples(episode_batches) + + # add to buffer and check that only last one is kept (due to replacement) + buffer.add(batch) + + self.assertEqual( + len(_get_internal_buffer(buffer)), + 1, + "The internal buffer should only contain one SampleBatch since" + " the capacity is 1.", + ) + test_utils.check(episode_batches[-1], _get_internal_buffer(buffer)[0]) + + # add again + buffer.add(episode_batches[0]) + + test_utils.check(episode_batches[0], _get_internal_buffer(buffer)[0]) + + # make buffer of enough capacity + capacity = len(episode_batches) + buffer = buffer_cls(capacity, max_seq_len, max_ep_len) + + # add to buffer and make sure all are in + buffer.add(batch) + self.assertEqual( + len(_get_internal_buffer(buffer)), + len(episode_batches), + "internal buffer doesn't have the right number of episodes.", + ) + for i in range(len(episode_batches)): + test_utils.check(episode_batches[i], _get_internal_buffer(buffer)[i]) + + # add another one and make sure it replaced one of them + new_batch = _generate_episode_batch(max_ep_len, 12345) + buffer.add(new_batch) + self.assertEqual( + len(_get_internal_buffer(buffer)), + len(episode_batches), + "internal buffer doesn't have the right number of episodes.", + ) + found = False + for episode_batch in _get_internal_buffer(buffer): + if episode_batch[SampleBatch.EPS_ID][0] == 12345: + test_utils.check(episode_batch, new_batch) + found = True + break + assert found, "new_batch not added to buffer." + + # test that adding too long an episode errors + long_batch = _generate_episode_batch(max_ep_len + 1, 123) + with self.assertRaises(ValueError): + buffer.add(long_batch) + + def test_sample_basic(self): + """Test sampling from a segmentation buffer.""" + for buffer_cls in (SegmentationBuffer, MultiAgentSegmentationBuffer): + max_seq_len = 5 + max_ep_len = 15 + capacity = 4 + obs_dim = 10 + act_dim = 2 + + buffer = buffer_cls(capacity, max_seq_len, max_ep_len) + + # generate batch and add to buffer + episode_batches = [] + for i in range(8): + episode_batches.append( + _generate_episode_batch(max_ep_len, i, obs_dim, act_dim) + ) + batch = concat_samples(episode_batches) + buffer.add(batch) + + # sample a few times and check shape + for bs in range(10, 20): + batch = _as_sample_batch(buffer.sample(bs)) + # check the keys exist + _assert_sample_batch_keys(batch) + + # check the shapes + self.assertEquals( + batch[SampleBatch.OBS].shape, (bs, max_seq_len, obs_dim) + ) + self.assertEquals( + batch[SampleBatch.ACTIONS].shape, (bs, max_seq_len, act_dim) + ) + self.assertEquals( + batch[SampleBatch.RETURNS_TO_GO].shape, + ( + bs, + max_seq_len + 1, + 1, + ), + ) + self.assertEquals(batch[SampleBatch.T].shape, (bs, max_seq_len)) + self.assertEquals( + batch[SampleBatch.ATTENTION_MASKS].shape, (bs, max_seq_len) + ) + + def test_sample_content(self): + """Test that the content of the sampling are valid.""" + for buffer_cls in (SegmentationBuffer, MultiAgentSegmentationBuffer): + max_seq_len = 5 + max_ep_len = 200 + capacity = 1 + obs_dim = 11 + act_dim = 1 + + buffer = buffer_cls(capacity, max_seq_len, max_ep_len) + + # generate single episode and add to buffer + episode = _generate_episode_batch(max_ep_len, 123, obs_dim, act_dim) + buffer.add(episode) + + # sample twice and make sure they are not equal. + # with a 200 max_ep_len and 200 samples, the probability that the two + # samples are equal by chance is (1/200)**200 which is basically zero. + sample1 = _as_sample_batch(buffer.sample(200)) + sample2 = _as_sample_batch(buffer.sample(200)) + _assert_sample_batch_keys(sample1) + _assert_sample_batch_keys(sample2) + _assert_sample_batch_not_equal(sample1, sample2) + + # sample and make sure the segments are actual segments of the episode + batch = _as_sample_batch(buffer.sample(1000)) + _assert_sample_batch_keys(batch) + for elem in batch.rows(): + _assert_is_segment(SampleBatch(elem), episode) + + def test_sample_capacity(self): + """Test that sampling from buffer of capacity > 1 works.""" + for buffer_cls in (SegmentationBuffer, MultiAgentSegmentationBuffer): + max_seq_len = 3 + max_ep_len = 10 + capacity = 100 + obs_dim = 1 + act_dim = 1 + + buffer = buffer_cls(capacity, max_seq_len, max_ep_len) + + # Generate batch and add to buffer + episode_batches = [] + for i in range(capacity): + episode_batches.append( + _generate_episode_batch(max_ep_len, i, obs_dim, act_dim) + ) + buffer.add(concat_samples(episode_batches)) + + # Sample 100 times and check that samples are from at least 2 different + # episodes. The [robability of all sampling from 1 episode by chance is + # (1/100)**99 which is basically zero. + batch = _as_sample_batch(buffer.sample(100)) + eps_ids = set() + for i in range(100): + # obs generated by _generate_episode_batch contains eps_id + # use -1 because there might be front padding + eps_id = int(batch[SampleBatch.OBS][i, -1, 0]) + eps_ids.add(eps_id) + + self.assertGreater( + len(eps_ids), 1, "buffer.sample is always returning the same episode." + ) + + def test_padding(self): + """Test that sample will front pad segments.""" + for buffer_cls in (SegmentationBuffer, MultiAgentSegmentationBuffer): + max_seq_len = 10 + max_ep_len = 100 + capacity = 1 + obs_dim = 3 + act_dim = 2 + + buffer = buffer_cls(capacity, max_seq_len, max_ep_len) + + for ep_len in range(1, max_seq_len): + # generate batch with episode lengths that are shorter than + # max_seq_len to test padding. + batch = _generate_episode_batch(ep_len, 123, obs_dim, act_dim) + buffer.add(batch) + + samples = _as_sample_batch(buffer.sample(50)) + for i in range(50): + # calculate number of pads based on the attention mask. + num_pad = int( + ep_len - samples[SampleBatch.ATTENTION_MASKS][i].sum() + ) + for key in samples.keys(): + # make sure padding are added. + assert np.allclose( + samples[key][i, :num_pad], 0.0 + ), "samples were not padded correctly." + + def test_multi_agent(self): + max_seq_len = 5 + max_ep_len = 20 + capacity = 10 + obs_dim = 3 + act_dim = 5 + + ma_buffer = MultiAgentSegmentationBuffer(capacity, max_seq_len, max_ep_len) + + policy_id1 = "1" + policy_id2 = "2" + policy_id3 = "3" + policy_ids = {policy_id1, policy_id2, policy_id3} + + policy1_batches = [] + for i in range(0, 10): + policy1_batches.append( + _generate_episode_batch(max_ep_len, i, obs_dim, act_dim) + ) + policy2_batches = [] + for i in range(10, 20): + policy2_batches.append( + _generate_episode_batch(max_ep_len, i, obs_dim, act_dim) + ) + policy3_batches = [] + for i in range(20, 30): + policy3_batches.append( + _generate_episode_batch(max_ep_len, i, obs_dim, act_dim) + ) + + batches_mapping = { + policy_id1: policy1_batches, + policy_id2: policy2_batches, + policy_id3: policy3_batches, + } + + ma_batch = MultiAgentBatch( + { + policy_id1: concat_samples(policy1_batches), + policy_id2: concat_samples(policy2_batches), + policy_id3: concat_samples(policy3_batches), + }, + max_ep_len * 10, + ) + + ma_buffer.add(ma_batch) + + # check all are added properly + for policy_id in policy_ids: + assert policy_id in ma_buffer.buffers.keys() + + for policy_id, buffer in ma_buffer.buffers.items(): + assert policy_id in policy_ids + for i in range(10): + test_utils.check( + batches_mapping[policy_id][i], _get_internal_buffer(buffer)[i] + ) + + # check that sampling are proper + for _ in range(50): + ma_sample = ma_buffer.sample(100) + for policy_id in policy_ids: + assert policy_id in ma_sample.policy_batches.keys() + + for policy_id, batch in ma_sample.policy_batches.items(): + eps_id_start = (int(policy_id) - 1) * 10 + eps_id_end = eps_id_start + 10 + + _assert_sample_batch_keys(batch) + + for i in range(100): + # Obs generated by _generate_episode_batch contains eps_id. + # Use -1 index because there might be front padding + eps_id = int(batch[SampleBatch.OBS][i, -1, 0]) + assert ( + eps_id_start <= eps_id < eps_id_end + ), "batch within multi agent batch has the wrong agent's episode." + + # sample twice and make sure they are not equal (probability equal almost zero) + ma_sample1 = ma_buffer.sample(200) + ma_sample2 = ma_buffer.sample(200) + for policy_id in policy_ids: + _assert_sample_batch_not_equal( + ma_sample1.policy_batches[policy_id], + ma_sample2.policy_batches[policy_id], + ) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index d065cffa72f1..4aca5ea9b53f 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -49,6 +49,10 @@ class SampleBatch(dict): OBS_EMBEDS = "obs_embeds" T = "t" + # decision transformer + RETURNS_TO_GO = "returns_to_go" + ATTENTION_MASKS = "attention_masks" + # Extra action fetches keys. ACTION_DIST_INPUTS = "action_dist_inputs" ACTION_PROB = "action_prob" @@ -335,10 +339,13 @@ def shuffle(self) -> "SampleBatch": return self @PublicAPI - def split_by_episode(self) -> List["SampleBatch"]: + def split_by_episode(self, key: Optional[str] = None) -> List["SampleBatch"]: """Splits by `eps_id` column and returns list of new batches. If `eps_id` is not present, splits by `dones` instead. + Args: + key: If specified, overwrite default and use key to split. + Returns: List of batches, one per distinct episode. @@ -370,8 +377,8 @@ def split_by_episode(self) -> List["SampleBatch"]: [{"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]}] """ - slices = [] - if SampleBatch.EPS_ID in self: + def slice_by_eps_id(): + slices = [] # Produce a new slice whenever we find a new episode ID. cur_eps_id = self[SampleBatch.EPS_ID][0] offset = 0 @@ -383,9 +390,10 @@ def split_by_episode(self) -> List["SampleBatch"]: cur_eps_id = next_eps_id # Add final slice. slices.append(self[offset : self.count]) + return slices - # No eps_id in data -> split by dones instead - elif SampleBatch.DONES in self: + def slice_by_dones(): + slices = [] offset = 0 for i in range(self.count): if self[SampleBatch.DONES][i]: @@ -397,8 +405,30 @@ def split_by_episode(self) -> List["SampleBatch"]: # Add final slice. if offset != self.count: slices.append(self[offset:]) + return slices + + key_to_method = { + SampleBatch.EPS_ID: slice_by_eps_id, + SampleBatch.DONES: slice_by_dones, + } + + # If key not specified, default to this order. + key_resolve_order = [SampleBatch.EPS_ID, SampleBatch.DONES] + + slices = None + if key is not None: + # If key specified, directly use it. + if key not in self: + raise KeyError(f"{self} does not have key `{key}`!") + slices = key_to_method[key]() else: - raise KeyError(f"{self} does not have `eps_id` or `dones`!") + # If key not specified, go in order. + for key in key_resolve_order: + if key in self: + slices = key_to_method[key]() + break + if slices is None: + raise KeyError(f"{self} does not have keys {key_resolve_order}!") assert ( sum(s.count for s in slices) == self.count diff --git a/rllib/policy/tests/test_sample_batch.py b/rllib/policy/tests/test_sample_batch.py index e4be363b0cbe..23909abdc27f 100644 --- a/rllib/policy/tests/test_sample_batch.py +++ b/rllib/policy/tests/test_sample_batch.py @@ -282,6 +282,14 @@ def test_split_by_episode(self): eps_split = [b["a"] for b in s.split_by_episode()] check(true_split, eps_split) + # Check that splitting by EPS_ID works correctly when explicitly specified + eps_split = [b["a"] for b in s.split_by_episode(key="eps_id")] + check(true_split, eps_split) + + # Check that splitting by DONES works correctly when explicitly specified + eps_split = [b["a"] for b in s.split_by_episode(key="dones")] + check(true_split, eps_split) + # Check that splitting by DONES works correctly del s["eps_id"] dones_split = [b["a"] for b in s.split_by_episode()]