Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] before_sub_environment_reset() callback enhancements (add next_episode arg). #28600

Merged
merged 22 commits into from
Sep 23, 2022

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Sep 19, 2022

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…re_sub_env_reset_callback

Signed-off-by: sven1977 <[email protected]>

# Conflicts:
#	rllib/algorithms/callbacks.py
#	rllib/algorithms/tests/test_callbacks.py
#	rllib/evaluation/env_runner_v2.py
#	rllib/evaluation/sampler.py
@@ -791,6 +792,8 @@ def _handle_done_episode(
env_id
],
env_index=env_id,
# Create new episode under this env_id.
next_episode=self._active_episodes[env_id],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I chatted with wes. if this is what they need, we don't really need to introduce a new callback, since doing self._active_episodes[env_id] here would trigger the on_episode_start() call.
essentially, what we are doing here is moving the on_episode_start() call to happen before reset().
so all we have to do is to make the self._active_episodes[env_id] call here, with a comment that says Assign the policy mapping for the next episode before env.reset() call.
does this make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per our discussion with Wes yesterday.

  • We add a new callback: on_episode_created, which is triggered right after the Episode(V2)? instance has been instantiated, but before(!) the sub-environment is reset.
  • Then we reset the sub-environment.
  • Then we trigger on_episode_start (like we did before). This way the on_episode_start behavior remains unaltered/no API changes.

@@ -108,26 +108,36 @@ def on_sub_environment_created(
pass

@OverrideToImplementCustomLogic
def before_sub_environment_reset(
def on_episode_created(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed the callback as discussed.

self,
*,
worker: "RolloutWorker",
sub_environment: EnvType,
base_env: BaseEnv,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same signature as on_episode_start

@@ -131,15 +141,17 @@ def test_before_sub_environment_reset(self):
.callbacks(BeforeSubEnvironmentResetCallback)
)

for _ in framework_iterator(config, frameworks=("tf", "torch")):
# Test with and without Connectors.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing w/ and w/o connectors more important than different frameworks (env-runners are framework-agnostic).

@@ -394,15 +396,12 @@ def run(self) -> Iterator[SampleBatchType]:
and other fields as dictated by `policy`.
"""
# Before the very first poll (this will reset all vector sub-environments):
# Call custom `before_sub_environment_reset` callbacks for all sub-environments.
# Create all upcoming episodes and call `on_episode_created` callbacks for
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create initial episodes and do callbacks.

@@ -854,10 +850,35 @@ def _handle_done_episode(
# Step after adding initial obs. This will give us 0 env and agent step.
new_episode.step()

def create_episode(self, env_id: EnvID) -> EpisodeV2:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Helper.

@@ -856,7 +831,11 @@ def _process_observations(
# This will be filled with dummy observations below.
all_agents_obs = {}

if not is_new_episode:
# If this episode is brand-new, call the episode start callback(s).
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to add this flag here to the old Episode class. Otherwise, we would never start increasing the length property. This is not a problem for EpisodeV2.

"""
# Create a new episode under the same `env_id` and call the
# `on_episode_created` callbacks.
new_episode = self._active_episodes[env_id]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, I would like to get rid of active_episodes being a default_dict. I think it adds a lot of confusion here, making things happen under the hood w/o tight control by the env_runner itself.

Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Copy link
Member

@gjoliver gjoliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, still have some minor questions / comments.
looking solid.

This method gets called before every `try_reset()` is called by RLlib
on a sub-environment (usually a gym.Env). This includes the very first (initial)
reset performed on each sub-environment.
episode
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a typo? should remove?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can I suggest we also update the doc string for on_episode_start() to point out the difference between it and on_episode_created()? basically, on_episode_start() gets called after base_env.try_reset() is done.
since users may be reading about on_episode_start() without noticing the details here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sorry, a leftover. :) Removed.

Cleaned up both docstrings and added exact sequence of events to both.

@@ -530,6 +516,10 @@ def _process_observations(
continue

episode: EpisodeV2 = self._active_episodes[env_id]
# If this episode is brand-new, call the episode start callback(s).
# Note: EpisodeV2s are initialized with length=-1 (before the reset).
if episode.length == -1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can I suggest we don't rely on the internals of EpisodeV2 directly?
it will be better if we create an API on EpisodeV2 similar to Episode

@property
def started(self) -> bool:
    return bool(self._has_init_obs)

then we can do:

if not episode.started:
    self._call_on_episode_start(episode, env_id)

it's likely safer if we rely on self._has_init_obs rather than the initial value of length?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another quick question is why do we do this here.
why don't we self._call_on_episode_start(episode, env_id) right after the reset() op is done?
although we may need to do this at a couple of places, I feel like it's mentally easier if they happen one after another right away?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have to do this here due to the different behavior of ray-remote envs (which return a ASYNC_RESET_RETURN upon reset and only publish those reset results via the next poll call, other than "normal", non-remote envs).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do next_episode_length = episode.length + 1 right after this, though :)
Also accessing the episode's internal properties.

Either way: Fixed it, added the suggested property.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh, remote env ... 😢
thanks for the fix though. feels a bit better we don't rely on that -1.

Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Copy link
Member

@gjoliver gjoliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for powering through all the changes :)

@gjoliver
Copy link
Member

a lot of tests failing though, they look related.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants