Skip to content

Commit

Permalink
fix(analysis.image): correct argument order parsing in some filters
Browse files Browse the repository at this point in the history
  • Loading branch information
kmnhan committed May 13, 2024
1 parent 0321ec1 commit 6043413
Showing 1 changed file with 64 additions and 35 deletions.
99 changes: 64 additions & 35 deletions src/erlab/analysis/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import itertools
import math
from collections.abc import Collection, Hashable, Mapping, Sequence, Sized
from collections.abc import Collection, Hashable, Mapping, Sequence
from typing import Literal

import numpy as np
Expand All @@ -26,6 +26,50 @@
from numba import carray, cfunc, types


def _parse_dict_arg(
dims: Sequence[Hashable],
sigma: float | Collection[float] | Mapping[Hashable, float],
arg_name: str,
reference_name: str,
allow_subset: bool = False,
) -> dict[Hashable, float]:
if isinstance(sigma, Mapping):
sigma_dict = dict(sigma)

elif np.isscalar(sigma):
sigma_dict = dict.fromkeys(dims, sigma)

elif isinstance(sigma, Collection):
if len(sigma) != len(dims):
raise ValueError(
f"`{arg_name}` does not match dimensions of {reference_name}"
)

sigma_dict = dict(zip(dims, sigma, strict=True))

else:
raise TypeError(f"`{arg_name}` must be a scalar, sequence, or mapping")

if not allow_subset and len(sigma_dict) != len(dims):
required_dims = set(dims) - set(sigma_dict.keys())
raise ValueError(
f"`{arg_name}` missing for the following dimension"
f"{'' if len(required_dims) == 1 else 's'}: {required_dims}"
)

else:
for d in sigma_dict.keys():
if d not in dims:
raise ValueError(
f"Dimension `{d}` in {arg_name} not found in {reference_name}"
)

# Make sure that sigma_dict is ordered in temrs of data dims
sigma_dict = {d: sigma_dict[d] for d in dims if d in sigma_dict.keys()}

return sigma_dict


def gaussian_filter(
darr: xr.DataArray,
sigma: float | Collection[float] | Mapping[Hashable, float],
Expand Down Expand Up @@ -108,35 +152,28 @@ def gaussian_filter(
Dimensions without coordinates: x, y
"""
if isinstance(sigma, Mapping):
sigma_dict = dict(sigma)
elif np.isscalar(sigma):
sigma_dict = dict.fromkeys(darr.dims, sigma)
elif isinstance(sigma, Collection):
sigma_dict = dict(zip(darr.dims, sigma, strict=True))
else:
raise TypeError("`sigma` must be a scalar, sequence, or mapping")
sigma_dict: dict[Hashable, float] = _parse_dict_arg(
darr.dims,
sigma,
arg_name="sigma",
reference_name="DataArray",
allow_subset=True,
)

# Get the axis indices to apply the filter
axes = tuple(darr.get_axis_num(d) for d in sigma_dict.keys())

# Convert arguments to tuples acceptable by scipy
if isinstance(order, Mapping):
order = tuple(order.get(str(d), 0) for d in sigma_dict.keys())

if isinstance(mode, Mapping):
mode = tuple(mode[str(d)] for d in sigma_dict.keys())

if radius is not None:
if isinstance(radius, Mapping):
radius_dict = dict(radius)
elif isinstance(radius, Sized):
if len(radius) != len(sigma_dict):
raise ValueError("`radius` does not match dimensions of `sigma`")
radius_dict = dict(zip(sigma_dict.keys(), radius, strict=True))
elif np.isscalar(radius):
radius_dict = dict.fromkeys(sigma_dict.keys(), radius)
else:
raise TypeError("`radius` must be a scalar, sequence, or mapping")
radius_dict = _parse_dict_arg(
tuple(sigma_dict.keys()), radius, "radius", "`sigma`"
)

# Calculate radius in pixels
radius_pix: tuple[int, ...] | None = tuple(
Expand Down Expand Up @@ -167,7 +204,7 @@ def gaussian_filter(

def gaussian_laplace(
darr: xr.DataArray,
sigma: float | Collection[float] | Mapping[str, float],
sigma: float | Collection[float] | Mapping[Hashable, float],
mode: str | Sequence[str] | Mapping[str, str] = "nearest",
cval: float = 0.0,
**kwargs,
Expand Down Expand Up @@ -214,21 +251,13 @@ def gaussian_laplace(
:func:`scipy.ndimage.gaussian_laplace` : The underlying function used to apply the
filter.
"""
if isinstance(sigma, Mapping):
sigma_dict = dict(sigma)
elif np.isscalar(sigma):
sigma_dict = dict.fromkeys(darr.dims, sigma)
elif isinstance(sigma, Collection):
sigma_dict = dict(zip(darr.dims, sigma, strict=True))
else:
raise TypeError("`sigma` must be a scalar, sequence, or mapping")

if len(sigma_dict) != darr.ndim:
required_dims = set(darr.dims) - set(sigma_dict.keys())
raise ValueError(
"`sigma` missing for the following dimension"
f"{'' if len(required_dims) == 1 else 's'}: {required_dims}"
)
sigma_dict: dict[Hashable, float] = _parse_dict_arg(
darr.dims,
sigma,
arg_name="sigma",
reference_name="DataArray",
allow_subset=False,
)

# Convert mode to tuple acceptable by scipy
if isinstance(mode, dict):
Expand Down

0 comments on commit 6043413

Please sign in to comment.