Skip to content

Commit

Permalink
[RLlib] Moving sampling coordination for `batch_mode=complete_episode…
Browse files Browse the repository at this point in the history
…s` to `synchronous_parallel_sample`. (#46321)
  • Loading branch information
simonsays1980 authored Jul 4, 2024
1 parent 2432b62 commit 3bdcab6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 15 deletions.
3 changes: 0 additions & 3 deletions rllib/algorithms/tests/test_callbacks_on_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,6 @@ def test_episode_and_sample_callbacks_batch_mode_complete_episodes(self):

# Train one iteration.
algo.train()
# We must have had exactly one `sample()` call on our EnvRunner.
if not multi_agent:
self.assertEqual(callback_obj.counts["sample"], 1)
# We should have had at least one episode start.
self.assertGreater(callback_obj.counts["start"], 0)
# Episode starts must be exact same as episode ends (b/c we always complete
Expand Down
22 changes: 10 additions & 12 deletions rllib/env/single_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,19 +197,17 @@ def sample(
explore=explore,
random_actions=random_actions,
)
# For complete episodes mode, sample as long as the number of timesteps
# done is smaller than the `train_batch_size`.
# For complete episodes mode, sample a single episode and
# leave coordination of sampling to `synchronous_parallel_sample`.
# TODO (simon, sven): The coordination will eventually move
# to `EnvRunnerGroup` in the future. So from the algorithm one
# would do `EnvRunnerGroup.sample()`.
else:
total = 0
samples = []
while total < self.config.train_batch_size:
episodes = self._sample_episodes(
num_episodes=self.num_envs,
explore=explore,
random_actions=random_actions,
)
total += sum(len(e) for e in episodes)
samples.extend(episodes)
samples = self._sample_episodes(
num_episodes=1,
explore=explore,
random_actions=random_actions,
)

# Make the `on_sample_end` callback.
self._callbacks.on_sample_end(
Expand Down

0 comments on commit 3bdcab6

Please sign in to comment.