Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate progress bar from fastprogress to tqdm #655

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

zaxtax
Copy link
Contributor

@zaxtax zaxtax commented Apr 1, 2024

Moves to use tqdm along with adding support for multiple progress bars

Makes blackjax suitable for running multiple chains in parallel.

Copy link

codecov bot commented Apr 1, 2024

Codecov Report

Attention: Patch coverage is 91.30435% with 2 lines in your changes are missing coverage. Please review.

Project coverage is 98.80%. Comparing base (7cf4f9d) to head (b2d2273).
Report is 1 commits behind head on main.

❗ Current head b2d2273 differs from pull request most recent head bf703cd. Consider uploading reports for the commit bf703cd to get more accurate results

Files Patch % Lines
blackjax/progress_bar.py 91.30% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #655      +/-   ##
==========================================
- Coverage   98.87%   98.80%   -0.07%     
==========================================
  Files          59       59              
  Lines        2745     2752       +7     
==========================================
+ Hits         2714     2719       +5     
- Misses         31       33       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@zaxtax zaxtax marked this pull request as ready for review April 1, 2024 13:20
@@ -14,70 +14,83 @@
"""Progress bar decorators for use with step functions.
Adapted from Jeremie Coullon's blog post :cite:p:`progress_bar`.
"""
from fastprogress.fastprogress import progress_bar
import jax
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from jax.debug import callback?

def progress_bar_scan(num_samples, print_rate=None):
"Progress bar for a JAX scan"
progress_bars = {}
def progress_bar_scan(num_samples, num_chains=1, print_rate=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC in the usage we need to specify the num_chains for pmap to work properly. Could you explain a bit more how you are planning to change the API for downstream application so that part works?

one_step = progress_bar_scan(num_steps)(_one_step)

one_step_ = jax.jit(progress_bar_scan(num_steps)(one_step))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not fully committed to this API, but I was thinking something where along with passing an array of iteration numbers, you also pass in the chain you are currently in. I think this is better than the numpyro design where you are using regexes on device objects to guess what chain to put the computation on.

def inference_loop(rng_key, kernel, initial_state, chain, num_samples, num_chains):

    def _one_step(state, xs):
        _, _, rng_key = xs
        state, _ = kernel(rng_key, state)
        return state, state
    one_step = jax.jit(progress_bar_factory(num_samples, num_chains)(_one_step))

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(
        one_step,
        initial_state,
        (np.arange(num_samples), chain * np.ones(num_samples), keys),
    )

    return states
    
inference_loop_multiple_chains = jax.pmap(
    inference_loop,
    in_axes=(0, None, 0, 0, None, None),
    static_broadcasted_argnums=(1, 4, 5),
    devices=jax.devices(),
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For downstream applications that don't use multiple chains, I have included logic to maintain backward compatibility. Though I'm not sure how actual code is implementing progress bars for multiple chains today.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, could you share a small jupyter notebook how it looks like?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that's very helpful. Let me think about it a bit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfectly happy to rework the API. This is was an attempt to make something simple and backwards compatible.

@andrewdipper andrewdipper mentioned this pull request Aug 3, 2024
10 tasks
@junpenglao
Copy link
Member

@zaxtax should we update this after #712 and get it merge?

@zaxtax
Copy link
Contributor Author

zaxtax commented Sep 27, 2024 via email

@junpenglao
Copy link
Member

I forgot the reason why we were doing this beside the pmap bug (which is now fixed)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants