From 321a5cc37d89cd5ea1cf00efaf8abf6762fdbff5 Mon Sep 17 00:00:00 2001 From: Tomas Capretto Date: Sun, 28 Apr 2024 20:01:28 -0300 Subject: [PATCH 1/7] Replace Progress with CustomProgress --- pymc/backends/arviz.py | 10 ++++++---- pymc/sampling/forward.py | 6 +++--- pymc/sampling/mcmc.py | 7 +++++-- pymc/sampling/parallel.py | 27 +++++++++++++++------------ pymc/sampling/population.py | 15 +++++++-------- pymc/smc/sampling.py | 10 ++++------ pymc/tuning/starting.py | 11 ++++++++--- pymc/util.py | 34 ++++++++++++++++++++++++++++++++++ pymc/variational/inference.py | 11 ++++++----- 9 files changed, 88 insertions(+), 43 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index d7238dc96aa..870193538be 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -32,7 +32,7 @@ from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires from pytensor.graph import ancestors from pytensor.tensor.sharedvar import SharedVariable -from rich.progress import Console, Progress +from rich.progress import Console from rich.theme import Theme from xarray import Dataset @@ -40,7 +40,7 @@ from pymc.model import Model, modelcontext from pymc.pytensorf import PointFunc, extract_obs_data -from pymc.util import default_progress_theme, get_default_varnames +from pymc.util import CustomProgress, default_progress_theme, get_default_varnames if TYPE_CHECKING: from pymc.backends.base import MultiTrace @@ -649,8 +649,10 @@ def apply_function_over_dataset( out_dict = _DefaultTrace(n_pts) indices = range(n_pts) - with Progress(console=Console(theme=progressbar_theme), disable=not progressbar) as progress: - task = progress.add_task("Computing ...", total=n_pts, visible=progressbar) + with CustomProgress( + console=Console(theme=progressbar_theme), disable=not progressbar + ) as progress: + task = progress.add_task("Computing ...", total=n_pts) for idx in indices: out = fn(posterior_pts[idx]) fn.f.trust_input = True # If we arrive here the dtypes are valid diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index cc1dcd52f6e..6f1bf4be6e3 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -44,7 +44,6 @@ ) from pytensor.tensor.sharedvar import SharedVariable from rich.console import Console -from rich.progress import Progress from rich.theme import Theme import pymc as pm @@ -55,6 +54,7 @@ from pymc.model import Model, modelcontext from pymc.pytensorf import compile_pymc from pymc.util import ( + CustomProgress, RandomState, _get_seeds_per_chain, default_progress_theme, @@ -829,10 +829,10 @@ def sample_posterior_predictive( _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore ppc_trace_t = _DefaultTrace(samples) try: - with Progress( + with CustomProgress( console=Console(theme=progressbar_theme), disable=not progressbar ) as progress: - task = progress.add_task("Sampling ...", total=samples, visible=progressbar) + task = progress.add_task("Sampling ...", total=samples) for idx in np.arange(samples): if nchain > 1: # the trace object will either be a MultiTrace (and have _straces)... diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index f207d1ce98c..3c626c0db56 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -65,6 +65,7 @@ from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared from pymc.step_methods.hmc import quadpotential from pymc.util import ( + CustomProgress, RandomSeed, RandomState, _get_seeds_per_chain, @@ -1075,9 +1076,11 @@ def _sample( ) _pbar_data = {"chain": chain, "divergences": 0} _desc = "Sampling chain {chain:d}, {divergences:,d} divergences" - with Progress(console=Console(theme=progressbar_theme)) as progress: + with CustomProgress( + console=Console(theme=progressbar_theme), disable=not progressbar + ) as progress: try: - task = progress.add_task(_desc.format(**_pbar_data), total=draws, visible=progressbar) + task = progress.add_task(_desc.format(**_pbar_data), total=draws) for it, diverging in enumerate(sampling_gen): if it >= skip_first and diverging: _pbar_data["divergences"] += 1 diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 11bb5e49e37..a3b41032a47 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -27,13 +27,13 @@ import numpy as np from rich.console import Console -from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn +from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from rich.theme import Theme from threadpoolctl import threadpool_limits from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError -from pymc.util import RandomSeed, default_progress_theme +from pymc.util import CustomProgress, RandomSeed, default_progress_theme logger = logging.getLogger(__name__) @@ -431,7 +431,7 @@ def __init__( self._in_context = False - self._progress = Progress( + self._progress = CustomProgress( "[progress.description]{task.description}", BarColumn(), "[progress.percentage]{task.percentage:>3.0f}%", @@ -465,7 +465,6 @@ def __iter__(self): self._desc.format(self), completed=self._completed_draws, total=self._total_draws, - visible=self._show_progress, ) while self._active: @@ -474,20 +473,24 @@ def __iter__(self): self._completed_draws += 1 if not tuning and stats and stats[0].get("diverging"): self._divergences += 1 - progress.update( - task, - refresh=True, - completed=self._completed_draws, - total=self._total_draws, - description=self._desc.format(self), - ) + + if self._progress.is_enabled: + progress.update( + task, + refresh=True, + completed=self._completed_draws, + total=self._total_draws, + description=self._desc.format(self), + ) if is_last: proc.join() self._active.remove(proc) self._finished.append(proc) self._make_active() - progress.update(task, description=self._desc.format(self), refresh=True) + + if self._progress.is_enabled: + progress.update(task, description=self._desc.format(self), refresh=True) # We could also yield proc.shared_point_view directly, # and only call proc.write_next() after the yield returns. diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 2b0aad2b32a..4d0cbe0024c 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -24,7 +24,7 @@ import cloudpickle import numpy as np -from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn +from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from pymc.backends.base import BaseTrace from pymc.initial_point import PointType @@ -37,7 +37,7 @@ StatsType, ) from pymc.step_methods.metropolis import DEMetropolis -from pymc.util import RandomSeed +from pymc.util import CustomProgress, RandomSeed __all__ = () @@ -100,8 +100,8 @@ def _sample_population( progressbar=progressbar, ) - with Progress() as progress: - task = progress.add_task("[red]Sampling...", total=draws, visible=progressbar) + with CustomProgress(disable=not progressbar) as progress: + task = progress.add_task("[red]Sampling...", total=draws) for _ in sampling: progress.update(task, advance=1, refresh=True) @@ -175,20 +175,19 @@ def __init__(self, steppers, parallelize: bool, progressbar: bool = True): ) import multiprocessing - with Progress( + with CustomProgress( "[progress.description]{task.description}", BarColumn(), "[progress.percentage]{task.percentage:>3.0f}%", TimeRemainingColumn(), TextColumn("/"), TimeElapsedColumn(), + disable=not progressbar, ) as self._progress: for c, stepper in enumerate(steppers): # enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers) # ): - task = self._progress.add_task( - description=f"Chain {c}", visible=progressbar - ) + task = self._progress.add_task(description=f"Chain {c}") secondary_end, primary_end = multiprocessing.Pipe() stepper_dumps = cloudpickle.dumps(stepper, protocol=4) process = multiprocessing.Process( diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index db4044a4fe9..ad6f7ede412 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -26,7 +26,6 @@ from arviz import InferenceData from rich.progress import ( - Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, @@ -41,7 +40,7 @@ from pymc.sampling.parallel import _cpu_count from pymc.smc.kernels import IMH from pymc.stats.convergence import log_warnings, run_convergence_checks -from pymc.util import RandomState, _get_seeds_per_chain +from pymc.util import CustomProgress, RandomState, _get_seeds_per_chain def sample_smc( @@ -369,13 +368,14 @@ def _sample_smc_int( def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): - with Progress( + with CustomProgress( TextColumn("{task.description}"), SpinnerColumn(), TimeRemainingColumn(), TextColumn("/"), TimeElapsedColumn(), TextColumn("{task.fields[status]}"), + disable=not progressbar, ) as progress: futures = [] # keep track of the jobs with multiprocessing.Manager() as manager: @@ -390,9 +390,7 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): with ProcessPoolExecutor(max_workers=cores) as executor: for c in range(chains): # iterate over the jobs we need to run # set visible false so we don't have a lot of bars all at once: - task_id = progress.add_task( - f"Chain {c}", status="Stage: 0 Beta: 0", visible=progressbar - ) + task_id = progress.add_task(f"Chain {c}", status="Stage: 0 Beta: 0") futures.append( executor.submit( _sample_smc_int, diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index 09b787c506d..cb8ae010d75 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -37,7 +37,12 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.initial_point import make_initial_point_fn from pymc.model import modelcontext -from pymc.util import default_progress_theme, get_default_varnames, get_value_vars_from_user_vars +from pymc.util import ( + CustomProgress, + default_progress_theme, + get_default_varnames, + get_value_vars_from_user_vars, +) from pymc.vartypes import discrete_types, typefilter __all__ = ["find_MAP"] @@ -219,13 +224,13 @@ def __init__( self.desc = "logp = {:,.5g}, ||grad|| = {:,.5g}" self.previous_x = None self.progressbar = progressbar - self.progress = Progress( + self.progress = CustomProgress( *Progress.get_default_columns(), TextColumn("{task.fields[loss]}"), console=Console(theme=progressbar_theme), disable=not progressbar, ) - self.task = self.progress.add_task("MAP", total=maxeval, visible=progressbar, loss="") + self.task = self.progress.add_task("MAP", total=maxeval, loss="") def __call__(self, x): neg_value = np.float64(self.logp_func(pm.floatX(x))) diff --git a/pymc/util.py b/pymc/util.py index ccf97c89a31..cd77e0bc25b 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -27,6 +27,7 @@ from pytensor import Variable from pytensor.compile import SharedVariable from pytensor.graph.utils import ValidatingScratchpad +from rich.progress import Progress from rich.theme import Theme from pymc.exceptions import BlockModelAccessError @@ -520,3 +521,36 @@ def makeiter(a): return a else: return [a] + + +class CustomProgress(Progress): + """A child of Progress that allows to disable progress bars and its container + + The implementation simply checks an `is_enabled` flag and generates the progress bar only if + it's `True`. + """ + + def __init__(self, *args, **kwargs): + self.is_enabled = not kwargs.get("disable", None) is True + if self.is_enabled: + super().__init__(*args, **kwargs) + + def __enter__(self): + if self.is_enabled: + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.is_enabled: + super().__exit__(exc_type, exc_val, exc_tb) + + def add_task(self, *args, **kwargs): + if self.is_enabled: + return super().add_task(*args, **kwargs) + return None + + def advance(self, task_id, advance=1) -> None: + if self.is_enabled: + super().advance(task_id, advance) + self.refresh() + return None diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index 3a9a69add72..42dad4a404b 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -23,7 +23,7 @@ import pymc as pm -from pymc.util import default_progress_theme +from pymc.util import CustomProgress, default_progress_theme from pymc.variational import test_functions from pymc.variational.approximations import Empirical, FullRank, MeanField from pymc.variational.operators import KL, KSD @@ -166,10 +166,10 @@ def fit( def _iterate_without_loss(self, s, n, step_func, progressbar, progressbar_theme, callbacks): i = 0 try: - with Progress( + with CustomProgress( console=Console(theme=progressbar_theme), disable=not progressbar ) as progress: - task = progress.add_task("Fitting", total=n, visible=progressbar) + task = progress.add_task("Fitting", total=n) for i in range(n): step_func() progress.update(task, advance=1) @@ -217,12 +217,13 @@ def _infmean(input_array): scores[:] = np.nan i = 0 try: - with Progress( + with CustomProgress( *Progress.get_default_columns(), TextColumn("{task.fields[loss]}"), console=Console(theme=progressbar_theme), + disable=not progressbar, ) as progress: - task = progress.add_task("Fitting:", total=n, visible=progressbar, loss="") + task = progress.add_task("Fitting:", total=n, loss="") for i in range(n): e = step_func() progress.update(task, advance=1) From 355c7e124e95bb6bc4e6a5ca94a7b86c989a98e2 Mon Sep 17 00:00:00 2001 From: Tomas Capretto Date: Wed, 1 May 2024 12:42:15 -0300 Subject: [PATCH 2/7] Add update method to CustomProgress --- pymc/sampling/parallel.py | 19 ++++++++----------- pymc/util.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index a3b41032a47..7d65d26541e 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -474,23 +474,20 @@ def __iter__(self): if not tuning and stats and stats[0].get("diverging"): self._divergences += 1 - if self._progress.is_enabled: - progress.update( - task, - refresh=True, - completed=self._completed_draws, - total=self._total_draws, - description=self._desc.format(self), - ) + progress.update( + task, + refresh=True, + completed=self._completed_draws, + total=self._total_draws, + description=self._desc.format(self), + ) if is_last: proc.join() self._active.remove(proc) self._finished.append(proc) self._make_active() - - if self._progress.is_enabled: - progress.update(task, description=self._desc.format(self), refresh=True) + progress.update(task, description=self._desc.format(self), refresh=True) # We could also yield proc.shared_point_view directly, # and only call proc.write_next() after the yield returns. diff --git a/pymc/util.py b/pymc/util.py index cd77e0bc25b..fe55813385e 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -531,7 +531,7 @@ class CustomProgress(Progress): """ def __init__(self, *args, **kwargs): - self.is_enabled = not kwargs.get("disable", None) is True + self.is_enabled = kwargs.get("disable", None) is not True if self.is_enabled: super().__init__(*args, **kwargs) @@ -552,5 +552,29 @@ def add_task(self, *args, **kwargs): def advance(self, task_id, advance=1) -> None: if self.is_enabled: super().advance(task_id, advance) - self.refresh() + return None + + def update( + self, + task_id, + *, + total=None, + completed=None, + advance=None, + description=None, + visible=None, + refresh=False, + **fields, + ): + if self.is_enabled: + super().update( + task_id, + total=total, + completed=completed, + advance=advance, + description=description, + visible=visible, + refresh=refresh, + **fields, + ) return None From 5575767097a817da44bf8f1c96c514e9fa044336 Mon Sep 17 00:00:00 2001 From: Tomas Capretto Date: Wed, 1 May 2024 13:35:12 -0300 Subject: [PATCH 3/7] remove unused import --- pymc/sampling/mcmc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 3c626c0db56..f11d957f3cd 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -36,7 +36,6 @@ from arviz.data.base import make_attrs from pytensor.graph.basic import Variable from rich.console import Console -from rich.progress import Progress from rich.theme import Theme from threadpoolctl import threadpool_limits from typing_extensions import Protocol From 974c89631d41bd69c6156c4ac84f6f8b9a6f1e79 Mon Sep 17 00:00:00 2001 From: Tomas Capretto Date: Sun, 28 Apr 2024 20:01:28 -0300 Subject: [PATCH 4/7] Replace Progress with CustomProgress --- pymc/sampling/parallel.py | 4 +++- pymc/util.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 7d65d26541e..97f62e5b5e5 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -487,7 +487,9 @@ def __iter__(self): self._active.remove(proc) self._finished.append(proc) self._make_active() - progress.update(task, description=self._desc.format(self), refresh=True) + + if self._progress.is_enabled: + progress.update(task, description=self._desc.format(self), refresh=True) # We could also yield proc.shared_point_view directly, # and only call proc.write_next() after the yield returns. diff --git a/pymc/util.py b/pymc/util.py index fe55813385e..bfce02b34ee 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -531,7 +531,7 @@ class CustomProgress(Progress): """ def __init__(self, *args, **kwargs): - self.is_enabled = kwargs.get("disable", None) is not True + self.is_enabled = not kwargs.get("disable", None) is True if self.is_enabled: super().__init__(*args, **kwargs) From ecabab5c20157402ec124e1a68797421e1199f11 Mon Sep 17 00:00:00 2001 From: Tomas Capretto Date: Wed, 1 May 2024 12:42:15 -0300 Subject: [PATCH 5/7] Add update method to CustomProgress --- pymc/sampling/parallel.py | 4 +--- pymc/util.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 97f62e5b5e5..7d65d26541e 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -487,9 +487,7 @@ def __iter__(self): self._active.remove(proc) self._finished.append(proc) self._make_active() - - if self._progress.is_enabled: - progress.update(task, description=self._desc.format(self), refresh=True) + progress.update(task, description=self._desc.format(self), refresh=True) # We could also yield proc.shared_point_view directly, # and only call proc.write_next() after the yield returns. diff --git a/pymc/util.py b/pymc/util.py index bfce02b34ee..fe55813385e 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -531,7 +531,7 @@ class CustomProgress(Progress): """ def __init__(self, *args, **kwargs): - self.is_enabled = not kwargs.get("disable", None) is True + self.is_enabled = kwargs.get("disable", None) is not True if self.is_enabled: super().__init__(*args, **kwargs) From d245a01c6f61a44bc435289454d0bdec24ceb369 Mon Sep 17 00:00:00 2001 From: Tomas Capretto Date: Wed, 1 May 2024 22:48:28 -0300 Subject: [PATCH 6/7] Remove some refreshes that slow things down --- pymc/sampling/mcmc.py | 2 +- pymc/sampling/parallel.py | 1 - pymc/sampling/population.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index f11d957f3cd..7e690afe1b3 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1083,7 +1083,7 @@ def _sample( for it, diverging in enumerate(sampling_gen): if it >= skip_first and diverging: _pbar_data["divergences"] += 1 - progress.update(task, refresh=True, advance=1) + progress.update(task) progress.update(task, refresh=True, advance=1, completed=True) except KeyboardInterrupt: pass diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 7d65d26541e..fb0e1fbc0d7 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -476,7 +476,6 @@ def __iter__(self): progress.update( task, - refresh=True, completed=self._completed_draws, total=self._total_draws, description=self._desc.format(self), diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 4d0cbe0024c..ae7b637e246 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -104,7 +104,7 @@ def _sample_population( task = progress.add_task("[red]Sampling...", total=draws) for _ in sampling: - progress.update(task, advance=1, refresh=True) + progress.update(task) return From 292dae99cb6efb3a121510c7ceee7c3774784f4b Mon Sep 17 00:00:00 2001 From: Tomas Capretto Date: Sat, 11 May 2024 18:18:18 -0300 Subject: [PATCH 7/7] Remove some 'refresh' and make sure progress goes to 100% --- pymc/sampling/forward.py | 20 ++++++++++++++++---- pymc/sampling/mcmc.py | 25 +++++++++++++++++++------ pymc/sampling/parallel.py | 1 - pymc/sampling/population.py | 1 - 4 files changed, 35 insertions(+), 12 deletions(-) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 6f1bf4be6e3..23b3a601658 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -44,6 +44,7 @@ ) from pytensor.tensor.sharedvar import SharedVariable from rich.console import Console +from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from rich.theme import Theme import pymc as pm @@ -828,11 +829,21 @@ def sample_posterior_predictive( # All model variables have a name, but mypy does not know this _log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore ppc_trace_t = _DefaultTrace(samples) + + progress = CustomProgress( + "[progress.description]{task.description}", + BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + TimeRemainingColumn(), + TextColumn("/"), + TimeElapsedColumn(), + console=Console(theme=progressbar_theme), + disable=not progressbar, + ) + try: - with CustomProgress( - console=Console(theme=progressbar_theme), disable=not progressbar - ) as progress: - task = progress.add_task("Sampling ...", total=samples) + with progress: + task = progress.add_task("Sampling ...", completed=0, total=samples) for idx in np.arange(samples): if nchain > 1: # the trace object will either be a MultiTrace (and have _straces)... @@ -854,6 +865,7 @@ def sample_posterior_predictive( ppc_trace_t.insert(k.name, v, idx) progress.advance(task) + progress.update(task, refresh=True, completed=samples) except KeyboardInterrupt: pass diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 7e690afe1b3..f2ef43e8b58 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -36,6 +36,7 @@ from arviz.data.base import make_attrs from pytensor.graph.basic import Variable from rich.console import Console +from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from rich.theme import Theme from threadpoolctl import threadpool_limits from typing_extensions import Protocol @@ -1075,16 +1076,28 @@ def _sample( ) _pbar_data = {"chain": chain, "divergences": 0} _desc = "Sampling chain {chain:d}, {divergences:,d} divergences" - with CustomProgress( - console=Console(theme=progressbar_theme), disable=not progressbar - ) as progress: + + progress = CustomProgress( + "[progress.description]{task.description}", + BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + TimeRemainingColumn(), + TextColumn("/"), + TimeElapsedColumn(), + console=Console(theme=progressbar_theme), + disable=not progressbar, + ) + + with progress: try: - task = progress.add_task(_desc.format(**_pbar_data), total=draws) + task = progress.add_task(_desc.format(**_pbar_data), completed=0, total=draws) for it, diverging in enumerate(sampling_gen): if it >= skip_first and diverging: _pbar_data["divergences"] += 1 - progress.update(task) - progress.update(task, refresh=True, advance=1, completed=True) + progress.update(task, description=_desc.format(**_pbar_data), completed=it) + progress.update( + task, description=_desc.format(**_pbar_data), completed=draws, refresh=True + ) except KeyboardInterrupt: pass diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index fb0e1fbc0d7..9f950f621f1 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -473,7 +473,6 @@ def __iter__(self): self._completed_draws += 1 if not tuning and stats and stats[0].get("diverging"): self._divergences += 1 - progress.update( task, completed=self._completed_draws, diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index ae7b637e246..4d5ced3f522 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -102,7 +102,6 @@ def _sample_population( with CustomProgress(disable=not progressbar) as progress: task = progress.add_task("[red]Sampling...", total=draws) - for _ in sampling: progress.update(task)