Skip to content

Commit

Permalink
[Draft] Enable pmap progress bar with cpu backend / remove deprecated…
Browse files Browse the repository at this point in the history
… host_callback (#1841)

* initial progbar

* rename

* add lock

* comments + lock fix

* forgot chain comment

* switch to simple resource counter

* remove iter_num / fixes
  • Loading branch information
andrewdipper authored Aug 9, 2024
1 parent 643e2ca commit fb018d7
Showing 1 changed file with 42 additions and 38 deletions.
80 changes: 42 additions & 38 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import random
import re
from threading import Lock
import warnings

import numpy as np
Expand All @@ -18,7 +19,7 @@
import jax
from jax import device_put, jit, lax, vmap
from jax.core import Tracer
from jax.experimental import host_callback
from jax.experimental import io_callback
import jax.numpy as jnp

_DISABLE_CONTROL_FLOW_PRIM = False
Expand Down Expand Up @@ -201,58 +202,57 @@ def progress_bar_factory(num_samples, num_chains):

remainder = num_samples % print_rate

idx_counter = 0 # resource counter to assign chains to progress bars
tqdm_bars = {}
finished_chains = []
# lock serializes access to idx_counter since callbacks are multithreaded
# this prevents races that assign multiple chains to a progress bar
lock = Lock()
for chain in range(num_chains):
tqdm_bars[chain] = tqdm_auto(range(num_samples), position=chain)
tqdm_bars[chain].set_description("Compiling.. ", refresh=True)

def _update_tqdm(arg, transform, device):
chain_match = _CHAIN_RE.search(str(device))
assert chain_match
chain = int(chain_match.group())
def _update_tqdm(increment, chain):
increment = int(increment)
chain = int(chain)
if chain == -1:
nonlocal idx_counter
with lock:
chain = idx_counter
idx_counter += 1
tqdm_bars[chain].set_description(f"Running chain {chain}", refresh=False)
tqdm_bars[chain].update(arg)

def _close_tqdm(arg, transform, device):
chain_match = _CHAIN_RE.search(str(device))
assert chain_match
chain = int(chain_match.group())
tqdm_bars[chain].update(arg)
finished_chains.append(chain)
if len(finished_chains) == num_chains:
for chain in range(num_chains):
tqdm_bars[chain].close()

def _update_progress_bar(iter_num):
tqdm_bars[chain].update(increment)
return chain

def _close_tqdm(increment, chain):
increment = int(increment)
chain = int(chain)
tqdm_bars[chain].update(increment)
tqdm_bars[chain].close()

def _update_progress_bar(iter_num, chain):
"""Updates tqdm progress bar of a JAX loop only if the iteration number is a multiple of the print_rate
Usage: carry = progress_bar((iter_num, print_rate), carry)
"""

_ = lax.cond(
chain = lax.cond(
iter_num == 1,
lambda _: host_callback.id_tap(
_update_tqdm, 0, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
lambda _: io_callback(_update_tqdm, jnp.array(0), 0, chain),
lambda _: chain,
operand=None,
)
_ = lax.cond(
chain = lax.cond(
iter_num % print_rate == 0,
lambda _: host_callback.id_tap(
_update_tqdm, print_rate, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
lambda _: io_callback(_update_tqdm, jnp.array(0), print_rate, chain),
lambda _: chain,
operand=None,
)
_ = lax.cond(
iter_num == num_samples,
lambda _: host_callback.id_tap(
_close_tqdm, remainder, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
lambda _: io_callback(_close_tqdm, None, remainder, chain),
lambda _: None,
operand=None,
)
return chain

def progress_bar_fori_loop(func):
"""Decorator that adds a progress bar to `body_fun` used in `lax.fori_loop`.
Expand All @@ -261,9 +261,10 @@ def progress_bar_fori_loop(func):
"""

def wrapper_progress_bar(i, vals):
result = func(i, vals)
_update_progress_bar(i + 1)
return result
(subvals, chain) = vals
result = func(i, subvals)
chain = _update_progress_bar(i + 1, chain)
return (result, chain)

return wrapper_progress_bar

Expand Down Expand Up @@ -378,8 +379,11 @@ def loop_fn(collection):

def loop_fn(collection):
return fori_loop(
0, upper, _body_fn_pbar, (init_val, collection, start_idx, thinning)
)
0,
upper,
_body_fn_pbar,
((init_val, collection, start_idx, thinning), -1), # -1 for chain id
)[0]

last_val, collection, _, _ = maybe_jit(loop_fn, donate_argnums=0)(collection)

Expand Down

0 comments on commit fb018d7

Please sign in to comment.