Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoids setting jax tracer as lazy property attribute #1843

Merged
merged 8 commits into from
Aug 9, 2024

Conversation

sbidari
Copy link
Contributor

@sbidari sbidari commented Aug 6, 2024

Add a conditional if numpyro.util.not_jax_tracer(value) before setting lazy_property as an attribute of TruncatedDistribution

This resolves errors described in #1836 and CDCgov/PyRenew#282

tests added:
check valid sampling of TruncatedDistribution in parallel (failed previously as described in #1836 )
check predictive methods (prior predictive and inference) can be run multiple times on the same model built using TruncatedDistribution (error encountered here CDCgov/PyRenew#282)

@sbidari sbidari changed the title remove tracer as an attribute of truncated distribution avoids setting jax tracer as lazy property attribute Aug 6, 2024
Copy link
Contributor

@dylanhmorris dylanhmorris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this, @sbidari! Two thoughts:

  • I think some aspects of the second test function have been unnecessarily carried over from the first test function and can be removed. See below.
  • I wonder whether it might be worth annotating the test functions (via their docstrings) to explain why they should run without error if things are working and why/how they fail prior without the patch in this PR

test/test_distributions.py Outdated Show resolved Hide resolved
test/test_distributions.py Outdated Show resolved Hide resolved
test/test_distributions.py Outdated Show resolved Hide resolved
test/test_distributions.py Outdated Show resolved Hide resolved
@fehiepsi fehiepsi marked this pull request as ready for review August 7, 2024 04:50
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending reviews from @dylanhmorris. Thanks @sbidari!

@sbidari
Copy link
Contributor Author

sbidari commented Aug 7, 2024

Thanks @fehiepsi!

@dylanhmorris I added docstring with links to where issues are described and combined the two test functions.

Copy link
Contributor

@dylanhmorris dylanhmorris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @sbidari. One suggestion you can take or leave on naming the test (since it should pass if there isn't tracer leakage). @fehiepsi: this LGTM.

test/test_distributions.py Outdated Show resolved Hide resolved
Co-authored-by: Dylan H. Morris <[email protected]>
@sbidari
Copy link
Contributor Author

sbidari commented Aug 7, 2024

I am not sure why the test is failing in CI with UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. despite using numpyro.set_host_device_count(2) at the beginning. Any thoughts @fehiepsi?

test/test_distributions.py Outdated Show resolved Hide resolved
@dylanhmorris
Copy link
Contributor

Noting here that when this lands it will also provide a more general fix for #1651

sbidari and others added 2 commits August 8, 2024 09:50
Move test from test_distributions.py to test_distributions_util.py
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sbidari!

@fehiepsi fehiepsi merged commit d61f15c into pyro-ppl:master Aug 9, 2024
4 checks passed
@sbidari sbidari deleted the remove-tracer-as-attribute branch August 10, 2024 04:50
@damonbayer
Copy link
Contributor

@fehiepsi Is there a release planned for this fix? We would appreciate having it for cdcgov/multisignal-epi-inference as soon as possible.

@fehiepsi
Copy link
Member

I can make a patch release in upcoming days. In the mean time, you can patch the utility.

@damonbayer
Copy link
Contributor

@fehiepsi Do you have a date in mind?

@fehiepsi fehiepsi mentioned this pull request Sep 5, 2024
dlp-rb added a commit to rockerbox/numpyro that referenced this pull request Sep 10, 2024
This commit applies the changes Subekshya Bidari authored in [Numpyro PR
number 1843][pr] to fix using a TruncatedNormal distribution (or any
other TwoSidedTruncatedDistribution) when running multiple chains in
parallel using the NUTs MCMC sampler.

[pr]: pyro-ppl#1843
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants