-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Use PyTorch vectorized max() and sum() in SampleBatch.__init__ when possible #28388
[RLlib] Use PyTorch vectorized max() and sum() in SampleBatch.__init__ when possible #28388
Conversation
…ossible Signed-off-by: Cassidy Laidlaw <[email protected]>
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 14 days if no further activity occurs. Thank you for your contributions.
|
It looks like the only tests that are failing in CI are ones that are flaky. Can somebody review this? |
Signed-off-by: Cassidy Laidlaw <[email protected]>
Signed-off-by: Cassidy Laidlaw <[email protected]>
2d473bc
to
f804275
Compare
Updated the PR and tests seem to be passing again. Can anyone look over and/or merge? @sven1977 @gjoliver @avnishn @ArturNiederfahrenhorst @smorad @maxpumperla @kouroshHakha @krfricke |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sound reasonable to me.
Another bump—will this be merged soon? |
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 14 days if no further activity occurs. Thank you for your contributions.
|
Is there anything I can do to get this merged? @kouroshHakha |
cool. looks like an awesome change for torch. |
…_ when possible (#28388) Signed-off-by: Cassidy Laidlaw <[email protected]>
It shouldn't make things any slower in TF. The PyTorch vectorized max/sum are only used if the tensors are from torch, otherwise the code runs the same as before the PR. |
Why are these changes needed?
Currently, there are two lines in
SampleBatch.__init__
where the Python builtinssum
andmax
are used onseq_lens
:self.max_seq_len = max(seq_lens_)
self.count = sum(self[SampleBatch.SEQ_LENS])
However, if the
seq_lens
are a PyTorch tensor on the GPU, this is incredibly slow. I believe this is because each element ofseq_lens
has to be fetched independently from the GPU when iterating over the tensor forsum
andmax
. Thus, I have changed the two lines to useseq_lens_.max().item()
andself[SampleBatch.SEQ_LENS].sum().item()
for PyTorch tensors.This significantly speeds up training with models that require state. For instance, consider the following training run:
Before this PR, it takes 22.7 s to run, but after only 13.5 s. If look at
result["timers"]["learn_time_ms"]
, we can also see that it is8042
before the PR but3162
after. Thus in some cases we're getting more than a 2x speedup for SGD!Related issue number
I haven't opened an issue.
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.