Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Use PyTorch vectorized max() and sum() in SampleBatch.__init__ when possible #28388

Merged
merged 4 commits into from
Jan 12, 2023

Conversation

cassidylaidlaw
Copy link
Contributor

@cassidylaidlaw cassidylaidlaw commented Sep 8, 2022

Why are these changes needed?

Currently, there are two lines in SampleBatch.__init__ where the Python builtins sum and max are used on seq_lens:

  1. self.max_seq_len = max(seq_lens_)
  2. self.count = sum(self[SampleBatch.SEQ_LENS])

However, if the seq_lens are a PyTorch tensor on the GPU, this is incredibly slow. I believe this is because each element of seq_lens has to be fetched independently from the GPU when iterating over the tensor for sum and max. Thus, I have changed the two lines to use seq_lens_.max().item() and self[SampleBatch.SEQ_LENS].sum().item() for PyTorch tensors.

This significantly speeds up training with models that require state. For instance, consider the following training run:

rllib train --run PPO --env CartPole-v0 \
--config '{"train_batch_size": 2000, "sgd_minibatch_size": 1000, "num_sgd_iter": 100, "framework": "torch", "num_workers": 10, "num_gpus": 1, "model": {"use_lstm": true}, "rollout_fragment_length": 1}' \
--stop '{"training_iteration": 1}'

Before this PR, it takes 22.7 s to run, but after only 13.5 s. If look at result["timers"]["learn_time_ms"], we can also see that it is 8042 before the PR but 3162 after. Thus in some cases we're getting more than a 2x speedup for SGD!

Related issue number

I haven't opened an issue.

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@stale
Copy link

stale bot commented Oct 29, 2022

This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 14 days if no further activity occurs. Thank you for your contributions.

  • If you'd like to keep this open, just leave any comment, and the stale label will be removed.

@stale stale bot added the stale The issue is stale. It will be closed within 7 days unless there are further conversation label Oct 29, 2022
@cassidylaidlaw
Copy link
Contributor Author

It looks like the only tests that are failing in CI are ones that are flaky. Can somebody review this?

@stale stale bot removed the stale The issue is stale. It will be closed within 7 days unless there are further conversation label Nov 18, 2022
@cassidylaidlaw
Copy link
Contributor Author

Updated the PR and tests seem to be passing again. Can anyone look over and/or merge? @sven1977 @gjoliver @avnishn @ArturNiederfahrenhorst @smorad @maxpumperla @kouroshHakha @krfricke

Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sound reasonable to me.

@cassidylaidlaw
Copy link
Contributor Author

Another bump—will this be merged soon?

@stale
Copy link

stale bot commented Dec 31, 2022

This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 14 days if no further activity occurs. Thank you for your contributions.

  • If you'd like to keep this open, just leave any comment, and the stale label will be removed.

@stale stale bot added the stale The issue is stale. It will be closed within 7 days unless there are further conversation label Dec 31, 2022
@cassidylaidlaw
Copy link
Contributor Author

Is there anything I can do to get this merged? @kouroshHakha

@kouroshHakha
Copy link
Contributor

kouroshHakha commented Jan 8, 2023

cc @gjoliver @sven1977

@stale stale bot removed stale The issue is stale. It will be closed within 7 days unless there are further conversation labels Jan 8, 2023
@gjoliver
Copy link
Member

cool. looks like an awesome change for torch.
does this make things slower for TF? just want to have an idea.

@gjoliver gjoliver merged commit 2bc8837 into ray-project:master Jan 12, 2023
AmeerHajAli pushed a commit that referenced this pull request Jan 12, 2023
@cassidylaidlaw
Copy link
Contributor Author

It shouldn't make things any slower in TF. The PyTorch vectorized max/sum are only used if the tensors are from torch, otherwise the code runs the same as before the PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants