From 882510bb6e77c3e1687003fcfbaf6986677bb3d7 Mon Sep 17 00:00:00 2001 From: simonsays1980 Date: Fri, 6 Sep 2024 18:11:21 +0200 Subject: [PATCH] [RLlib; Off-policy] Add episode sampling to `EpisodeReplayBuffer`. (#47500) Signed-off-by: ujjawal-khare --- rllib/utils/replay_buffers/episode_replay_buffer.py | 7 ------- rllib/utils/replay_buffers/prioritized_episode_buffer.py | 1 - 2 files changed, 8 deletions(-) diff --git a/rllib/utils/replay_buffers/episode_replay_buffer.py b/rllib/utils/replay_buffers/episode_replay_buffer.py index 4a3309b301d9..90ae1fa35664 100644 --- a/rllib/utils/replay_buffers/episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/episode_replay_buffer.py @@ -218,7 +218,6 @@ def sample( include_infos: bool = False, include_extra_model_outputs: bool = False, sample_episodes: Optional[bool] = False, - finalize: bool = False, **kwargs, ) -> Union[SampleBatchType, SingleAgentEpisode]: """Samples from a buffer in a randomized way. @@ -263,7 +262,6 @@ def sample( actual state of model e.g. action log-probabilities, etc.). If `True`, the extra model outputs at the `"obs"` in the batch is included (the timestep at which the action is computed). - finalize: If episodes should be finalized. Returns: Either a batch with transitions in each row or (if `return_episodes=True`) @@ -281,7 +279,6 @@ def sample( gamma=gamma, include_infos=include_infos, include_extra_model_outputs=include_extra_model_outputs, - finalize=finalize, ) else: return self._sample_batch( @@ -427,7 +424,6 @@ def _sample_episodes( gamma: float = 0.99, include_infos: bool = False, include_extra_model_outputs: bool = False, - finalize: bool = False, **kwargs, ) -> List[SingleAgentEpisode]: """Samples episodes from a buffer in a randomized way. @@ -472,7 +468,6 @@ def _sample_episodes( actual state of model e.g. action log-probabilities, etc.). If `True`, the extra model outputs at the `"obs"` in the batch is included (the timestep at which the action is computed). - finalize: If episodes should be finalized. Returns: A list of 1-step long episodes containing all basic episode data and if @@ -583,8 +578,6 @@ def _sample_episodes( len_lookback_buffer=0, t_started=episode_ts, ) - if finalize: - sampled_episode.finalize() sampled_episodes.append(sampled_episode) # Increment counter. diff --git a/rllib/utils/replay_buffers/prioritized_episode_buffer.py b/rllib/utils/replay_buffers/prioritized_episode_buffer.py index 8c5160068beb..8ce1d2770793 100644 --- a/rllib/utils/replay_buffers/prioritized_episode_buffer.py +++ b/rllib/utils/replay_buffers/prioritized_episode_buffer.py @@ -310,7 +310,6 @@ def sample( gamma: float = 0.99, include_infos: bool = False, include_extra_model_outputs: bool = False, - finalize: bool = False, **kwargs, ) -> SampleBatchType: """Samples from a buffer in a prioritized way.