diff --git a/docs/source/conf.py b/docs/source/conf.py index eaaea52b..9a040f12 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -101,10 +101,7 @@ def linkcode_resolve(domain, info) -> str | None: except OSError: lineno = None - if lineno: - linespec = f"#L{lineno}-L{lineno + len(source) - 1}" - else: - linespec = "" + linespec = f"#L{lineno}-L{lineno + len(source) - 1}" if lineno else "" import erlab diff --git a/pyproject.toml b/pyproject.toml index 39b9aa43..bd6083e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -163,6 +163,7 @@ select = [ "Q", "RSE", "RET", + "SIM", "TID", "TCH", "INT", diff --git a/src/erlab/accessors/fit.py b/src/erlab/accessors/fit.py index 8f590607..73d821f2 100644 --- a/src/erlab/accessors/fit.py +++ b/src/erlab/accessors/fit.py @@ -6,6 +6,7 @@ "ParallelFitDataArrayAccessor", ] +import contextlib import copy import itertools import warnings @@ -73,7 +74,7 @@ def _parse_params( def _parse_multiple_params(d: dict[str, Any], as_str: bool) -> xr.DataArray: - for k in d.keys(): + for k in d: if isinstance(d[k], int | float | complex | xr.DataArray): d[k] = {"value": d[k]} @@ -269,11 +270,7 @@ def __call__( else: reduce_dims_ = list(reduce_dims) - if ( - isinstance(coords, str) - or isinstance(coords, xr.DataArray) - or not isinstance(coords, Iterable) - ): + if isinstance(coords, str | xr.DataArray) or not isinstance(coords, Iterable): coords = [coords] coords_: Sequence[xr.DataArray] = [ self._obj[coord] if isinstance(coord, str) else coord for coord in coords @@ -339,10 +336,7 @@ def _wrapper(Y, *args, **kwargs): coords__ = args[:n_coords] init_params_ = args[n_coords] - if guess: - initial_params = lmfit.create_params() - else: - initial_params = model.make_params() + initial_params = lmfit.create_params() if guess else model.make_params() if isinstance(init_params_, _ParametersWraper): initial_params.update(init_params_.params) @@ -400,10 +394,8 @@ def _wrapper(Y, *args, **kwargs): if isinstance(model, lmfit.model.CompositeModel): guessed_params = model.make_params() for comp in model.components: - try: + with contextlib.suppress(NotImplementedError): guessed_params.update(comp.guess(y, **indep_var_kwargs)) - except NotImplementedError: - pass # Given parameters must override guessed parameters initial_params = guessed_params.update(initial_params) @@ -461,10 +453,7 @@ def _wrapper(Y, *args, **kwargs): return popt, perr, pcov, stats, data, best, modres def _output_wrapper(name, da, out=None) -> dict: - if name is _THIS_ARRAY: - name = "" - else: - name = f"{name!s}_" + name = "" if name is _THIS_ARRAY else f"{name!s}_" if out is None: out = {} @@ -472,13 +461,13 @@ def _output_wrapper(name, da, out=None) -> dict: input_core_dims = [reduce_dims_ for _ in range(n_coords + 1)] input_core_dims.extend([[] for _ in range(1)]) # core_dims for parameters - if isinstance(params, xr.Dataset): + if not isinstance(params, xr.Dataset): + params_to_apply = params + else: try: params_to_apply = params[name.rstrip("_")] except KeyError: params_to_apply = params[float(name.rstrip("_"))] - else: - params_to_apply = params popt, perr, pcov, stats, data, best, modres = xr.apply_ufunc( _wrapper, @@ -668,7 +657,7 @@ def __call__(self, dim: str, model: lmfit.Model, **kwargs) -> xr.Dataset: drop_keys = [] concat_vars: dict[Hashable, list[xr.DataArray]] = {} - for k in ds.data_vars.keys(): + for k in ds.data_vars: for var in self._VAR_KEYS: key = f"{k}_{var}" if key in fitres: diff --git a/src/erlab/accessors/kspace.py b/src/erlab/accessors/kspace.py index 5f19feac..68fa7c44 100644 --- a/src/erlab/accessors/kspace.py +++ b/src/erlab/accessors/kspace.py @@ -786,7 +786,7 @@ def convert( target_dict: dict[str, xr.DataArray] = self._inverse_broadcast( momentum_coords.get("kx"), momentum_coords.get("ky"), - momentum_coords.get("kz", None), + momentum_coords.get("kz"), ) # Coords of first value in target_dict. Output of inverse_broadcast are all diff --git a/src/erlab/analysis/fit/functions/general.py b/src/erlab/analysis/fit/functions/general.py index 4afba3c7..d873ded5 100644 --- a/src/erlab/analysis/fit/functions/general.py +++ b/src/erlab/analysis/fit/functions/general.py @@ -53,10 +53,7 @@ def _infer_meshgrid_shape(arr: np.ndarray) -> tuple[tuple[int, int], int, np.nda # The shape of the original meshgrid shape = len(arr) // (change_index[0] + 1), change_index[0] + 1 - if axis == 0: - coord = arr.reshape(shape)[:, 0] - else: - coord = arr.reshape(shape)[0, :] + coord = arr.reshape(shape)[:, 0] if axis == 0 else arr.reshape(shape)[0, :] return shape, axis, coord diff --git a/src/erlab/analysis/fit/minuit.py b/src/erlab/analysis/fit/minuit.py index 5676cea4..c10d3586 100644 --- a/src/erlab/analysis/fit/minuit.py +++ b/src/erlab/analysis/fit/minuit.py @@ -114,9 +114,10 @@ def from_lmfit( return_cost: bool = False, **kwargs, ) -> Minuit | tuple[LeastSq, Minuit]: - if len(model.independent_vars) == 1: - if isinstance(ivars, np.ndarray | xarray.DataArray): - ivars = [ivars] + if len(model.independent_vars) == 1 and isinstance( + ivars, np.ndarray | xarray.DataArray + ): + ivars = [ivars] x: npt.NDArray | list[npt.NDArray] = [np.asarray(a) for a in ivars] diff --git a/src/erlab/analysis/fit/models.py b/src/erlab/analysis/fit/models.py index 404f294c..e7c68b57 100644 --- a/src/erlab/analysis/fit/models.py +++ b/src/erlab/analysis/fit/models.py @@ -10,6 +10,7 @@ "StepEdgeModel", ] +import contextlib from typing import Literal import lmfit @@ -179,10 +180,8 @@ def guess(self, data, x, **kwargs): temp = 30.0 if isinstance(data, xr.DataArray): - try: + with contextlib.suppress(KeyError): temp = float(data.attrs["temp_sample"]) - except KeyError: - pass pars[f"{self.prefix}center"].set( value=efermi, min=np.asarray(x).min(), max=np.asarray(x).max() diff --git a/src/erlab/analysis/gold.py b/src/erlab/analysis/gold.py index 6516b5fb..4075dcd9 100644 --- a/src/erlab/analysis/gold.py +++ b/src/erlab/analysis/gold.py @@ -589,10 +589,7 @@ def quick_fit( """ data = darr.mean([d for d in darr.dims if d != "eV"]) - if eV_range is not None: - data_fit = data.sel(eV=slice(*eV_range)) - else: - data_fit = data + data_fit = data.sel(eV=slice(*eV_range)) if eV_range is not None else data if temp is None: if "temp_sample" in data.attrs: diff --git a/src/erlab/analysis/image.py b/src/erlab/analysis/image.py index 27978467..b13d1abe 100644 --- a/src/erlab/analysis/image.py +++ b/src/erlab/analysis/image.py @@ -63,14 +63,14 @@ def _parse_dict_arg( f"{'' if len(required_dims) == 1 else 's'}: {required_dims}" ) - for d in sigma_dict.keys(): + for d in sigma_dict: if d not in dims: raise ValueError( f"Dimension `{d}` in {arg_name} not found in {reference_name}" ) # Make sure that sigma_dict is ordered in temrs of data dims - return {d: sigma_dict[d] for d in dims if d in sigma_dict.keys()} + return {d: sigma_dict[d] for d in dims if d in sigma_dict} def gaussian_filter( @@ -164,14 +164,14 @@ def gaussian_filter( ) # Get the axis indices to apply the filter - axes = tuple(darr.get_axis_num(d) for d in sigma_dict.keys()) + axes = tuple(darr.get_axis_num(d) for d in sigma_dict) # Convert arguments to tuples acceptable by scipy if isinstance(order, Mapping): - order = tuple(order.get(str(d), 0) for d in sigma_dict.keys()) + order = tuple(order.get(str(d), 0) for d in sigma_dict) if isinstance(mode, Mapping): - mode = tuple(mode[str(d)] for d in sigma_dict.keys()) + mode = tuple(mode[str(d)] for d in sigma_dict) if radius is not None: radius_dict = _parse_dict_arg( @@ -186,7 +186,7 @@ def gaussian_filter( else: radius_pix = None - for d in sigma_dict.keys(): + for d in sigma_dict: if not is_uniform_spaced(darr[d].values): raise ValueError(f"Dimension `{d}` is not uniformly spaced") @@ -268,7 +268,7 @@ def gaussian_laplace( # Convert mode to tuple acceptable by scipy if isinstance(mode, dict): - mode = tuple(mode[d] for d in sigma_dict.keys()) + mode = tuple(mode[d] for d in sigma_dict) # Calculate sigma in pixels sigma_pix: tuple[float, ...] = tuple( @@ -433,10 +433,7 @@ def ndsavgol( if method not in ["pinv", "lstsq"]: raise ValueError("method must be 'pinv' or 'lstsq'") - if method == "lstsq": - accurate = True - else: - accurate = False + accurate = method == "lstsq" if isinstance(window_shape, int): window_shape = (window_shape,) * arr.ndim diff --git a/src/erlab/analysis/mask/__init__.py b/src/erlab/analysis/mask/__init__.py index 981164a7..9e570bfc 100644 --- a/src/erlab/analysis/mask/__init__.py +++ b/src/erlab/analysis/mask/__init__.py @@ -219,9 +219,8 @@ def spherical_mask( array([False, True, True, True, False]) Dimensions without coordinates: x """ - if isinstance(radius, dict): - if set(radius.keys()) != set(sel_kw.keys()): - raise ValueError("Keys in radius and sel_kw must match") + if isinstance(radius, dict) and set(radius.keys()) != set(sel_kw.keys()): + raise ValueError("Keys in radius and sel_kw must match") if len(sel_kw) == 0: raise ValueError("No dimensions provided for mask") @@ -232,10 +231,7 @@ def spherical_mask( if k not in darr.dims: raise ValueError(f"Dimension {k} not found in data") - if isinstance(radius, dict): - r = radius[k] - else: - r = float(radius) + r = radius[k] if isinstance(radius, dict) else float(radius) delta_squared = delta_squared + ((darr[k] - v) / r) ** 2 @@ -287,10 +283,7 @@ def hex_bz_mask_points( invert: bool = False, ) -> npt.NDArray[np.bool_]: """Return a mask for given points.""" - if reciprocal: - d = 2 * np.pi / (a * 3) - else: - d = a + d = 2 * np.pi / (a * 3) if reciprocal else a ang = rotate + np.array([0, 60, 120, 180, 240, 300]) vertices = np.array( [ diff --git a/src/erlab/interactive/colors.py b/src/erlab/interactive/colors.py index 7c52773d..3cab54ad 100644 --- a/src/erlab/interactive/colors.py +++ b/src/erlab/interactive/colors.py @@ -14,6 +14,7 @@ "pg_colormap_to_QPixmap", ] +import contextlib import weakref from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING, Literal @@ -74,10 +75,8 @@ def __init__(self, *args, **kwargs) -> None: def load_thumbnail(self, index: int) -> None: if not self.thumbnails_loaded: text = self.itemText(index) - try: + with contextlib.suppress(KeyError): self.setItemIcon(index, QtGui.QIcon(pg_colormap_to_QPixmap(text))) - except KeyError: - pass def load_all(self) -> None: self.clear() @@ -628,10 +627,7 @@ def pg_colormap_names( # if (_mpl != []) and (cet != []): # local = [] - if exclude_local: - all_cmaps = cet + _mpl - else: - all_cmaps = local + cet + _mpl + all_cmaps = cet + _mpl if exclude_local else local + cet + _mpl elif exclude_local: all_cmaps = _mpl else: diff --git a/src/erlab/interactive/fermiedge.py b/src/erlab/interactive/fermiedge.py index ad647fda..5f8b5285 100644 --- a/src/erlab/interactive/fermiedge.py +++ b/src/erlab/interactive/fermiedge.py @@ -446,10 +446,7 @@ def _perform_poly_fit(self): method=params["Method"], scale_covar=params["Scale cov"], ) - if self.data_corr is None: - target = self.data - else: - target = self.data_corr + target = self.data if self.data_corr is None else self.data_corr self.corrected = erlab.analysis.correct_with_edge( target, self.result, plot=False, shift_coords=params["Shift coords"] ) @@ -465,10 +462,7 @@ def _perform_spline_fit(self): lam=params["lambda"], ) - if self.data_corr is None: - target = self.data - else: - target = self.data_corr + target = self.data if self.data_corr is None else self.data_corr self.corrected = erlab.analysis.correct_with_edge( target, self.result, plot=False, shift_coords=params["Shift coords"] ) @@ -515,9 +509,8 @@ def gen_code(self, mode: str) -> None: if not p0["Scale cov"]: arg_dict["scale_covar_edge"] = False - if mode == "poly": - if not p1["Scale cov"]: - arg_dict["scale_covar"] = False + if mode == "poly" and not p1["Scale cov"]: + arg_dict["scale_covar"] = False if self.data_corr is None: gen_function_code( diff --git a/src/erlab/interactive/imagetool/__init__.py b/src/erlab/interactive/imagetool/__init__.py index e0696074..f23cbd9b 100644 --- a/src/erlab/interactive/imagetool/__init__.py +++ b/src/erlab/interactive/imagetool/__init__.py @@ -548,7 +548,7 @@ def _open_file( "xarray HDF5 Files (*.h5)": (erlab.io.load_hdf5, {}), "NetCDF Files (*.nc *.nc4 *.cdf)": (xr.load_dataarray, {}), } - for k in erlab.io.loaders.keys(): + for k in erlab.io.loaders: valid_loaders = valid_loaders | erlab.io.loaders[k].file_dialog_methods dialog = QtWidgets.QFileDialog(self) diff --git a/src/erlab/interactive/imagetool/_deprecated/imagetool_mpl.py b/src/erlab/interactive/imagetool/_deprecated/imagetool_mpl.py index 4793ed44..1bea9f47 100644 --- a/src/erlab/interactive/imagetool/_deprecated/imagetool_mpl.py +++ b/src/erlab/interactive/imagetool/_deprecated/imagetool_mpl.py @@ -33,7 +33,7 @@ def qt_style_names(): """Return a list of styles, default platform style first.""" default_style_name = QtWidgets.QApplication.style().objectName().lower() result = [] - for style in QtWidgets.QStyleFactory.keys(): + for style in QtWidgets.QStyleFactory.keys(): # noqa: SIM118 if style.lower() == default_style_name: result.insert(0, style) else: @@ -852,9 +852,8 @@ def get_index_of_value(self, axis, val): def onmove(self, event) -> None: if self.ignore(event): return - if not event.button: - if not self._shift: - return + if not event.button and not self._shift: + return if event.inaxes not in self.axes: return if not self.canvas.widgetlock.available(self): @@ -1365,9 +1364,8 @@ def zoom_new(self, *args) -> None: def onmove_super(self, event) -> None: if event.inaxes not in self.axes: return - if not event.button: - if not self.itool._shift: - return + if not event.button and not self.itool._shift: + return for i in range(self.ndim): self._cursor_spin[i].blockSignals(True) self._cursor_spin[i].setValue(self.itool._last_ind[i]) @@ -1430,18 +1428,20 @@ def itoolmpl(data, *args, **kwargs) -> None: qapp = QtWidgets.QApplication(sys.argv) # print(qapp.devicePixelRatio()) mpl_style = "default" - with plt.rc_context( - { - "text.usetex": False, - # 'font.family':'SF Pro', - # 'font.size':8, - # 'font.stretch':'condensed', - # 'mathtext.fontset':'cm', - # 'font.family':'fantasy', - } + with ( + plt.rc_context( + { + "text.usetex": False, + # 'font.family':'SF Pro', + # 'font.size':8, + # 'font.stretch':'condensed', + # 'mathtext.fontset':'cm', + # 'font.family':'fantasy', + } + ), + plt.style.context(mpl_style), ): - with plt.style.context(mpl_style): - app = ImageTool(data, *args, **kwargs) + app = ImageTool(data, *args, **kwargs) change_style("Fusion") app.show() app.activateWindow() diff --git a/src/erlab/interactive/imagetool/_deprecated/imagetool_old.py b/src/erlab/interactive/imagetool/_deprecated/imagetool_old.py index fd1c665a..0a828b7b 100644 --- a/src/erlab/interactive/imagetool/_deprecated/imagetool_old.py +++ b/src/erlab/interactive/imagetool/_deprecated/imagetool_old.py @@ -7,6 +7,7 @@ """ import colorsys +import contextlib import enum import importlib import sys @@ -91,10 +92,9 @@ def __init__( self.toggled.connect(self.refresh_icons) def refresh_icons(self) -> None: - if self.icon_key_off is not None: - if self.isChecked(): - self.setIcon(qta.icon(ICON_NAME[self.icon_key_off])) - return + if self.icon_key_off is not None and self.isChecked(): + self.setIcon(qta.icon(ICON_NAME[self.icon_key_off])) + return self.setIcon(qta.icon(ICON_NAME[self.icon_key_on])) def changeEvent(self, evt) -> None: @@ -253,7 +253,7 @@ def qt_style_names(): """Return a list of styles, default platform style first.""" default_style_name = QtWidgets.QApplication.style().objectName().lower() result = [] - for style in QtWidgets.QStyleFactory.keys(): + for style in QtWidgets.QStyleFactory.keys(): # noqa: SIM118 if style.lower() == default_style_name: result.insert(0, style) else: @@ -722,7 +722,7 @@ def mouseDragEvent(self, ev) -> None: def setLabels( self, mode: ItoolAxisItem.LabelType = ItoolAxisItem.LabelType.TextLabel, **kwds ) -> None: - for k in kwds.keys(): + for k in kwds: if k != "title": self.getAxis(k).set_label_mode(mode) super().setLabels(**kwds) @@ -751,13 +751,14 @@ def setAxisItems(self, axisItems=None) -> None: if k in axisItems: axis = axisItems[k] - if axis.scene() is not None: - if k not in self.axes or axis != self.axes[k]["item"]: - raise RuntimeError( - "Can't add an axis to multiple plots. Shared axes" - " can be achieved with multiple AxisItem instances" - " and set[X/Y]Link." - ) + if axis.scene() is not None and ( + k not in self.axes or axis != self.axes[k]["item"] + ): + raise RuntimeError( + "Can't add an axis to multiple plots. Shared axes" + " can be achieved with multiple AxisItem instances" + " and set[X/Y]Link." + ) else: axis = ItoolAxisItem(orientation=k, parent=self) @@ -1022,10 +1023,7 @@ def _update_stretch(self, row=None, col=None) -> None: row_factor = (100000, 150000, 300000) else: row_factor = row - if col is None: - col_factor = row_factor - else: - col_factor = col + col_factor = row_factor if col is None else col self._stretch_factors = (row_factor, col_factor) @@ -1340,7 +1338,7 @@ def autoRange(self, padding=None) -> None: def toggle_axes(self, axis) -> None: target = self.axes[axis] - toggle = False if target in self.ci.items.keys() else True + toggle = target not in self.ci.items if self.data_ndim == 2: ref_dims = ((1, 0, 1, 1), (0, 0, 1, 1), (1, 1, 1, 1)) @@ -1390,11 +1388,9 @@ def toggle_axes(self, axis) -> None: return anchors = tuple(ref_dims[i][:2] for i in group) - other_index = [ - x for x in group if x != axis and self.axes[x] in self.ci.items.keys() - ] + other_index = [x for x in group if x != axis and self.axes[x] in self.ci.items] other = [self.axes[i] for i in other_index] - unique = True if len(other) == 0 else False + unique = len(other) == 0 if not toggle: self.removeItem(target) if not unique: @@ -1567,9 +1563,10 @@ def _get_curr_axes_index(self, pos): for i, ax in enumerate(self.axes): if ax.vb.sceneBoundingRect().contains(pos): return i, self._get_mouse_datapos(ax, pos) - if self.colorbar is not None: - if self.colorbar.sceneBoundingRect().contains(pos): - return -1, self._get_mouse_datapos(self.colorbar, pos) + if self.colorbar is not None and self.colorbar.sceneBoundingRect().contains( + pos + ): + return -1, self._get_mouse_datapos(self.colorbar, pos) return None, None def _measure_fps(self) -> None: @@ -2377,10 +2374,9 @@ def fast_isocurve_chain(points): break lines_linked = [np.float64(x) for x in range(0)] for ch in points.values(): - if len(ch) == 2: - ch = ch[1][1:][::-1] + ch[0] # join together ends of chain - else: - ch = ch[0] + ch = ( + ch[1][1:][::-1] + ch[0] if len(ch) == 2 else ch[0] + ) # join together ends of chain lines_linked.append([p[0] for p in ch]) return lines_linked @@ -2404,10 +2400,7 @@ def generatePath(self) -> None: self.path = None return - if self.axisOrder == "row-major": - data = self.data.T - else: - data = self.data + data = self.data.T if self.axisOrder == "row-major" else self.data lines = fast_isocurve(data, self.level, self.connected, self.extendToEdge) # lines = pg.functions.isocurve( @@ -2507,7 +2500,7 @@ def cmap_changed(self) -> None: self.cmap = self.imageItem()._colorMap self.lut = self.imageItem().lut # self.lut = self.cmap.getStops()[1] - if not self.npts == self.lut.shape[0]: + if self.npts != self.lut.shape[0]: self.npts = self.lut.shape[0] self.cbar.setImage(self.cmap.pos.reshape((-1, 1))) self.cbar._colorMap = self.cmap @@ -2986,10 +2979,7 @@ def set_cmap(self, name=None) -> None: self._gamma_slider.blockSignals(True) self._gamma_slider.setValue(self.gamma_scale(gamma)) self._gamma_slider.blockSignals(False) - if isinstance(name, str): - cmap = name - else: - cmap = self._cmap_combo.currentText() + cmap = name if isinstance(name, str) else self._cmap_combo.currentText() mode = self._cmap_mode_button.isChecked() self.itool.set_cmap(cmap, gamma=gamma, reverse=reverse, high_contrast=mode) @@ -3024,10 +3014,8 @@ def __init__(self, *args, **kwargs) -> None: def load_thumbnail(self, index) -> None: if not self.thumbnails_loaded: text = self.itemText(index) - try: + with contextlib.suppress(KeyError): self.setItemIcon(index, QtGui.QIcon(pg_colormap_to_QPixmap(text))) - except KeyError: - pass def load_all(self) -> None: self.clear() diff --git a/src/erlab/interactive/imagetool/controls.py b/src/erlab/interactive/imagetool/controls.py index b012448d..f6b3c92b 100644 --- a/src/erlab/interactive/imagetool/controls.py +++ b/src/erlab/interactive/imagetool/controls.py @@ -10,6 +10,7 @@ "ItoolCrosshairControls", ] +import contextlib import types from typing import TYPE_CHECKING, cast @@ -89,10 +90,9 @@ def get_icon(self, icon: str): return qta.icon(icon) def refresh_icons(self) -> None: - if self.icon_key_off is not None: - if self.isChecked(): - self.setIcon(self.get_icon(self.icon_key_off)) - return + if self.icon_key_off is not None and self.isChecked(): + self.setIcon(self.get_icon(self.icon_key_off)) + return if self.icon_key_on is not None: self.setIcon(self.get_icon(self.icon_key_on)) @@ -210,10 +210,8 @@ def slicer_area(self, value: ImageSlicerArea) -> None: """ # ignore until https://bugreports.qt.io/browse/PYSIDE-229 is fixed - try: + with contextlib.suppress(RuntimeError): self.disconnect_signals() - except RuntimeError: - pass self._slicer_area = value clear_layout(self.layout()) self.sub_controls = [] diff --git a/src/erlab/interactive/imagetool/core.py b/src/erlab/interactive/imagetool/core.py index 065613b4..e7f6bd88 100644 --- a/src/erlab/interactive/imagetool/core.py +++ b/src/erlab/interactive/imagetool/core.py @@ -129,10 +129,7 @@ def suppress_history(method: Callable | None = None): def my_decorator(method: Callable): @functools.wraps(method) def wrapped(self, *args, **kwargs): - if hasattr(self, "slicer_area"): - area = self.slicer_area - else: - area = self + area = self.slicer_area if hasattr(self, "slicer_area") else self with area.history_suppressed(): return method(self, *args, **kwargs) @@ -149,10 +146,7 @@ def record_history(method: Callable | None = None): def my_decorator(method: Callable): @functools.wraps(method) def wrapped(self, *args, **kwargs): - if hasattr(self, "slicer_area"): - area = self.slicer_area - else: - area = self + area = self.slicer_area if hasattr(self, "slicer_area") else self area.sigWriteHistory.emit() with area.history_suppressed(): # Prevent making additional records within the method @@ -201,22 +195,19 @@ def wrapped(*args, **kwargs): skip_sync = kwargs.pop("__slicer_skip_sync", False) out = func(*args, **kwargs) - if args[0].is_linked: - if not skip_sync: - all_args = inspect.Signature.from_callable(func).bind( - *args, **kwargs + if args[0].is_linked and not skip_sync: + all_args = inspect.Signature.from_callable(func).bind(*args, **kwargs) + all_args.apply_defaults() + obj = all_args.arguments.pop("self") + if obj._linking_proxy is not None: + obj._linking_proxy.sync( + obj, + func.__name__, + all_args.arguments, + indices, + steps, + color, ) - all_args.apply_defaults() - obj = all_args.arguments.pop("self") - if obj._linking_proxy is not None: - obj._linking_proxy.sync( - obj, - func.__name__, - all_args.arguments, - indices, - steps, - color, - ) return out return wrapped @@ -311,7 +302,7 @@ def convert_args( steps: bool, ): if indices: - index: int | None = args.get("value", None) + index: int | None = args.get("value") if index is not None: axis: int | None = args.get("axis") @@ -891,9 +882,8 @@ def set_data( raise ValueError("No data variables found in Dataset") from e else: data = xr.DataArray(np.asarray(data)) - if hasattr(data.data, "flags"): - if not data.data.flags["WRITEABLE"]: - data = data.copy() + if hasattr(data.data, "flags") and not data.data.flags["WRITEABLE"]: + data = data.copy() if not rad2deg: self._data = data @@ -1611,10 +1601,8 @@ def update_manual_range(self) -> None: def set_range_from(self, limits: dict[str, list[float]], **kwargs) -> None: for dim, key in zip(self.axis_dims, ("xRange", "yRange"), strict=True): if dim is not None: - try: + with contextlib.suppress(KeyError): kwargs[key] = limits[dim] - except KeyError: - pass if len(kwargs) != 0: self.setRange(**kwargs) @@ -1836,10 +1824,9 @@ def disconnect_signals(self) -> None: @QtCore.Slot(int, object) def refresh_items_data(self, cursor: int, axes: tuple[int] | None = None) -> None: self.refresh_cursor(cursor) - if axes is not None: + if axes is not None and all(elem in self.display_axis for elem in axes): # display_axis는 축 dim 표시하는거임. 즉 해당 축만 바뀌면 데이터 변화 없음 - if all(elem in self.display_axis for elem in axes): - return + return for item in self.slicer_data_items: if item.cursor_index != cursor: continue diff --git a/src/erlab/interactive/imagetool/manager.py b/src/erlab/interactive/imagetool/manager.py index a3b6eefe..c64a2e91 100644 --- a/src/erlab/interactive/imagetool/manager.py +++ b/src/erlab/interactive/imagetool/manager.py @@ -2,6 +2,7 @@ from __future__ import annotations +import contextlib import functools import gc import os @@ -119,10 +120,8 @@ def run(self) -> None: os.remove(f) dirname = os.path.dirname(f) if os.path.isdir(dirname): - try: + with contextlib.suppress(OSError): os.rmdir(dirname) - except OSError: - pass except ( pickle.UnpicklingError, AttributeError, @@ -594,7 +593,7 @@ def closeEvent(self, event: QtGui.QCloseEvent | None) -> None: msg = f"All {self.ntools} remaining windows will be closed." ret = QtWidgets.QMessageBox.question(self, "Do you want to close?", msg) - if not ret == QtWidgets.QMessageBox.StandardButton.Yes: + if ret != QtWidgets.QMessageBox.StandardButton.Yes: if event: event.ignore() return diff --git a/src/erlab/interactive/utils.py b/src/erlab/interactive/utils.py index e1388746..060b0ab7 100644 --- a/src/erlab/interactive/utils.py +++ b/src/erlab/interactive/utils.py @@ -82,13 +82,9 @@ def copy_to_clipboard(content: str | list[str]) -> str: def _parse_single_arg(arg): if isinstance(arg, str): - if arg.startswith("|") and arg.endswith("|"): - # If the string is surrounded by vertical bars, remove them - arg = arg[1:-1] - - else: - # Otherwise, quote the string - arg = f'"{arg}"' + # If the string is surrounded by vertical bars, remove them + # Otherwise, quote the string + arg = arg[1:-1] if arg.startswith("|") and arg.endswith("|") else f'"{arg}"' elif isinstance(arg, dict): # If the argument is a dict, convert to string with double quotes @@ -190,7 +186,7 @@ def format_kwargs(d: dict[str, Any]) -> str: Dictionary of keyword arguments. """ - if all(s.isidentifier() for s in d.keys()): + if all(s.isidentifier() for s in d): return ", ".join(f"{k}={_parse_single_arg(v)!s}" for k, v in d.items()) out = ", ".join(f'"{k}": {_parse_single_arg(v)!s}' for k, v in d.items()) return "{" + out + "}" @@ -426,10 +422,7 @@ def stepEnabled(self): return self.StepEnabledFlag.StepNone def setValue(self, val) -> None: - if np.isnan(val): - val = np.nan - else: - val = max(self.minimum(), min(val, self.maximum())) + val = np.nan if np.isnan(val) else max(self.minimum(), min(val, self.maximum())) if self._only_int and np.isfinite(val): val = round(val) @@ -516,10 +509,7 @@ def __init__(self, *args, **kwargs) -> None: def updateAutoSIPrefix(self) -> None: if self.label.isVisible(): - if self.logMode: - _range = 10 ** np.array(self.range) - else: - _range = self.range + _range = 10 ** np.array(self.range) if self.logMode else self.range (scale, prefix) = pg.siScale( max(abs(_range[0] * self.scale), abs(_range[1] * self.scale)) ) @@ -582,10 +572,7 @@ def labelString(self) -> str: else: units = f"({self.labelUnitPrefix}{self.labelUnits})" - if self.labelText == "": - s = units - else: - s = f"{self.labelText} {units}" + s = units if self.labelText == "" else f"{self.labelText} {units}" style = ";".join([f"{k}: {v}" for k, v in self.labelStyle.items()]) @@ -904,10 +891,7 @@ def __init__( self.untracked = [] self.widgets: dict[str, QtWidgets.QWidget] = {} - if widgets is not None: - kwargs = widgets - else: - kwargs = widgets_kwargs + kwargs = widgets if widgets is not None else widgets_kwargs j = 0 for i, (k, v) in enumerate(kwargs.items()): @@ -1513,10 +1497,7 @@ def __init__(self, *args, **kwargs) -> None: self.mainfunc_kwargs: dict[str, Any] = {} def call_prefunc(self, x): - if self.prefunc_only_values: - xval = np.asarray(x) - else: - xval = x + xval = np.asarray(x) if self.prefunc_only_values else x return self.prefunc(xval, **self.prefunc_kwargs) def set_input(self, data=None) -> None: diff --git a/src/erlab/io/dataloader.py b/src/erlab/io/dataloader.py index 7417f735..70d539e8 100644 --- a/src/erlab/io/dataloader.py +++ b/src/erlab/io/dataloader.py @@ -37,6 +37,7 @@ Callable, ItemsView, Iterable, + Iterator, KeysView, Mapping, Sequence, @@ -718,12 +719,13 @@ def _update_plot(_) -> None: plt.title("") # Remove automatically generated title # Add line at Fermi level if the data is 2D and has an energy dimension - if plot_data.ndim == 2 and "eV" in plot_data.dims: - # Check if binding - if plot_data["eV"].values[0] * plot_data["eV"].values[-1] < 0: - eplt.fermiline( - orientation="h" if plot_data.dims[0] == "eV" else "v" - ) + # that includes zero + if (plot_data.ndim == 2 and "eV" in plot_data.dims) and ( + plot_data["eV"].values[0] * plot_data["eV"].values[-1] < 0 + ): + eplt.fermiline( + orientation="h" if plot_data.dims[0] == "eV" else "v" + ) show_inline_matplotlib_plots() def _next(_) -> None: @@ -1091,10 +1093,7 @@ def load_multiple_parallel( A list of the loaded data. """ if n_jobs is None: - if len(file_paths) < 15: - n_jobs = 1 - else: - n_jobs = -1 + n_jobs = 1 if len(file_paths) < 15 else -1 return joblib.Parallel(n_jobs=n_jobs)( joblib.delayed(self.load_single)(f) for f in file_paths @@ -1173,6 +1172,9 @@ def get(self, key: str) -> LoaderBase: return loader + def __iter__(self) -> Iterator[str]: + return iter(self.loaders) + def __getitem__(self, key: str) -> LoaderBase: return self.get(key) diff --git a/src/erlab/io/igor.py b/src/erlab/io/igor.py index f1cf8035..c4f615cb 100644 --- a/src/erlab/io/igor.py +++ b/src/erlab/io/igor.py @@ -1,3 +1,4 @@ +import contextlib import os from typing import Any @@ -44,9 +45,8 @@ def _load_experiment_raw( def unpack_folders(expt) -> None: for name, record in expt.items(): if isinstance(record, igor2.record.WaveRecord): - if prefix is not None: - if not name.decode().startswith(prefix): - continue + if prefix is not None and not name.decode().startswith(prefix): + continue if name.decode() in ignore: continue waves[name.decode()] = load_wave(record, **kwargs) @@ -227,10 +227,8 @@ def get_dim_name(index): try: v = int(v) except ValueError: - try: + with contextlib.suppress(ValueError): v = float(v) - except ValueError: - pass attrs[k] = v return xr.DataArray( diff --git a/src/erlab/io/plugins/merlin.py b/src/erlab/io/plugins/merlin.py index 55fdc715..ab10ae99 100644 --- a/src/erlab/io/plugins/merlin.py +++ b/src/erlab/io/plugins/merlin.py @@ -201,10 +201,7 @@ def generate_summary( for name, path in files.items(): if os.path.splitext(path)[1] == ".ibw": data = self.load_live(path) - if "beta" in data.dims: - data_type = "LP" - else: - data_type = "LXY" + data_type = "LP" if "beta" in data.dims else "LXY" else: idx, _ = self.infer_index(os.path.splitext(os.path.basename(path))[0]) if idx in processed_indices: diff --git a/src/erlab/io/plugins/ssrl52.py b/src/erlab/io/plugins/ssrl52.py index 660ca9cb..2846fe90 100644 --- a/src/erlab/io/plugins/ssrl52.py +++ b/src/erlab/io/plugins/ssrl52.py @@ -223,11 +223,14 @@ def post_process(self, data: xr.DataArray) -> xr.DataArray: data = data.assign_attrs(temp_sample=temp) # Convert to binding energy - if "sample_workfunction" in data.attrs and "eV" in data.dims: - if data.eV.min() > 0: - data = data.assign_coords( - eV=data.eV - float(data.hv) + data.attrs["sample_workfunction"] - ) + if ( + "sample_workfunction" in data.attrs + and "eV" in data.dims + and data["eV"].min() > 0 + ): + data = data.assign_coords( + eV=data["eV"] - float(data["hv"]) + data.attrs["sample_workfunction"] + ) return data diff --git a/src/erlab/io/utils.py b/src/erlab/io/utils.py index e0174c50..ebe91d9d 100644 --- a/src/erlab/io/utils.py +++ b/src/erlab/io/utils.py @@ -207,9 +207,7 @@ def save_as_hdf5( if isinstance(v, dict): data = data.assign_attrs({k: str(v)}) - if isinstance(data, xr.Dataset): - igor_compat = False - elif data.ndim > 4: + if isinstance(data, xr.Dataset) or data.ndim > 4: igor_compat = False if igor_compat: diff --git a/src/erlab/plotting/annotations.py b/src/erlab/plotting/annotations.py index 79c4394b..524a059b 100644 --- a/src/erlab/plotting/annotations.py +++ b/src/erlab/plotting/annotations.py @@ -148,10 +148,7 @@ def _alph_label(val, prefix, suffix, numeric, capital): if numeric: val = str(val) else: - if capital: - ref_char = "A" - else: - ref_char = "a" + ref_char = "A" if capital else "a" val = chr(int(val) + ord(ref_char) - 1) elif not isinstance(val, str): raise TypeError("Input values must be integers or strings.") @@ -232,7 +229,7 @@ def label_for_dim(dim_name: str, deg2rad: bool = False, escaped: bool = True) -> def parse_special_point(name: str) -> str: special_points = {"G": r"\Gamma", "D": r"\Delta"} - if name in special_points.keys(): + if name in special_points: return special_points[name] return name @@ -243,16 +240,10 @@ def parse_point_labels(name: str, roman: bool = True, bar: bool = False) -> str: if name.endswith("*"): name = name[:-1] - if roman: - format_str = r"\mathdefault{{{}}}^*" - else: - format_str = r"{}^*" + format_str = "\\mathdefault{{{}}}^*" if roman else "{}^*" elif name.endswith("'"): name = name[:-1] - if roman: - format_str = r"\mathdefault{{{}}}\prime" - else: - format_str = r"{}\prime" + format_str = "\\mathdefault{{{}}}\\prime" if roman else "{}\\prime" elif roman: format_str = r"\mathdefault{{{}}}" else: @@ -260,12 +251,7 @@ def parse_point_labels(name: str, roman: bool = True, bar: bool = False) -> str: name = format_str.format(parse_special_point(name)) - if bar: - name = rf"$\overline{{{name}}}$" - else: - name = rf"${name}$" - - return name + return f"$\\overline{{{name}}}$" if bar else f"${name}$" def copy_mathtext( @@ -506,10 +492,7 @@ def label_subplots( for i, ax in enumerate(axlist): if fontsize is None: - if isinstance(ax, matplotlib.figure.Figure): - fontsize = "large" - else: - fontsize = "medium" + fontsize = "large" if isinstance(ax, matplotlib.figure.Figure) else "medium" label_str = _alph_label(value_arr[i], prefix, suffix, numeric, capital) with plt.rc_context({"text.color": axes_textcolor(ax)}): @@ -818,10 +801,7 @@ def plot_hv_text_right(ax, val, x=1 - 0.025, y=0.975, **kwargs) -> None: def property_label(key, value, decimals=None, si=0, name=None, unit=None) -> str: - if name == "": - delim = "" - else: - delim = " = " + delim = "" if name == "" else " = " if name is None: name = name_for_dim(key, escaped=False) if name is None: diff --git a/src/erlab/plotting/atoms.py b/src/erlab/plotting/atoms.py index d7bd8d62..0833a5ff 100644 --- a/src/erlab/plotting/atoms.py +++ b/src/erlab/plotting/atoms.py @@ -178,9 +178,11 @@ def draw(self, renderer) -> None: proj_sizes = np.sqrt(self.sizes_orig) super().set_sizes(proj_sizes**2, self.figure.dpi) - with self._use_zordered_offset(): - with matplotlib.cbook._setattr_cm(self, _in_draw=True): - matplotlib.collections.Collection.draw(self, renderer) + with ( + self._use_zordered_offset(), + matplotlib.cbook._setattr_cm(self, _in_draw=True), + ): + matplotlib.collections.Collection.draw(self, renderer) # def draw(self, renderer): # # Note: unlike in the 2D case, where we can enforce equal diff --git a/src/erlab/plotting/bz.py b/src/erlab/plotting/bz.py index d661e5a1..d58bbbbc 100644 --- a/src/erlab/plotting/bz.py +++ b/src/erlab/plotting/bz.py @@ -130,10 +130,7 @@ def plot_hex_bz( ] kwargs["edgecolor"] = kwargs.pop("edgecolor", kwargs.pop("ec", axes_textcolor(ax))) - if reciprocal: - r = 4 * np.pi / (a * 3) - else: - r = 2 * a + r = 4 * np.pi / (a * 3) if reciprocal else 2 * a clip = kwargs.pop("clip_path", None) poly = RegularPolygon(offset, 6, radius=r, orientation=np.deg2rad(rotate), **kwargs) diff --git a/src/erlab/plotting/colors.py b/src/erlab/plotting/colors.py index eebb5cfc..3e1682c3 100644 --- a/src/erlab/plotting/colors.py +++ b/src/erlab/plotting/colors.py @@ -45,6 +45,7 @@ "unify_clim", ] +import contextlib from collections.abc import Iterable, Sequence from numbers import Number from typing import Any, Literal, cast @@ -484,14 +485,13 @@ def get_mappable( except (IndexError, AttributeError): mappable = None - if mappable is None: - if not silent: - raise RuntimeError( - "No mappable was found to use for colorbar " - "creation. First define a mappable such as " - "an image (with imshow) or a contour set (" - "with contourf)." - ) + if mappable is None and not silent: + raise RuntimeError( + "No mappable was found to use for colorbar " + "creation. First define a mappable such as " + "an image (with imshow) or a contour set (" + "with contourf)." + ) return mappable @@ -978,9 +978,8 @@ def _get_segment_for_color( cmap: matplotlib.colors.LinearSegmentedColormap, color: Literal["red", "green", "blue", "alpha"], ) -> Any: - if hasattr(cmap, "_segmentdata"): - if color in cmap._segmentdata: - return cmap._segmentdata[color] + if hasattr(cmap, "_segmentdata") and color in cmap._segmentdata: + return cmap._segmentdata[color] return None @@ -988,9 +987,14 @@ def _is_segment_iterable(cmap: matplotlib.colors.Colormap) -> bool: if not isinstance(cmap, matplotlib.colors.LinearSegmentedColormap): return False - if any(callable(_get_segment_for_color(cmap, c)) for c in ["red", "green", "blue"]): # type: ignore[arg-type] - return False - return True + return not any( + callable( + _get_segment_for_color( + cmap, cast(Literal["red", "green", "blue", "alpha"], c) + ) + ) + for c in ["red", "green", "blue"] + ) def combined_cmap( @@ -1255,16 +1259,16 @@ def axes_textcolor( """ c = light mappable = get_mappable(ax, silent=True) - if mappable is not None: - if isinstance( + if ( + mappable is not None + and isinstance( mappable, matplotlib.image._ImageBase | matplotlib.collections.QuadMesh - ): - if not image_is_light(mappable): - c = dark + ) + and not image_is_light(mappable) + ): + c = dark return c -try: +with contextlib.suppress(ValueError): combined_cmap("bone_r", "hot", "bonehot", register=True) -except ValueError: - pass diff --git a/src/erlab/plotting/general.py b/src/erlab/plotting/general.py index 6c4579cd..dc2452ad 100644 --- a/src/erlab/plotting/general.py +++ b/src/erlab/plotting/general.py @@ -59,10 +59,7 @@ def figwh(ratio=0.6180339887498948, wide=0, wscale=1, style="aps", fixed_height= if isinstance(ratio, str): ratio = float(ratio) * 2 / (1 + np.sqrt(5)) w = figure_width_ref[style][wide] - if fixed_height: - h = w * ratio - else: - h = w * wscale * ratio + h = w * ratio if fixed_height else w * wscale * ratio return w * wscale, h @@ -529,14 +526,8 @@ def plot_array_2d( >>> carr = xr.DataArray([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) >>> eplt.plot_array_2d(larr, carr) """ - if lnorm is None: - lnorm = plt.Normalize() - else: - lnorm = copy.deepcopy(lnorm) - if cnorm is None: - cnorm = plt.Normalize() - else: - cnorm = copy.deepcopy(cnorm) + lnorm = plt.Normalize() if lnorm is None else copy.deepcopy(lnorm) + cnorm = plt.Normalize() if cnorm is None else copy.deepcopy(cnorm) if colorbar_kw is None: colorbar_kw = {} if imshow_kw is None: diff --git a/src/erlab/utils/array.py b/src/erlab/utils/array.py index 22140ffb..5a644558 100644 --- a/src/erlab/utils/array.py +++ b/src/erlab/utils/array.py @@ -100,10 +100,7 @@ def is_dims_uniform( if dims is None: dims = darr.dims - for dim in dims: - if not is_uniform_spaced(darr[dim].values, **kwargs): - return False - return True + return all(is_uniform_spaced(darr[dim].values, **kwargs) for dim in dims) def check_arg_2d_darr(func: Callable | None = None):