Skip to content

Commit

Permalink
style: enforce ruff SIM
Browse files Browse the repository at this point in the history
  • Loading branch information
kmnhan committed Jul 9, 2024
1 parent 99dd897 commit 1ce4d36
Show file tree
Hide file tree
Showing 30 changed files with 194 additions and 312 deletions.
5 changes: 1 addition & 4 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,7 @@ def linkcode_resolve(domain, info) -> str | None:
except OSError:
lineno = None

if lineno:
linespec = f"#L{lineno}-L{lineno + len(source) - 1}"
else:
linespec = ""
linespec = f"#L{lineno}-L{lineno + len(source) - 1}" if lineno else ""

import erlab

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ select = [
"Q",
"RSE",
"RET",
"SIM",
"TID",
"TCH",
"INT",
Expand Down
31 changes: 10 additions & 21 deletions src/erlab/accessors/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"ParallelFitDataArrayAccessor",
]

import contextlib
import copy
import itertools
import warnings
Expand Down Expand Up @@ -73,7 +74,7 @@ def _parse_params(


def _parse_multiple_params(d: dict[str, Any], as_str: bool) -> xr.DataArray:
for k in d.keys():
for k in d:
if isinstance(d[k], int | float | complex | xr.DataArray):
d[k] = {"value": d[k]}

Expand Down Expand Up @@ -269,11 +270,7 @@ def __call__(
else:
reduce_dims_ = list(reduce_dims)

if (
isinstance(coords, str)
or isinstance(coords, xr.DataArray)
or not isinstance(coords, Iterable)
):
if isinstance(coords, str | xr.DataArray) or not isinstance(coords, Iterable):
coords = [coords]
coords_: Sequence[xr.DataArray] = [
self._obj[coord] if isinstance(coord, str) else coord for coord in coords
Expand Down Expand Up @@ -339,10 +336,7 @@ def _wrapper(Y, *args, **kwargs):
coords__ = args[:n_coords]
init_params_ = args[n_coords]

if guess:
initial_params = lmfit.create_params()
else:
initial_params = model.make_params()
initial_params = lmfit.create_params() if guess else model.make_params()

if isinstance(init_params_, _ParametersWraper):
initial_params.update(init_params_.params)
Expand Down Expand Up @@ -400,10 +394,8 @@ def _wrapper(Y, *args, **kwargs):
if isinstance(model, lmfit.model.CompositeModel):
guessed_params = model.make_params()
for comp in model.components:
try:
with contextlib.suppress(NotImplementedError):
guessed_params.update(comp.guess(y, **indep_var_kwargs))
except NotImplementedError:
pass
# Given parameters must override guessed parameters
initial_params = guessed_params.update(initial_params)

Expand Down Expand Up @@ -461,24 +453,21 @@ def _wrapper(Y, *args, **kwargs):
return popt, perr, pcov, stats, data, best, modres

def _output_wrapper(name, da, out=None) -> dict:
if name is _THIS_ARRAY:
name = ""
else:
name = f"{name!s}_"
name = "" if name is _THIS_ARRAY else f"{name!s}_"

if out is None:
out = {}

input_core_dims = [reduce_dims_ for _ in range(n_coords + 1)]
input_core_dims.extend([[] for _ in range(1)]) # core_dims for parameters

if isinstance(params, xr.Dataset):
if not isinstance(params, xr.Dataset):
params_to_apply = params
else:
try:
params_to_apply = params[name.rstrip("_")]
except KeyError:
params_to_apply = params[float(name.rstrip("_"))]
else:
params_to_apply = params

popt, perr, pcov, stats, data, best, modres = xr.apply_ufunc(
_wrapper,
Expand Down Expand Up @@ -668,7 +657,7 @@ def __call__(self, dim: str, model: lmfit.Model, **kwargs) -> xr.Dataset:

drop_keys = []
concat_vars: dict[Hashable, list[xr.DataArray]] = {}
for k in ds.data_vars.keys():
for k in ds.data_vars:
for var in self._VAR_KEYS:
key = f"{k}_{var}"
if key in fitres:
Expand Down
2 changes: 1 addition & 1 deletion src/erlab/accessors/kspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ def convert(
target_dict: dict[str, xr.DataArray] = self._inverse_broadcast(
momentum_coords.get("kx"),
momentum_coords.get("ky"),
momentum_coords.get("kz", None),
momentum_coords.get("kz"),
)

# Coords of first value in target_dict. Output of inverse_broadcast are all
Expand Down
5 changes: 1 addition & 4 deletions src/erlab/analysis/fit/functions/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ def _infer_meshgrid_shape(arr: np.ndarray) -> tuple[tuple[int, int], int, np.nda
# The shape of the original meshgrid
shape = len(arr) // (change_index[0] + 1), change_index[0] + 1

if axis == 0:
coord = arr.reshape(shape)[:, 0]
else:
coord = arr.reshape(shape)[0, :]
coord = arr.reshape(shape)[:, 0] if axis == 0 else arr.reshape(shape)[0, :]

return shape, axis, coord

Expand Down
7 changes: 4 additions & 3 deletions src/erlab/analysis/fit/minuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ def from_lmfit(
return_cost: bool = False,
**kwargs,
) -> Minuit | tuple[LeastSq, Minuit]:
if len(model.independent_vars) == 1:
if isinstance(ivars, np.ndarray | xarray.DataArray):
ivars = [ivars]
if len(model.independent_vars) == 1 and isinstance(
ivars, np.ndarray | xarray.DataArray
):
ivars = [ivars]

x: npt.NDArray | list[npt.NDArray] = [np.asarray(a) for a in ivars]

Expand Down
5 changes: 2 additions & 3 deletions src/erlab/analysis/fit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"StepEdgeModel",
]

import contextlib
from typing import Literal

import lmfit
Expand Down Expand Up @@ -179,10 +180,8 @@ def guess(self, data, x, **kwargs):

temp = 30.0
if isinstance(data, xr.DataArray):
try:
with contextlib.suppress(KeyError):
temp = float(data.attrs["temp_sample"])
except KeyError:
pass

pars[f"{self.prefix}center"].set(
value=efermi, min=np.asarray(x).min(), max=np.asarray(x).max()
Expand Down
5 changes: 1 addition & 4 deletions src/erlab/analysis/gold.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,10 +589,7 @@ def quick_fit(
"""
data = darr.mean([d for d in darr.dims if d != "eV"])

if eV_range is not None:
data_fit = data.sel(eV=slice(*eV_range))
else:
data_fit = data
data_fit = data.sel(eV=slice(*eV_range)) if eV_range is not None else data

if temp is None:
if "temp_sample" in data.attrs:
Expand Down
19 changes: 8 additions & 11 deletions src/erlab/analysis/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ def _parse_dict_arg(
f"{'' if len(required_dims) == 1 else 's'}: {required_dims}"
)

for d in sigma_dict.keys():
for d in sigma_dict:
if d not in dims:
raise ValueError(
f"Dimension `{d}` in {arg_name} not found in {reference_name}"
)

# Make sure that sigma_dict is ordered in temrs of data dims
return {d: sigma_dict[d] for d in dims if d in sigma_dict.keys()}
return {d: sigma_dict[d] for d in dims if d in sigma_dict}


def gaussian_filter(
Expand Down Expand Up @@ -164,14 +164,14 @@ def gaussian_filter(
)

# Get the axis indices to apply the filter
axes = tuple(darr.get_axis_num(d) for d in sigma_dict.keys())
axes = tuple(darr.get_axis_num(d) for d in sigma_dict)

# Convert arguments to tuples acceptable by scipy
if isinstance(order, Mapping):
order = tuple(order.get(str(d), 0) for d in sigma_dict.keys())
order = tuple(order.get(str(d), 0) for d in sigma_dict)

if isinstance(mode, Mapping):
mode = tuple(mode[str(d)] for d in sigma_dict.keys())
mode = tuple(mode[str(d)] for d in sigma_dict)

if radius is not None:
radius_dict = _parse_dict_arg(
Expand All @@ -186,7 +186,7 @@ def gaussian_filter(
else:
radius_pix = None

for d in sigma_dict.keys():
for d in sigma_dict:
if not is_uniform_spaced(darr[d].values):
raise ValueError(f"Dimension `{d}` is not uniformly spaced")

Expand Down Expand Up @@ -268,7 +268,7 @@ def gaussian_laplace(

# Convert mode to tuple acceptable by scipy
if isinstance(mode, dict):
mode = tuple(mode[d] for d in sigma_dict.keys())
mode = tuple(mode[d] for d in sigma_dict)

# Calculate sigma in pixels
sigma_pix: tuple[float, ...] = tuple(
Expand Down Expand Up @@ -433,10 +433,7 @@ def ndsavgol(
if method not in ["pinv", "lstsq"]:
raise ValueError("method must be 'pinv' or 'lstsq'")

if method == "lstsq":
accurate = True
else:
accurate = False
accurate = method == "lstsq"

if isinstance(window_shape, int):
window_shape = (window_shape,) * arr.ndim
Expand Down
15 changes: 4 additions & 11 deletions src/erlab/analysis/mask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,8 @@ def spherical_mask(
array([False, True, True, True, False])
Dimensions without coordinates: x
"""
if isinstance(radius, dict):
if set(radius.keys()) != set(sel_kw.keys()):
raise ValueError("Keys in radius and sel_kw must match")
if isinstance(radius, dict) and set(radius.keys()) != set(sel_kw.keys()):
raise ValueError("Keys in radius and sel_kw must match")

if len(sel_kw) == 0:
raise ValueError("No dimensions provided for mask")
Expand All @@ -232,10 +231,7 @@ def spherical_mask(
if k not in darr.dims:
raise ValueError(f"Dimension {k} not found in data")

if isinstance(radius, dict):
r = radius[k]
else:
r = float(radius)
r = radius[k] if isinstance(radius, dict) else float(radius)

delta_squared = delta_squared + ((darr[k] - v) / r) ** 2

Expand Down Expand Up @@ -287,10 +283,7 @@ def hex_bz_mask_points(
invert: bool = False,
) -> npt.NDArray[np.bool_]:
"""Return a mask for given points."""
if reciprocal:
d = 2 * np.pi / (a * 3)
else:
d = a
d = 2 * np.pi / (a * 3) if reciprocal else a
ang = rotate + np.array([0, 60, 120, 180, 240, 300])
vertices = np.array(
[
Expand Down
10 changes: 3 additions & 7 deletions src/erlab/interactive/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"pg_colormap_to_QPixmap",
]

import contextlib
import weakref
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Literal
Expand Down Expand Up @@ -74,10 +75,8 @@ def __init__(self, *args, **kwargs) -> None:
def load_thumbnail(self, index: int) -> None:
if not self.thumbnails_loaded:
text = self.itemText(index)
try:
with contextlib.suppress(KeyError):
self.setItemIcon(index, QtGui.QIcon(pg_colormap_to_QPixmap(text)))
except KeyError:
pass

def load_all(self) -> None:
self.clear()
Expand Down Expand Up @@ -628,10 +627,7 @@ def pg_colormap_names(
# if (_mpl != []) and (cet != []):
# local = []

if exclude_local:
all_cmaps = cet + _mpl
else:
all_cmaps = local + cet + _mpl
all_cmaps = cet + _mpl if exclude_local else local + cet + _mpl
elif exclude_local:
all_cmaps = _mpl
else:
Expand Down
15 changes: 4 additions & 11 deletions src/erlab/interactive/fermiedge.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,10 +446,7 @@ def _perform_poly_fit(self):
method=params["Method"],
scale_covar=params["Scale cov"],
)
if self.data_corr is None:
target = self.data
else:
target = self.data_corr
target = self.data if self.data_corr is None else self.data_corr
self.corrected = erlab.analysis.correct_with_edge(
target, self.result, plot=False, shift_coords=params["Shift coords"]
)
Expand All @@ -465,10 +462,7 @@ def _perform_spline_fit(self):
lam=params["lambda"],
)

if self.data_corr is None:
target = self.data
else:
target = self.data_corr
target = self.data if self.data_corr is None else self.data_corr
self.corrected = erlab.analysis.correct_with_edge(
target, self.result, plot=False, shift_coords=params["Shift coords"]
)
Expand Down Expand Up @@ -515,9 +509,8 @@ def gen_code(self, mode: str) -> None:
if not p0["Scale cov"]:
arg_dict["scale_covar_edge"] = False

if mode == "poly":
if not p1["Scale cov"]:
arg_dict["scale_covar"] = False
if mode == "poly" and not p1["Scale cov"]:
arg_dict["scale_covar"] = False

if self.data_corr is None:
gen_function_code(
Expand Down
2 changes: 1 addition & 1 deletion src/erlab/interactive/imagetool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def _open_file(
"xarray HDF5 Files (*.h5)": (erlab.io.load_hdf5, {}),
"NetCDF Files (*.nc *.nc4 *.cdf)": (xr.load_dataarray, {}),
}
for k in erlab.io.loaders.keys():
for k in erlab.io.loaders:
valid_loaders = valid_loaders | erlab.io.loaders[k].file_dialog_methods

dialog = QtWidgets.QFileDialog(self)
Expand Down
Loading

0 comments on commit 1ce4d36

Please sign in to comment.