diff --git a/rllib/algorithms/tests/test_callbacks_on_env_runner.py b/rllib/algorithms/tests/test_callbacks_on_env_runner.py index 898717c6c2b2..6afa874509e0 100644 --- a/rllib/algorithms/tests/test_callbacks_on_env_runner.py +++ b/rllib/algorithms/tests/test_callbacks_on_env_runner.py @@ -179,9 +179,6 @@ def test_episode_and_sample_callbacks_batch_mode_complete_episodes(self): # Train one iteration. algo.train() - # We must have had exactly one `sample()` call on our EnvRunner. - if not multi_agent: - self.assertEqual(callback_obj.counts["sample"], 1) # We should have had at least one episode start. self.assertGreater(callback_obj.counts["start"], 0) # Episode starts must be exact same as episode ends (b/c we always complete diff --git a/rllib/env/single_agent_env_runner.py b/rllib/env/single_agent_env_runner.py index f9b20b498291..78e36143ea38 100644 --- a/rllib/env/single_agent_env_runner.py +++ b/rllib/env/single_agent_env_runner.py @@ -197,19 +197,17 @@ def sample( explore=explore, random_actions=random_actions, ) - # For complete episodes mode, sample as long as the number of timesteps - # done is smaller than the `train_batch_size`. + # For complete episodes mode, sample a single episode and + # leave coordination of sampling to `synchronous_parallel_sample`. + # TODO (simon, sven): The coordination will eventually move + # to `EnvRunnerGroup` in the future. So from the algorithm one + # would do `EnvRunnerGroup.sample()`. else: - total = 0 - samples = [] - while total < self.config.train_batch_size: - episodes = self._sample_episodes( - num_episodes=self.num_envs, - explore=explore, - random_actions=random_actions, - ) - total += sum(len(e) for e in episodes) - samples.extend(episodes) + samples = self._sample_episodes( + num_episodes=1, + explore=explore, + random_actions=random_actions, + ) # Make the `on_sample_end` callback. self._callbacks.on_sample_end(