Skip to content

Commit

Permalink
[RLlib] Fix calling of callback on_episode_created to conform to do…
Browse files Browse the repository at this point in the history
…cstring (after reset). (ray-project#45651)

Signed-off-by: Richard Liu <[email protected]>
  • Loading branch information
simonsays1980 authored and richardsliu committed Jun 12, 2024
1 parent 4a14bfe commit 9eed717
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
11 changes: 7 additions & 4 deletions rllib/algorithms/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,19 +245,22 @@ def on_episode_created(
"""Callback run when a new episode is created (but has not started yet!).
This method gets called after a new Episode(V2) (old stack) or
SingleAgentEpisode/MultiAgentEpisode instance has been created.
MultiAgentEpisode instance has been created.
This happens before the respective sub-environment's (usually a gym.Env)
`reset()` is called by RLlib.
1) Episode(V2)/Single-/MultiAgentEpisode created: This callback is called.
Note, at the moment this callback does not get called in the new API stack
and single-agent mode.
1) Episode(V2)/MultiAgentEpisode created: This callback is called.
2) Respective sub-environment (gym.Env) is `reset()`.
3) Callback `on_episode_start` is called.
4) Stepping through sub-environment/episode commences.
Args:
episode: The newly created episode. On the new API stack, this will be a
SingleAgentEpisode or MultiAgentEpisode object. On the old API stack,
this will be a Episode or EpisodeV2 object.
MultiAgentEpisode object. On the old API stack, this will be a
Episode or EpisodeV2 object.
This is the episode that is about to be started with an upcoming
`env.reset()`. Only after this reset call, the `on_episode_start`
callback will be called.
Expand Down
7 changes: 3 additions & 4 deletions rllib/env/multi_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,17 +414,16 @@ def _sample_episodes(

done_episodes_to_return: List[MultiAgentEpisode] = []

# Reset the environment.
# TODO (simon): Check, if we need here the seed from the config.
obs, infos = self.env.reset()

# Create a new multi-agent episode.
_episode = self._new_episode()
self._make_on_episode_callback("on_episode_created", _episode)
_shared_data = {
"agent_to_module_mapping_fn": self.config.policy_mapping_fn,
}

# Reset the environment.
# TODO (simon): Check, if we need here the seed from the config.
obs, infos = self.env.reset()
# Set initial obs and infos in the episodes.
_episode.add_env_reset(observations=obs, infos=infos)
self._make_on_episode_callback("on_episode_start", _episode)
Expand Down
14 changes: 8 additions & 6 deletions rllib/env/single_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,12 @@ def _sample_timesteps(

# Have to reset the env (on all vector sub_envs).
if force_reset or self._needs_initial_reset:
# Create n new episodes and make the `on_episode_created` callbacks.
# Create n new episodes.
# TODO (sven): Add callback `on_episode_created` as soon as
# `gymnasium-v1.0.0a2` PR is coming.
self._episodes = []
for env_index in range(self.num_envs):
self._episodes.append(self._new_episode())
self._make_on_episode_callback("on_episode_created", env_index)
self._shared_data = {}

# Erase all cached ongoing episodes (these will never be completed and
Expand Down Expand Up @@ -437,15 +438,16 @@ def _sample_episodes(

done_episodes_to_return: List[SingleAgentEpisode] = []

# Reset the environment.
# TODO (simon): Check, if we need here the seed from the config.
obs, infos = self.env.reset()
episodes = []
for env_index in range(self.num_envs):
episodes.append(self._new_episode())
self._make_on_episode_callback("on_episode_created", env_index, episodes)
# TODO (sven): Add callback `on_episode_created` as soon as
# `gymnasium-v1.0.0a2` PR is coming.
_shared_data = {}

# Reset the environment.
# TODO (simon): Check, if we need here the seed from the config.
obs, infos = self.env.reset()
for env_index in range(self.num_envs):
episodes[env_index].add_env_reset(
observation=obs[env_index],
Expand Down

0 comments on commit 9eed717

Please sign in to comment.