Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable autograd graph to propagate after multi-device syncing (for loss functions in ddp) #2754

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

cw-tan
Copy link

@cw-tan cw-tan commented Sep 17, 2024

What does this PR do?

Single-line enhancement proposed in #2745, that is, to enable the propagation of the autograd graph after the all_gather operation. This is useful for syncing loss functions in a ddp setting.

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃


📚 Documentation preview 📚: https://torchmetrics--2754.org.readthedocs.build/en/2754/

@Borda
Copy link
Member

Borda commented Sep 17, 2024

That sounds good to me, but can we add a test for this enhancement?

@cw-tan
Copy link
Author

cw-tan commented Sep 17, 2024

That sounds good to me, but can we add a test for this enhancement?

Thanks for the prompt response @Borda.

I'm thinking that _test_ddp_gather_uneven_tensors (here) and _test_ddp_gather_uneven_tensors_multidim (here) in tests/unittests/bases/test_ddp.py already cover the correctness of gather_all_tensors. I'm not sure what other ddp tests there are, but those tests should help tell us if the change I made isn't breaking existing functionality. Let me know if you had something else in mind for this.

I can make an additional unittest in tests/unittests/bases/test_ddp.py to give a tensor that requires_grad to gather_all_tensors, compute some scalar from them (proxy for a loss), and compute grads two ways (one going through the all_gather, one that doesn't) and compare. So this tests that the change achieves the desired effect. How does that sound?

Copy link

codecov bot commented Sep 17, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 69%. Comparing base (748caee) to head (af23080).

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #2754    +/-   ##
=======================================
- Coverage      69%     69%    -0%     
=======================================
  Files         329     316    -13     
  Lines       18077   17914   -163     
=======================================
- Hits        12496   12336   -160     
+ Misses       5581    5578     -3     

@Borda
Copy link
Member

Borda commented Sep 17, 2024

I can make an additional unittest in tests/unittests/bases/test_ddp.py to give a tensor that requires_grad to gather_all_tensors, compute some scalar from them (proxy for a loss), and compute grads two ways (one going through the all_gather, one that doesn't) and compare. So this tests that the change achieves the desired effect. How does that sound?

yeah, that sounds good to me :)

@Borda Borda added the enhancement New feature or request label Sep 17, 2024
@cw-tan cw-tan force-pushed the all_gather_ad branch 3 times, most recently from 1ba6fb3 to 6598ab8 Compare September 18, 2024 00:40
@cw-tan
Copy link
Author

cw-tan commented Sep 18, 2024

Update: to accommodate both cases where tensors from different ranks have the same/different shape, the line to put the original tensor (holding the AD graph) back into the gathered list was added in two places in the code.

Because of the two cases, I wrote two unittests to account for each. Interestingly, both pass 2.X stable, but for 1.X LTS, the "same shape" test passes but "different shape" test fails, and for 1.10 oldest, the "different shape" test passes but "same shape" test fails😅. I'll double check for bugs, but the actual code change is just two lines (and all other tests pass, so existing functionality still works), and the unittests are pretty short. The dependency of the unittests passing on different torch versions seems to indicate that it might be a torch versioning issue, maybe to do with ddp behavior? Any thoughts, @Borda ?

@Borda
Copy link
Member

Borda commented Sep 19, 2024

I wrote two unittests to account for each. Interestingly, both pass 2.X stable, but for 1.X LTS, the "same shape" test passes but "different shape" test fails, and for 1.10 oldest, the "different shape" test passes but "same shape" test fails😅.

that is strange and worse some more investigation...
cc: @SkafteNicki

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants