Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Hot fix for PPOTorchRLModule._compute_values with non-shared stateful encoder and batch slicing with non-empty infos. #44082

Conversation

simonsays1980
Copy link
Collaborator

@simonsays1980 simonsays1980 commented Mar 18, 2024

Why are these changes needed?

Running PPO with use_lstm=True and vf_share_layers=False results in an error in the PPOTorchRLModule._compute_values method as the specs checker expects a different spec for the state_in:

raise SpecCheckingError(
ray.rllib.core.models.specs.checker.SpecCheckingError: input spec validation failed on TorchLSTMEncoder.forward, The data dict does not match the model specs. Keys ('state_in', 'h') are in the spec dict but not on the data dict. Data keys are {('state_in', 'critic', 'c'), ('rewards',), ('action_logp',), ('state_in', 'actor', 'h'), ('state_in', 'critic', 'h'), ('terminateds',), ('obs',), ('truncateds',), ('vf_preds',), ('loss_mask',), ('seq_lens',), ('action_dist_inputs',), ('infos',), ('actions',), ('state_in', 'actor', 'c')}.

Exctracting the state_in for the critic solves this problem.

Another problem is solved related to non-empty infos in batch slicing (mainly occuring in MinibatchIterators). The reason is that slicing via tree.map_structure tries to slice also the entries of the infos which are usually singular values:

data = tree.map_structure(lambda value: value[start:stop], self)
IndexError: slice() cannot be applied to a 0-dim tensor.

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

…ompute_values' when using a non-shared stateful encoder. In addition fixed an error that occurs while slicing batches with non-empty infos.

Signed-off-by: Simon Zehnder <[email protected]>
@simonsays1980 simonsays1980 self-assigned this Mar 18, 2024
@simonsays1980 simonsays1980 added rllib RLlib related issues rllib-newstack labels Mar 18, 2024
@sven1977 sven1977 changed the title Hot fix for PPOTorchRLModule._compute_values with non-shared stateful encoder and batch slicing with non-empty infos [RLlib] Hot fix for PPOTorchRLModule._compute_values with non-shared stateful encoder and batch slicing with non-empty infos. Mar 18, 2024
@sven1977 sven1977 marked this pull request as ready for review March 18, 2024 11:12
@@ -716,7 +716,9 @@ def _batch_slice(self, slice_: slice) -> "SampleBatch":

# Exclude INFOs from regular array slicing as the data under this column might
# be a list (not good for `tree.map_structure` call).
infos = self.get(SampleBatch.INFOS)
# Furthermore, slicing does not work when the data in the column is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean a SampleBatch with B=0, correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

B>0. But in this case the infos are a list of dicts. When they are empty, tree.map_structure(infos) works, but when they are filled, tree.map_structure will fail as it tries to apply the slicing on singular values slicing fails.

Copy link
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for the fix @simonsays1980!

@@ -90,6 +90,10 @@ def _compute_values(self, batch, device=None):

# Separate vf-encoder.
if hasattr(self.encoder, "critic_encoder"):
if self.is_stateful():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Ran into this issue yesterday as well (and continued testing then with a shared value function :) ).

@sven1977 sven1977 merged commit fd0c148 into ray-project:master Mar 18, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rllib RLlib related issues rllib-newstack
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants