Skip to content

Commit

Permalink
refactor: make zip strict (ruff B905)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmnhan committed Apr 25, 2024
1 parent c97724d commit 78bf5f5
Show file tree
Hide file tree
Showing 23 changed files with 101 additions and 51 deletions.
4 changes: 2 additions & 2 deletions docs/source/pyplots/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def sample_plot(norms, labels, kw0, kw1, cmap):
figsize=eplt.figwh(),
)

for norm, label, k0, k1 in zip(norms, labels, kw0, kw1):
for norm, label, k0, k1 in zip(norms, labels, kw0, kw1, strict=True):
axs[0].plot(x, norm(**k0, **k1)(x), label=label)

bar_data = modulatedBarData(384, 256)
for i, (ax, norm, k1) in enumerate(zip(axs[1:], norms, kw1)):
for i, (ax, norm, k1) in enumerate(zip(axs[1:], norms, kw1, strict=True)):
ax.plot(
0.5,
1,
Expand Down
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ select = [
"RUF",
]
ignore = [
"B905",
"ICN001", # Import conventions
"TRY003", # Long exception messages
]
Expand Down Expand Up @@ -205,7 +204,13 @@ warn_unused_configs = true
warn_redundant_casts = true
warn_unused_ignores = true
allow_redefinition = true
exclude = ["^docs/", "^tests/", "_deprecated/", "interactive/fermiedge.py", "io/"]
exclude = [
"^docs/",
"^tests/",
"_deprecated/",
"interactive/fermiedge.py",
"io/",
]

[[tool.mypy.overrides]]
module = [
Expand Down
8 changes: 6 additions & 2 deletions src/erlab/accessors/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def _broadcast_dict_values(d: dict[str, Any]) -> dict[str, xr.DataArray]:
else:
to_broadcast[k] = xr.DataArray(v)

for k, v in zip(to_broadcast.keys(), xr.broadcast(*to_broadcast.values())):
for k, v in zip(
to_broadcast.keys(), xr.broadcast(*to_broadcast.values()), strict=True
):
d[k] = v
return d

Expand Down Expand Up @@ -385,7 +387,9 @@ def _wrapper(Y, *args, **kwargs):
if n_coords == 1:
indep_var_kwargs = {model.independent_vars[0]: x}
else:
indep_var_kwargs = dict(zip(model.independent_vars[:n_coords], x))
indep_var_kwargs = dict(
zip(model.independent_vars[:n_coords], x, strict=True)
)
else:
raise ValueError("Independent variables not defined in model")

Expand Down
6 changes: 5 additions & 1 deletion src/erlab/accessors/kspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,11 @@ def _inverse_broadcast(self, kx, ky, kz=None) -> dict[str, xr.DataArray]:
return cast(
dict[str, xr.DataArray],
dict(
zip(cast(list[str], out_dict.keys()), xr.broadcast(*out_dict.values()))
zip(
cast(list[str], out_dict.keys()),
xr.broadcast(*out_dict.values()),
strict=True,
)
),
)

Expand Down
12 changes: 7 additions & 5 deletions src/erlab/analysis/correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def acf2(arr, mode: str = "full", method: str = "fft"):
acf,
{
d: autocorrelation_lags(n, mode) * s
for s, n, d in zip(steps, arr.shape, out.dims)
for s, n, d in zip(steps, arr.shape, out.dims, strict=True)
},
attrs=out.attrs,
)
Expand All @@ -114,14 +114,14 @@ def acf2stack(arr, stack_dims=("eV",), mode: str = "full", method: str = "fft"):

out_list = joblib.Parallel(n_jobs=-1, pre_dispatch="3 * n_jobs")(
joblib.delayed(nanacf)(
np.squeeze(arr.isel(dict(zip(stack_dims, vals))).values),
np.squeeze(arr.isel(dict(zip(stack_dims, vals, strict=True))).values),
mode,
method,
)
for vals in itertools.product(*stack_iter)
)
acf_dims = tuple(filter(lambda d: d not in stack_dims, arr.dims))
acf_sizes = dict(zip(acf_dims, out_list[0].shape))
acf_sizes = dict(zip(acf_dims, out_list[0].shape, strict=True))
acf_steps = tuple(arr[d].values[1] - arr[d].values[0] for d in acf_dims)

out_sizes = stack_sizes | acf_sizes
Expand All @@ -137,12 +137,14 @@ def acf2stack(arr, stack_dims=("eV",), mode: str = "full", method: str = "fft"):
out = out.assign_coords({d: arr[d] for d in stack_dims})

for i, vals in enumerate(itertools.product(*stack_iter)):
out.loc[{s: arr[s][v] for s, v in zip(stack_dims, vals)}] = out_list[i]
out.loc[{s: arr[s][v] for s, v in zip(stack_dims, vals, strict=True)}] = (
out_list[i]
)

out = out.assign_coords(
{
d: autocorrelation_lags(len(arr[d]), mode) * s
for s, d in zip(acf_steps, acf_dims)
for s, d in zip(acf_steps, acf_dims, strict=True)
}
)
if all(i in out.dims for i in ["kx", "ky"]):
Expand Down
17 changes: 12 additions & 5 deletions src/erlab/analysis/fit/minuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import xarray
from iminuit.util import _detect_log_spacing, _smart_sampling

import erlab.plotting.general
Expand Down Expand Up @@ -103,14 +104,16 @@ class Minuit(iminuit.Minuit):
def from_lmfit(
cls,
model: lmfit.Model,
data: npt.NDArray,
ivars: npt.NDArray | Sequence[npt.NDArray],
data: npt.NDArray | xarray.DataArray,
ivars: npt.NDArray
| xarray.DataArray
| Sequence[npt.NDArray | xarray.DataArray],
yerr: float | npt.NDArray | None = None,
return_cost: bool = False,
**kwargs,
) -> Minuit | tuple[LeastSq, Minuit]:
if len(model.independent_vars) == 1:
if isinstance(ivars, npt.NDArray):
if isinstance(ivars, np.ndarray | xarray.DataArray):
ivars = [ivars]

x: npt.NDArray | list[npt.NDArray] = [np.asarray(a) for a in ivars]
Expand Down Expand Up @@ -176,12 +179,16 @@ def from_lmfit(
if len(model.independent_vars) == 1:

def _temp_func(x, *fargs):
return model.func(x, **dict(zip(model._param_root_names, fargs)))
return model.func(
x, **dict(zip(model._param_root_names, fargs, strict=True))
)

else:

def _temp_func(x, *fargs):
return model.func(*x, **dict(zip(model._param_root_names, fargs)))
return model.func(
*x, **dict(zip(model._param_root_names, fargs, strict=True))
)

c = LeastSq(x, data, yerr, _temp_func)
m = cls(c, name=param_names, **values)
Expand Down
6 changes: 3 additions & 3 deletions src/erlab/analysis/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def gaussian_filter(
elif np.isscalar(sigma):
sigma_dict = dict.fromkeys(darr.dims, sigma)
elif isinstance(sigma, Collection):
sigma_dict = dict(zip(darr.dims, sigma))
sigma_dict = dict(zip(darr.dims, sigma, strict=True))
else:
raise TypeError("`sigma` must be a scalar, sequence, or mapping")

Expand All @@ -123,7 +123,7 @@ def gaussian_filter(
elif isinstance(radius, Sized):
if len(radius) != len(sigma_dict):
raise ValueError("`radius` does not match dimensions of `sigma`")
radius_dict = dict(zip(sigma_dict.keys(), radius))
radius_dict = dict(zip(sigma_dict.keys(), radius, strict=True))
elif np.isscalar(radius):
radius_dict = dict.fromkeys(sigma_dict.keys(), radius)
else:
Expand Down Expand Up @@ -211,7 +211,7 @@ def gaussian_laplace(
elif np.isscalar(sigma):
sigma_dict = dict.fromkeys(darr.dims, sigma)
elif isinstance(sigma, Collection):
sigma_dict = dict(zip(darr.dims, sigma))
sigma_dict = dict(zip(darr.dims, sigma, strict=True))
else:
raise TypeError("`sigma` must be a scalar, sequence, or mapping")

Expand Down
4 changes: 2 additions & 2 deletions src/erlab/analysis/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,15 @@ def shift(
for idxs in itertools.product(*[range(darr.shape[i]) for i in domain_indices]):
# Construct slices for indexing
_slices: list[slice | int] = [slice(None)] * darr.ndim
for domain_index, i in zip(domain_indices, idxs):
for domain_index, i in zip(domain_indices, idxs, strict=True):
_slices[domain_index] = i

slices: tuple[slice | int, ...] = tuple(_slices)

# Initialize arguments to `scipy.ndimage.shift`
input = out[slices]
shifts: list[float] = [0.0] * input.ndim
shift_val: float = float(shift.isel(dict(zip(shift.dims, idxs))))
shift_val: float = float(shift.isel(dict(zip(shift.dims, idxs, strict=True))))
shifts[cast(int, input.get_axis_num(along))] = shift_val

# Apply shift
Expand Down
5 changes: 4 additions & 1 deletion src/erlab/interactive/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,9 @@ def copy_code(self):
arg_dict = {
dim: f"|np.linspace(*{data_name}['{dim}'][[0, -1]], {n})|"
for dim, n in zip(
[self.xdim, self.ydim], [self.nx_spin.value(), self.ny_spin.value()]
[self.xdim, self.ydim],
[self.nx_spin.value(), self.ny_spin.value()],
strict=True,
)
}
lines.append(
Expand All @@ -244,6 +246,7 @@ def copy_code(self):
np.round(s.value(), s.decimals())
for s in (self.sx_spin, self.sy_spin)
],
strict=True,
)
)
}
Expand Down
6 changes: 5 additions & 1 deletion src/erlab/interactive/imagetool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def _generate_menu_kwargs(self) -> dict:
),
(1, 1, 0, 0) * 2,
(1, -1, 1, -1, 10, -10, 10, -10),
strict=True,
)
):
menu_kwargs["viewMenu"]["actions"]["cursorMoveMenu"]["actions"][
Expand Down Expand Up @@ -371,6 +372,7 @@ def _generate_menu_kwargs(self) -> dict:
),
(1, 1, 0, 0) * 2,
(1, -1, 1, -1, 10, -10, 10, -10),
strict=True,
)
):
menu_kwargs["viewMenu"]["actions"]["cursorMoveMenu"]["actions"][
Expand Down Expand Up @@ -400,7 +402,9 @@ def refreshMenus(self):
self.action_dict["snapCursorAct"].blockSignals(False)

cmap_props = self.slicer_area.colormap_properties
for ca, k in zip(self.colorAct, ["reversed", "highContrast", "zeroCentered"]):
for ca, k in zip(
self.colorAct, ["reversed", "highContrast", "zeroCentered"], strict=True
):
ca.blockSignals(True)
ca.setChecked(cmap_props[k])
ca.blockSignals(False)
Expand Down
8 changes: 6 additions & 2 deletions src/erlab/interactive/imagetool/_deprecated/imagetool_mpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,9 @@ def update_spans(self):
span.set_xy(get_xy_y(*domain))
span.set_visible(self.visible)
if self.useblit:
for i, span in list(zip(self.span_ax_index[axis], self.spans[axis])):
for i, span in list(
zip(self.span_ax_index[axis], self.spans[axis], strict=True)
):
self.axes[i].draw_artist(span)

def get_index_of_value(self, axis, val):
Expand Down Expand Up @@ -968,7 +970,9 @@ def _update(self):
# self.pool(delayed(self.axes[i].draw_artist)(art) for i, art in list(zip(
# (0, 1, 4, 0, 2, 5, 3, 5, 4), self.cursors)))
else:
for i, art in list(zip(self.ax_index, self.all + self.scaling_axes)):
for i, art in list(
zip(self.ax_index, self.all + self.scaling_axes, strict=True)
):
self.axes[i].draw_artist(art)
if any(self.averaged):
self.update_spans()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,7 @@ def _initialize_layout(
)
else:
raise NotImplementedError("Only supports 2D, 3D, and 4D arrays.")
for i, (p, sel) in enumerate(zip(self.axes, valid_selection)):
for i, (p, sel) in enumerate(zip(self.axes, valid_selection, strict=True)):
p.setDefaultPadding(0)
for axis in ["left", "bottom", "right", "top"]:
p.getAxis(axis).setTickFont(font)
Expand Down
19 changes: 13 additions & 6 deletions src/erlab/interactive/imagetool/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ def adjust_layout(
]
if self.data.ndim == 4:
sizes[3] = (0, 0, (r0 + r1 - d))
for split, sz in zip(self._splitters, sizes):
for split, sz in zip(self._splitters, sizes, strict=True):
split.setSizes(tuple(round(s * scale) for s in sz))

for i, sel in enumerate(valid_axis):
Expand Down Expand Up @@ -1226,7 +1226,10 @@ def refresh_manual_range(self):
if self.is_independent:
return
for dim, auto, rng in zip(
self.axis_dims, self.vb.state["autoRange"], self.vb.state["viewRange"]
self.axis_dims,
self.vb.state["autoRange"],
self.vb.state["viewRange"],
strict=True,
):
if dim is not None:
if auto:
Expand All @@ -1240,7 +1243,7 @@ def update_manual_range(self):
self.set_range_from(self.slicer_area.manual_limits)

def set_range_from(self, limits: dict[str, list[float]], **kwargs):
for dim, key in zip(self.axis_dims, ("xRange", "yRange")):
for dim, key in zip(self.axis_dims, ("xRange", "yRange"), strict=True):
if dim is not None:
try:
kwargs[key] = limits[dim]
Expand Down Expand Up @@ -1374,7 +1377,7 @@ def add_cursor(self, update=True):

self.cursor_lines.append({})
self.cursor_spans.append({})
for c, s, ax in zip(cursors, spans, self.display_axis):
for c, s, ax in zip(cursors, spans, self.display_axis, strict=True):
self.cursor_lines[-1][ax] = c
self.cursor_spans[-1][ax] = s
self.addItem(c)
Expand Down Expand Up @@ -1422,7 +1425,9 @@ def remove_cursor(self, index: int):
item = self.slicer_data_items.pop(index)
self.removeItem(item)
for line, span in zip(
self.cursor_lines.pop(index).values(), self.cursor_spans.pop(index).values()
self.cursor_lines.pop(index).values(),
self.cursor_spans.pop(index).values(),
strict=True,
):
self.removeItem(line)
self.removeItem(span)
Expand Down Expand Up @@ -1471,7 +1476,9 @@ def refresh_labels(self):
if self.is_image:
label_kw = {
a: self._get_label_unit(i)
for a, i in zip(("top", "bottom", "left", "right"), (0, 0, 1, 1))
for a, i in zip(
("top", "bottom", "left", "right"), (0, 0, 1, 1), strict=True
)
if self.getAxis(a).isVisible()
}
else:
Expand Down
6 changes: 3 additions & 3 deletions src/erlab/interactive/imagetool/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def set_array(
[s // 2 - (1 if s % 2 == 0 else 0) for s in self._obj.shape]
]
self._values: list[list[np.float32]] = [
[c[i] for c, i in zip(self.coords, self._indices[0])]
[c[i] for c, i in zip(self.coords, self._indices[0], strict=True)]
]
self.snap_to_data: bool = False

Expand Down Expand Up @@ -361,7 +361,7 @@ def add_cursor(self, like_cursor: int = -1, update: bool = True) -> None:
self._bins.append(list(self.get_bins(like_cursor)))
new_ind = self.get_indices(like_cursor)
self._indices.append(list(new_ind))
self._values.append([c[i] for c, i in zip(self.coords, new_ind)])
self._values.append([c[i] for c, i in zip(self.coords, new_ind, strict=True)])
if update:
self.sigCursorCountChanged.emit(self.n_cursors)

Expand Down Expand Up @@ -644,7 +644,7 @@ def xslice(self, cursor: int, disp: Sequence[int]) -> xr.DataArray:
isel_kw = self.isel_args(cursor, disp, int_if_one=False)
binned_coord_average: dict[str, xr.DataArray] = {
str(k): self._obj[k][isel_kw[str(k)]].mean()
for k, v in zip(self._obj.dims, self.get_binned(cursor))
for k, v in zip(self._obj.dims, self.get_binned(cursor), strict=True)
if v
}
return (
Expand Down
7 changes: 5 additions & 2 deletions src/erlab/interactive/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def labelString(self):
for k, v in zip(
("0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "-"),
("⁰", "¹", "²", "³", "⁴", "⁵", "⁶", "⁷", "⁸", "⁹", "⁻"),
strict=True,
):
units = units.replace(k, v)
units = f"10{units}"
Expand Down Expand Up @@ -1156,15 +1157,17 @@ def update_pos(self):
self.widgets["y0"].setMaximum(self.widgets["y1"].value())
self.widgets["x1"].setMinimum(self.widgets["x0"].value())
self.widgets["y1"].setMinimum(self.widgets["y0"].value())
for pos, spin in zip(self.roi_limits, self.roi_spin):
for pos, spin in zip(self.roi_limits, self.roi_spin, strict=True):
spin.blockSignals(True)
spin.setValue(pos)
spin.blockSignals(False)

def modify_roi(self, x0=None, y0=None, x1=None, y1=None, update=True):
lim_new = (x0, y0, x1, y1)
lim_old = self.roi_limits
x0, y0, x1, y1 = ((f if f is not None else i) for i, f in zip(lim_old, lim_new))
x0, y0, x1, y1 = (
(f if f is not None else i) for i, f in zip(lim_old, lim_new, strict=True)
)
xm, ym, xM, yM = self.roi.maxBounds.getCoords()
x0, y0, x1, y1 = max(x0, xm), max(y0, ym), min(x1, xM), min(y1, yM)
self.roi.setPos((x0, y0), update=False)
Expand Down
Loading

0 comments on commit 78bf5f5

Please sign in to comment.