Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blackjax sampler fix for breaking change / enable progress bar under parallel chain_method #7453

Merged
merged 7 commits into from
Aug 12, 2024

Conversation

andrewdipper
Copy link
Contributor

@andrewdipper andrewdipper commented Aug 9, 2024

blackjax-devs/blackjax#712 changes the expected jax.lax.scan carry in progress_bar_scan. Since pymc's external blackjax sampler directly uses progress_bar_scan it will break when progressbar=True. This change switches to use a new wrapper to hide the progress bar details. In addition it enables the use of progress bars under chain_method="parallel".

I think any breaking issues can be handled by restricting blackjax version numbers. However, I'm not sure how to properly do that?

And of course for now tests are expected to fail until the changes show in a blackjax release.

PRs that are dependencies:
blackjax-devs/blackjax#712
blackjax-devs/blackjax#716

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7453.org.readthedocs.build/en/7453/

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 10, 2024

I think any breaking issues can be handled by restricting blackjax version numbers.

It's enough if we're compatible with the latest blackjax releases.

We can raise a runtime informative error if we know the installed version of blackjax is too old to work, directing users to update it

@junpenglao
Copy link
Member

junpenglao commented Aug 10, 2024

I guess this needs a new Blackjax release to work
Just cut a new Blackjax release, this should fix the test fail

@andrewdipper
Copy link
Contributor Author

andrewdipper commented Aug 10, 2024

Thanks for the new release

Looks like it needs to get into conda first

Copy link

codecov bot commented Aug 11, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.17%. Comparing base (48e56c3) to head (c58d94d).
Report is 6 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7453      +/-   ##
==========================================
- Coverage   92.20%   92.17%   -0.03%     
==========================================
  Files         103      103              
  Lines       17301    17258      -43     
==========================================
- Hits        15952    15908      -44     
- Misses       1349     1350       +1     
Files Coverage Δ
pymc/sampling/jax.py 94.78% <100.00%> (+0.75%) ⬆️

... and 1 file with indirect coverage changes

@junpenglao junpenglao merged commit 8cdc9ee into pymc-devs:main Aug 12, 2024
22 checks passed
@andrewdipper andrewdipper deleted the blackjax_pmap branch August 12, 2024 16:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants