From 60434136224c0875ed8fba41d24e32fc6868127c Mon Sep 17 00:00:00 2001 From: Kimoon Han Date: Mon, 13 May 2024 15:07:15 +0900 Subject: [PATCH] fix(analysis.image): correct argument order parsing in some filters --- src/erlab/analysis/image.py | 99 ++++++++++++++++++++++++------------- 1 file changed, 64 insertions(+), 35 deletions(-) diff --git a/src/erlab/analysis/image.py b/src/erlab/analysis/image.py index 22f5839..36dfffd 100644 --- a/src/erlab/analysis/image.py +++ b/src/erlab/analysis/image.py @@ -15,7 +15,7 @@ import itertools import math -from collections.abc import Collection, Hashable, Mapping, Sequence, Sized +from collections.abc import Collection, Hashable, Mapping, Sequence from typing import Literal import numpy as np @@ -26,6 +26,50 @@ from numba import carray, cfunc, types +def _parse_dict_arg( + dims: Sequence[Hashable], + sigma: float | Collection[float] | Mapping[Hashable, float], + arg_name: str, + reference_name: str, + allow_subset: bool = False, +) -> dict[Hashable, float]: + if isinstance(sigma, Mapping): + sigma_dict = dict(sigma) + + elif np.isscalar(sigma): + sigma_dict = dict.fromkeys(dims, sigma) + + elif isinstance(sigma, Collection): + if len(sigma) != len(dims): + raise ValueError( + f"`{arg_name}` does not match dimensions of {reference_name}" + ) + + sigma_dict = dict(zip(dims, sigma, strict=True)) + + else: + raise TypeError(f"`{arg_name}` must be a scalar, sequence, or mapping") + + if not allow_subset and len(sigma_dict) != len(dims): + required_dims = set(dims) - set(sigma_dict.keys()) + raise ValueError( + f"`{arg_name}` missing for the following dimension" + f"{'' if len(required_dims) == 1 else 's'}: {required_dims}" + ) + + else: + for d in sigma_dict.keys(): + if d not in dims: + raise ValueError( + f"Dimension `{d}` in {arg_name} not found in {reference_name}" + ) + + # Make sure that sigma_dict is ordered in temrs of data dims + sigma_dict = {d: sigma_dict[d] for d in dims if d in sigma_dict.keys()} + + return sigma_dict + + def gaussian_filter( darr: xr.DataArray, sigma: float | Collection[float] | Mapping[Hashable, float], @@ -108,14 +152,13 @@ def gaussian_filter( Dimensions without coordinates: x, y """ - if isinstance(sigma, Mapping): - sigma_dict = dict(sigma) - elif np.isscalar(sigma): - sigma_dict = dict.fromkeys(darr.dims, sigma) - elif isinstance(sigma, Collection): - sigma_dict = dict(zip(darr.dims, sigma, strict=True)) - else: - raise TypeError("`sigma` must be a scalar, sequence, or mapping") + sigma_dict: dict[Hashable, float] = _parse_dict_arg( + darr.dims, + sigma, + arg_name="sigma", + reference_name="DataArray", + allow_subset=True, + ) # Get the axis indices to apply the filter axes = tuple(darr.get_axis_num(d) for d in sigma_dict.keys()) @@ -123,20 +166,14 @@ def gaussian_filter( # Convert arguments to tuples acceptable by scipy if isinstance(order, Mapping): order = tuple(order.get(str(d), 0) for d in sigma_dict.keys()) + if isinstance(mode, Mapping): mode = tuple(mode[str(d)] for d in sigma_dict.keys()) if radius is not None: - if isinstance(radius, Mapping): - radius_dict = dict(radius) - elif isinstance(radius, Sized): - if len(radius) != len(sigma_dict): - raise ValueError("`radius` does not match dimensions of `sigma`") - radius_dict = dict(zip(sigma_dict.keys(), radius, strict=True)) - elif np.isscalar(radius): - radius_dict = dict.fromkeys(sigma_dict.keys(), radius) - else: - raise TypeError("`radius` must be a scalar, sequence, or mapping") + radius_dict = _parse_dict_arg( + tuple(sigma_dict.keys()), radius, "radius", "`sigma`" + ) # Calculate radius in pixels radius_pix: tuple[int, ...] | None = tuple( @@ -167,7 +204,7 @@ def gaussian_filter( def gaussian_laplace( darr: xr.DataArray, - sigma: float | Collection[float] | Mapping[str, float], + sigma: float | Collection[float] | Mapping[Hashable, float], mode: str | Sequence[str] | Mapping[str, str] = "nearest", cval: float = 0.0, **kwargs, @@ -214,21 +251,13 @@ def gaussian_laplace( :func:`scipy.ndimage.gaussian_laplace` : The underlying function used to apply the filter. """ - if isinstance(sigma, Mapping): - sigma_dict = dict(sigma) - elif np.isscalar(sigma): - sigma_dict = dict.fromkeys(darr.dims, sigma) - elif isinstance(sigma, Collection): - sigma_dict = dict(zip(darr.dims, sigma, strict=True)) - else: - raise TypeError("`sigma` must be a scalar, sequence, or mapping") - - if len(sigma_dict) != darr.ndim: - required_dims = set(darr.dims) - set(sigma_dict.keys()) - raise ValueError( - "`sigma` missing for the following dimension" - f"{'' if len(required_dims) == 1 else 's'}: {required_dims}" - ) + sigma_dict: dict[Hashable, float] = _parse_dict_arg( + darr.dims, + sigma, + arg_name="sigma", + reference_name="DataArray", + allow_subset=False, + ) # Convert mode to tuple acceptable by scipy if isinstance(mode, dict):