-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] before_sub_environment_reset()
callback enhancements (add next_episode
arg).
#28600
[RLlib] before_sub_environment_reset()
callback enhancements (add next_episode
arg).
#28600
Conversation
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…re_sub_env_reset_callback Signed-off-by: sven1977 <[email protected]> # Conflicts: # rllib/algorithms/callbacks.py # rllib/algorithms/tests/test_callbacks.py # rllib/evaluation/env_runner_v2.py # rllib/evaluation/sampler.py
rllib/evaluation/env_runner_v2.py
Outdated
@@ -791,6 +792,8 @@ def _handle_done_episode( | |||
env_id | |||
], | |||
env_index=env_id, | |||
# Create new episode under this env_id. | |||
next_episode=self._active_episodes[env_id], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I chatted with wes. if this is what they need, we don't really need to introduce a new callback, since doing self._active_episodes[env_id]
here would trigger the on_episode_start()
call.
essentially, what we are doing here is moving the on_episode_start()
call to happen before reset().
so all we have to do is to make the self._active_episodes[env_id]
call here, with a comment that says Assign the policy mapping for the next episode before env.reset() call
.
does this make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As per our discussion with Wes yesterday.
- We add a new callback:
on_episode_created
, which is triggered right after the Episode(V2)? instance has been instantiated, but before(!) the sub-environment is reset. - Then we reset the sub-environment.
- Then we trigger
on_episode_start
(like we did before). This way theon_episode_start
behavior remains unaltered/no API changes.
Signed-off-by: sven1977 <[email protected]>
…re_sub_env_reset_callback
Signed-off-by: sven1977 <[email protected]>
@@ -108,26 +108,36 @@ def on_sub_environment_created( | |||
pass | |||
|
|||
@OverrideToImplementCustomLogic | |||
def before_sub_environment_reset( | |||
def on_episode_created( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed the callback as discussed.
self, | ||
*, | ||
worker: "RolloutWorker", | ||
sub_environment: EnvType, | ||
base_env: BaseEnv, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same signature as on_episode_start
@@ -131,15 +141,17 @@ def test_before_sub_environment_reset(self): | |||
.callbacks(BeforeSubEnvironmentResetCallback) | |||
) | |||
|
|||
for _ in framework_iterator(config, frameworks=("tf", "torch")): | |||
# Test with and without Connectors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Testing w/ and w/o connectors more important than different frameworks (env-runners are framework-agnostic).
@@ -394,15 +396,12 @@ def run(self) -> Iterator[SampleBatchType]: | |||
and other fields as dictated by `policy`. | |||
""" | |||
# Before the very first poll (this will reset all vector sub-environments): | |||
# Call custom `before_sub_environment_reset` callbacks for all sub-environments. | |||
# Create all upcoming episodes and call `on_episode_created` callbacks for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Create initial episodes and do callbacks.
@@ -854,10 +850,35 @@ def _handle_done_episode( | |||
# Step after adding initial obs. This will give us 0 env and agent step. | |||
new_episode.step() | |||
|
|||
def create_episode(self, env_id: EnvID) -> EpisodeV2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Helper.
@@ -856,7 +831,11 @@ def _process_observations( | |||
# This will be filled with dummy observations below. | |||
all_agents_obs = {} | |||
|
|||
if not is_new_episode: | |||
# If this episode is brand-new, call the episode start callback(s). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needed to add this flag here to the old Episode
class. Otherwise, we would never start increasing the length
property. This is not a problem for EpisodeV2.
rllib/evaluation/env_runner_v2.py
Outdated
""" | ||
# Create a new episode under the same `env_id` and call the | ||
# `on_episode_created` callbacks. | ||
new_episode = self._active_episodes[env_id] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, I would like to get rid of active_episodes being a default_dict
. I think it adds a lot of confusion here, making things happen under the hood w/o tight control by the env_runner itself.
Signed-off-by: sven1977 <[email protected]>
…re_sub_env_reset_callback
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool, still have some minor questions / comments.
looking solid.
rllib/algorithms/callbacks.py
Outdated
This method gets called before every `try_reset()` is called by RLlib | ||
on a sub-environment (usually a gym.Env). This includes the very first (initial) | ||
reset performed on each sub-environment. | ||
episode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a typo? should remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can I suggest we also update the doc string for on_episode_start() to point out the difference between it and on_episode_created()? basically, on_episode_start() gets called after base_env.try_reset() is done.
since users may be reading about on_episode_start() without noticing the details here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, sorry, a leftover. :) Removed.
Cleaned up both docstrings and added exact sequence of events to both.
@@ -530,6 +516,10 @@ def _process_observations( | |||
continue | |||
|
|||
episode: EpisodeV2 = self._active_episodes[env_id] | |||
# If this episode is brand-new, call the episode start callback(s). | |||
# Note: EpisodeV2s are initialized with length=-1 (before the reset). | |||
if episode.length == -1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can I suggest we don't rely on the internals of EpisodeV2 directly?
it will be better if we create an API on EpisodeV2 similar to Episode
@property
def started(self) -> bool:
return bool(self._has_init_obs)
then we can do:
if not episode.started:
self._call_on_episode_start(episode, env_id)
it's likely safer if we rely on self._has_init_obs rather than the initial value of length?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another quick question is why do we do this here.
why don't we self._call_on_episode_start(episode, env_id) right after the reset() op is done?
although we may need to do this at a couple of places, I feel like it's mentally easier if they happen one after another right away?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have to do this here due to the different behavior of ray-remote envs (which return a ASYNC_RESET_RETURN
upon reset and only publish those reset results via the next poll
call, other than "normal", non-remote envs).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do next_episode_length = episode.length + 1
right after this, though :)
Also accessing the episode's internal properties.
Either way: Fixed it, added the suggested property.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
huh, remote env ... 😢
thanks for the fix though. feels a bit better we don't rely on that -1.
…re_sub_env_reset_callback
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for powering through all the changes :)
a lot of tests failing though, they look related. |
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.