Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement CustomProgress that does not output empty divs when disabled #7290

Merged
merged 7 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires
from pytensor.graph import ancestors
from pytensor.tensor.sharedvar import SharedVariable
from rich.progress import Console, Progress
from rich.progress import Console
from rich.theme import Theme
from xarray import Dataset

import pymc

from pymc.model import Model, modelcontext
from pymc.pytensorf import PointFunc, extract_obs_data
from pymc.util import default_progress_theme, get_default_varnames
from pymc.util import CustomProgress, default_progress_theme, get_default_varnames

if TYPE_CHECKING:
from pymc.backends.base import MultiTrace
Expand Down Expand Up @@ -649,8 +649,10 @@ def apply_function_over_dataset(
out_dict = _DefaultTrace(n_pts)
indices = range(n_pts)

with Progress(console=Console(theme=progressbar_theme), disable=not progressbar) as progress:
task = progress.add_task("Computing ...", total=n_pts, visible=progressbar)
with CustomProgress(
console=Console(theme=progressbar_theme), disable=not progressbar
) as progress:
task = progress.add_task("Computing ...", total=n_pts)
for idx in indices:
out = fn(posterior_pts[idx])
fn.f.trust_input = True # If we arrive here the dtypes are valid
Expand Down
22 changes: 17 additions & 5 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)
from pytensor.tensor.sharedvar import SharedVariable
from rich.console import Console
from rich.progress import Progress
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.theme import Theme

import pymc as pm
Expand All @@ -55,6 +55,7 @@
from pymc.model import Model, modelcontext
from pymc.pytensorf import compile_pymc
from pymc.util import (
CustomProgress,
RandomState,
_get_seeds_per_chain,
default_progress_theme,
Expand Down Expand Up @@ -828,11 +829,21 @@ def sample_posterior_predictive(
# All model variables have a name, but mypy does not know this
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
ppc_trace_t = _DefaultTrace(samples)

progress = CustomProgress(
"[progress.description]{task.description}",
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
console=Console(theme=progressbar_theme),
disable=not progressbar,
)

try:
with Progress(
console=Console(theme=progressbar_theme), disable=not progressbar
) as progress:
task = progress.add_task("Sampling ...", total=samples, visible=progressbar)
with progress:
task = progress.add_task("Sampling ...", completed=0, total=samples)
for idx in np.arange(samples):
if nchain > 1:
# the trace object will either be a MultiTrace (and have _straces)...
Expand All @@ -854,6 +865,7 @@ def sample_posterior_predictive(
ppc_trace_t.insert(k.name, v, idx)

progress.advance(task)
progress.update(task, refresh=True, completed=samples)

except KeyboardInterrupt:
pass
Expand Down
25 changes: 20 additions & 5 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from arviz.data.base import make_attrs
from pytensor.graph.basic import Variable
from rich.console import Console
from rich.progress import Progress
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.theme import Theme
from threadpoolctl import threadpool_limits
from typing_extensions import Protocol
Expand Down Expand Up @@ -65,6 +65,7 @@
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
from pymc.step_methods.hmc import quadpotential
from pymc.util import (
CustomProgress,
RandomSeed,
RandomState,
_get_seeds_per_chain,
Expand Down Expand Up @@ -1075,14 +1076,28 @@ def _sample(
)
_pbar_data = {"chain": chain, "divergences": 0}
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
with Progress(console=Console(theme=progressbar_theme)) as progress:

progress = CustomProgress(
"[progress.description]{task.description}",
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
console=Console(theme=progressbar_theme),
disable=not progressbar,
)

with progress:
try:
task = progress.add_task(_desc.format(**_pbar_data), total=draws, visible=progressbar)
task = progress.add_task(_desc.format(**_pbar_data), completed=0, total=draws)
for it, diverging in enumerate(sampling_gen):
if it >= skip_first and diverging:
_pbar_data["divergences"] += 1
progress.update(task, refresh=True, advance=1)
progress.update(task, refresh=True, advance=1, completed=True)
progress.update(task, description=_desc.format(**_pbar_data), completed=it)
progress.update(
task, description=_desc.format(**_pbar_data), completed=draws, refresh=True
)
except KeyboardInterrupt:
pass

Expand Down
8 changes: 3 additions & 5 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
import numpy as np

from rich.console import Console
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.theme import Theme
from threadpoolctl import threadpool_limits

from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.util import RandomSeed, default_progress_theme
from pymc.util import CustomProgress, RandomSeed, default_progress_theme

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -431,7 +431,7 @@ def __init__(

self._in_context = False

self._progress = Progress(
self._progress = CustomProgress(
"[progress.description]{task.description}",
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
Expand Down Expand Up @@ -465,7 +465,6 @@ def __iter__(self):
self._desc.format(self),
completed=self._completed_draws,
total=self._total_draws,
visible=self._show_progress,
)

while self._active:
Expand All @@ -476,7 +475,6 @@ def __iter__(self):
self._divergences += 1
progress.update(
task,
refresh=True,
completed=self._completed_draws,
total=self._total_draws,
description=self._desc.format(self),
Expand Down
18 changes: 8 additions & 10 deletions pymc/sampling/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import cloudpickle
import numpy as np

from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn

from pymc.backends.base import BaseTrace
from pymc.initial_point import PointType
Expand All @@ -37,7 +37,7 @@
StatsType,
)
from pymc.step_methods.metropolis import DEMetropolis
from pymc.util import RandomSeed
from pymc.util import CustomProgress, RandomSeed

__all__ = ()

Expand Down Expand Up @@ -100,11 +100,10 @@ def _sample_population(
progressbar=progressbar,
)

with Progress() as progress:
task = progress.add_task("[red]Sampling...", total=draws, visible=progressbar)

with CustomProgress(disable=not progressbar) as progress:
task = progress.add_task("[red]Sampling...", total=draws)
for _ in sampling:
progress.update(task, advance=1, refresh=True)
progress.update(task)

return

Expand Down Expand Up @@ -175,20 +174,19 @@ def __init__(self, steppers, parallelize: bool, progressbar: bool = True):
)
import multiprocessing

with Progress(
with CustomProgress(
"[progress.description]{task.description}",
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
disable=not progressbar,
) as self._progress:
for c, stepper in enumerate(steppers):
# enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers)
# ):
task = self._progress.add_task(
description=f"Chain {c}", visible=progressbar
)
task = self._progress.add_task(description=f"Chain {c}")
secondary_end, primary_end = multiprocessing.Pipe()
stepper_dumps = cloudpickle.dumps(stepper, protocol=4)
process = multiprocessing.Process(
Expand Down
10 changes: 4 additions & 6 deletions pymc/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from arviz import InferenceData
from rich.progress import (
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
Expand All @@ -41,7 +40,7 @@
from pymc.sampling.parallel import _cpu_count
from pymc.smc.kernels import IMH
from pymc.stats.convergence import log_warnings, run_convergence_checks
from pymc.util import RandomState, _get_seeds_per_chain
from pymc.util import CustomProgress, RandomState, _get_seeds_per_chain


def sample_smc(
Expand Down Expand Up @@ -369,13 +368,14 @@ def _sample_smc_int(


def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
with Progress(
with CustomProgress(
TextColumn("{task.description}"),
SpinnerColumn(),
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
TextColumn("{task.fields[status]}"),
disable=not progressbar,
) as progress:
futures = [] # keep track of the jobs
with multiprocessing.Manager() as manager:
Expand All @@ -390,9 +390,7 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
with ProcessPoolExecutor(max_workers=cores) as executor:
for c in range(chains): # iterate over the jobs we need to run
# set visible false so we don't have a lot of bars all at once:
task_id = progress.add_task(
f"Chain {c}", status="Stage: 0 Beta: 0", visible=progressbar
)
task_id = progress.add_task(f"Chain {c}", status="Stage: 0 Beta: 0")
futures.append(
executor.submit(
_sample_smc_int,
Expand Down
11 changes: 8 additions & 3 deletions pymc/tuning/starting.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.initial_point import make_initial_point_fn
from pymc.model import modelcontext
from pymc.util import default_progress_theme, get_default_varnames, get_value_vars_from_user_vars
from pymc.util import (
CustomProgress,
default_progress_theme,
get_default_varnames,
get_value_vars_from_user_vars,
)
from pymc.vartypes import discrete_types, typefilter

__all__ = ["find_MAP"]
Expand Down Expand Up @@ -219,13 +224,13 @@ def __init__(
self.desc = "logp = {:,.5g}, ||grad|| = {:,.5g}"
self.previous_x = None
self.progressbar = progressbar
self.progress = Progress(
self.progress = CustomProgress(
*Progress.get_default_columns(),
TextColumn("{task.fields[loss]}"),
console=Console(theme=progressbar_theme),
disable=not progressbar,
)
self.task = self.progress.add_task("MAP", total=maxeval, visible=progressbar, loss="")
self.task = self.progress.add_task("MAP", total=maxeval, loss="")

def __call__(self, x):
neg_value = np.float64(self.logp_func(pm.floatX(x)))
Expand Down
58 changes: 58 additions & 0 deletions pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytensor import Variable
from pytensor.compile import SharedVariable
from pytensor.graph.utils import ValidatingScratchpad
from rich.progress import Progress
from rich.theme import Theme

from pymc.exceptions import BlockModelAccessError
Expand Down Expand Up @@ -520,3 +521,60 @@ def makeiter(a):
return a
else:
return [a]


class CustomProgress(Progress):
"""A child of Progress that allows to disable progress bars and its container

The implementation simply checks an `is_enabled` flag and generates the progress bar only if
it's `True`.
"""

def __init__(self, *args, **kwargs):
self.is_enabled = kwargs.get("disable", None) is not True
if self.is_enabled:
super().__init__(*args, **kwargs)

def __enter__(self):
if self.is_enabled:
self.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.is_enabled:
super().__exit__(exc_type, exc_val, exc_tb)

def add_task(self, *args, **kwargs):
if self.is_enabled:
return super().add_task(*args, **kwargs)
return None

def advance(self, task_id, advance=1) -> None:
if self.is_enabled:
super().advance(task_id, advance)
return None

def update(
self,
task_id,
*,
total=None,
completed=None,
advance=None,
description=None,
visible=None,
refresh=False,
**fields,
):
if self.is_enabled:
super().update(
task_id,
total=total,
completed=completed,
advance=advance,
description=description,
visible=visible,
refresh=refresh,
**fields,
)
return None
Loading
Loading