-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Action masking example for new API stack. #46146
[RLlib] Action masking example for new API stack. #46146
Conversation
…add_env_reset' calls in 'SingleAgentEnvRunner' to deal with 'Dict' observation spaces. Signed-off-by: simonsays1980 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super cool PR! One more example script TODO down.
We need to add this example to BUILD as well.
@@ -447,7 +447,7 @@ def _sample_episodes( | |||
obs, infos = self.env.reset() | |||
for env_index in range(self.num_envs): | |||
episodes[env_index].add_env_reset( | |||
observation=obs[env_index], | |||
observation=unbatch(obs)[env_index], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strange: Shouldn't this already have caused a bug in the existing flatten obs example (which also uses a dict obs space)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, the flatten obs example silently snuck around it: The unbatch(obs)
was only forgotten in _sample_episodes
and not in _sample_timesteps
(has cost me a lot of time to figure it out thoug). Because the flatten obs does no evaluation - it never ran into it :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahhh, yes, makes perfect sense. Thanks for catching this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With ray==2.24.0
, batch_mode="complete_episodes"
will cause the bug. And obs = unbatch(obs)
should outter the for-loop like what _sample_timesteps
do:
obs, infos = self.env.reset()
obs = unbatch(obs)
for env_index in range(self.num_envs):
episodes[env_index].add_env_reset(
observation=obs[env_index],
…ngs and comments. Added action-masking and autoregressive actions examples to the BUILD. Signed-off-by: simonsays1980 <[email protected]>
…ll be deprecated in very near future. Signed-off-by: simonsays1980 <[email protected]>
Signed-off-by: simonsays1980 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the additional fixes and answers @simonsays1980 . Very nice PR!
…on masking module naming. Signed-off-by: simonsays1980 <[email protected]>
…nter. Signed-off-by: simonsays1980 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approval for CI files.
…ong name in the BUILD file. Signed-off-by: simonsays1980 <[email protected]>
Why are these changes needed?
This PR adds an example for using action masking in the new API stack to the repository. In addition it makes a small change to the
SingleAgentEnvRunner
to deal withDict
observation spaces.Related issue number
Closes #44780 #44452
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.