Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Action masking example for new API stack. #46146

Merged
Merged
3 changes: 2 additions & 1 deletion rllib/env/single_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def _sample_episodes(
obs, infos = self.env.reset()
for env_index in range(self.num_envs):
episodes[env_index].add_env_reset(
observation=obs[env_index],
observation=unbatch(obs)[env_index],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strange: Shouldn't this already have caused a bug in the existing flatten obs example (which also uses a dict obs space)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, the flatten obs example silently snuck around it: The unbatch(obs) was only forgotten in _sample_episodes and not in _sample_timesteps (has cost me a lot of time to figure it out thoug). Because the flatten obs does no evaluation - it never ran into it :D

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh, yes, makes perfect sense. Thanks for catching this!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With ray==2.24.0, batch_mode="complete_episodes" will cause the bug. And obs = unbatch(obs) should outter the for-loop like what _sample_timesteps do:

        obs, infos = self.env.reset()
        obs = unbatch(obs)
        for env_index in range(self.num_envs):
            episodes[env_index].add_env_reset(
                observation=obs[env_index],

infos=infos[env_index],
)
self._make_on_episode_callback("on_episode_start", env_index, episodes)
Expand Down Expand Up @@ -724,6 +724,7 @@ def make_env(self) -> None:
asynchronous=self.config.remote_worker_envs,
)
)

self.num_envs: int = self.env.num_envs
assert self.num_envs == self.config.num_envs_per_env_runner

Expand Down
149 changes: 0 additions & 149 deletions rllib/examples/action_masking.py

This file was deleted.

Loading
Loading