Skip to content

Commit

Permalink
fix: fix progress bar for parallel objects that return generators
Browse files Browse the repository at this point in the history
Tqdm imports are also simplified. We no longer handle `is_notebook` ourselves, but just import from `tqdm.auto`
  • Loading branch information
kmnhan committed Apr 21, 2024
1 parent fd8e1ad commit 23d41b3
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 33 deletions.
34 changes: 26 additions & 8 deletions src/erlab/analysis/gold.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import numpy.typing as npt
import scipy.interpolate
import tqdm.auto
import uncertainties
import xarray as xr

Expand Down Expand Up @@ -243,6 +244,11 @@ def edge(
parallel_kw.setdefault("n_jobs", -1)
else:
parallel_kw.setdefault("n_jobs", 1)

parallel_kw.setdefault("max_nbytes", None)
parallel_kw.setdefault("return_as", "generator")
parallel_kw.setdefault("pre_dispatch", "n_jobs")

parallel_obj = joblib.Parallel(**parallel_kw)

if normalize:
Expand All @@ -260,23 +266,35 @@ def _fit(data, w):
method=method,
scale_covar=scale_covar,
weights=w,
**kwargs,
)
return res

if progress:
with joblib_progress(desc="Fitting", total=n_fits) as _:
tqdm_kw = {"desc": "Fitting", "total": n_fits, "disable": not progress}

if parallel_obj.return_generator:
fitresults = tqdm.auto.tqdm(
parallel_obj(
joblib.delayed(_fit)(gold_sel.isel(alpha=i), weights[i])
for i in range(n_fits)
),
**tqdm_kw,
)
else:
if progress:
with joblib_progress(**tqdm_kw) as _:
fitresults = parallel_obj(
joblib.delayed(_fit)(gold_sel.isel(alpha=i), weights[i])
for i in range(n_fits)
)
else:
fitresults = parallel_obj(
joblib.delayed(_fit)(gold_sel.isel(alpha=i), weights[i])
for i in range(n_fits)
)
else:
fitresults = parallel_obj(
joblib.delayed(_fit)(gold_sel.isel(alpha=i), weights[i])
for i in range(n_fits)
)

if return_full:
return fitresults
return list(fitresults)

xval = []
res_vals = []
Expand Down
5 changes: 4 additions & 1 deletion src/erlab/interactive/fermiedge.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def set_params(self, data, x0, y0, x1, y1, params):
self.y_range: tuple[float, float] = (y0, y1)
self.params = params
self.parallel_obj = joblib.Parallel(
n_jobs=self.params["# CPU"], max_nbytes=None
n_jobs=self.params["# CPU"],
max_nbytes=None,
return_as="list",
pre_dispatch="n_jobs",
)

@QtCore.Slot()
Expand Down
27 changes: 3 additions & 24 deletions src/erlab/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,19 @@

import joblib
import joblib._parallel_backends
import tqdm
import tqdm.notebook
import tqdm.auto
from qtpy import QtCore


def is_notebook():
# http://stackoverflow.com/questions/34091701/determine-if-were-in-an-ipython-notebook-session
if "IPython" not in sys.modules: # IPython hasn't been imported
return False
from IPython import get_ipython

# check for `kernel` attribute on the IPython instance
return getattr(get_ipython(), "kernel", None) is not None


@contextlib.contextmanager
def joblib_progress(file=None, notebook=None, dynamic_ncols=True, **kwargs):
def joblib_progress(file=None, **kwargs):
"""Context manager to patch joblib to report into tqdm progress bar given as
argument"""

if file is None:
file = sys.stdout

if notebook is None:
notebook = is_notebook()

if notebook:
tqdm_object = tqdm.notebook.tqdm(
iterable=None, dynamic_ncols=dynamic_ncols, file=file, **kwargs
)
else:
tqdm_object = tqdm.tqdm(
iterable=None, dynamic_ncols=dynamic_ncols, file=file, **kwargs
)
tqdm_object = tqdm.auto.tqdm(iterable=None, file=file, **kwargs)

def tqdm_print_progress(self):
if self.n_completed_tasks > tqdm_object.n:
Expand Down

0 comments on commit 23d41b3

Please sign in to comment.