Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add time remaining column to progress bars #7273

Merged
merged 10 commits into from
Apr 26, 2024
4 changes: 3 additions & 1 deletion pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,9 @@ def sample_posterior_predictive(
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
ppc_trace_t = _DefaultTrace(samples)
try:
with Progress(console=Console(theme=progressbar_theme)) as progress:
with Progress(
console=Console(theme=progressbar_theme), disable=not progressbar
) as progress:
fonnesbeck marked this conversation as resolved.
Show resolved Hide resolved
task = progress.add_task("Sampling ...", total=samples, visible=progressbar)
for idx in np.arange(samples):
if nchain > 1:
Expand Down
4 changes: 3 additions & 1 deletion pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np

from rich.console import Console
from rich.progress import BarColumn, Progress, TimeRemainingColumn
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.theme import Theme

from pymc.blocking import DictToArrayBijection
Expand Down Expand Up @@ -428,6 +428,8 @@ def __init__(
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
console=Console(theme=progressbar_theme),
)
self._show_progress = progressbar
Expand Down
4 changes: 3 additions & 1 deletion pymc/sampling/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import cloudpickle
import numpy as np

from rich.progress import BarColumn, Progress, TimeRemainingColumn
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn

from pymc.backends.base import BaseTrace
from pymc.initial_point import PointType
Expand Down Expand Up @@ -180,6 +180,8 @@ def __init__(self, steppers, parallelize: bool, progressbar: bool = True):
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
) as self._progress:
for c, stepper in enumerate(steppers):
# enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers)
Expand Down
10 changes: 9 additions & 1 deletion pymc/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
import numpy as np

from arviz import InferenceData
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
from rich.progress import (
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)

import pymc

Expand Down Expand Up @@ -366,6 +372,8 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
with Progress(
TextColumn("{task.description}"),
SpinnerColumn(),
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
TextColumn("{task.fields[status]}"),
) as progress:
Expand Down
Loading