Skip to content

Commit

Permalink
comments + lock fix
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdipper committed Aug 6, 2024
1 parent 8df33ca commit 819416c
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ def progress_bar_factory(num_samples, num_chains):
tqdm_bars[chain] = tqdm_auto(range(num_samples), position=chain)
tqdm_bars[chain].set_description("Compiling.. ", refresh=True)

# Uses resource counting for each iter_num value. Chains are assigned to progress
# bars based on order of arrival to each iter_num value
def _calc_chain_idx(iter_num):
with lock:
try:
Expand Down Expand Up @@ -237,9 +239,10 @@ def _close_tqdm(iter_num, increment):
chain = _calc_chain_idx(iter_num + 1) # +1 so no collision in idx_map
tqdm_bars[chain].update(increment)
finished_chains.append(chain)
if len(finished_chains) == num_chains:
for chain in range(num_chains):
tqdm_bars[chain].close()
with lock:
if len(finished_chains) == num_chains:
for chain in range(num_chains):
tqdm_bars[chain].close()

def _update_progress_bar(iter_num):
"""Updates tqdm progress bar of a JAX loop only if the iteration number is a multiple of the print_rate
Expand All @@ -248,7 +251,7 @@ def _update_progress_bar(iter_num):

_ = lax.cond(
iter_num == 1,
lambda _: io_callback(_update_tqdm, None, -1, 0),
lambda _: io_callback(_update_tqdm, None, 0, 0),
lambda _: None,
operand=None,
)
Expand Down

0 comments on commit 819416c

Please sign in to comment.