Skip to content

Commit

Permalink
feat: add callable fit accessor using apply_ufunc
Browse files Browse the repository at this point in the history
Add a `Dataset.modelfit` and `DataArray.modelfit` accessor with similar syntax and output as `Dataset.curvefit`. Closes #22
  • Loading branch information
kmnhan committed Apr 17, 2024
1 parent 8a47325 commit e06982d
Show file tree
Hide file tree
Showing 4 changed files with 331 additions and 5 deletions.
5 changes: 4 additions & 1 deletion src/erlab/accessors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
"""

__all__ = [
"ImageToolAccessor",
"ModelFitDataArrayAccessor",
"ModelFitDatasetAccessor",
"MomentumAccessor",
"OffsetView",
"ImageToolAccessor",
"PlotAccessor",
"SelectionAccessor",
]

from erlab.accessors.fit import ModelFitDataArrayAccessor, ModelFitDatasetAccessor
from erlab.accessors.kspace import MomentumAccessor, OffsetView
from erlab.accessors.utils import ImageToolAccessor, PlotAccessor, SelectionAccessor
321 changes: 321 additions & 0 deletions src/erlab/accessors/fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
__all__ = ["ModelFitDataArrayAccessor", "ModelFitDatasetAccessor"]

import warnings
from collections.abc import Hashable, Iterable, Mapping, Sequence
from typing import Any, Literal

import lmfit
import numpy as np
import xarray as xr
from xarray.core.types import Dims

from erlab.accessors.utils import _THIS_ARRAY, ERLabAccessor


@xr.register_dataset_accessor("modelfit")
class ModelFitDatasetAccessor(ERLabAccessor):
"""`xarray.Dataset.modelfit` accessor for fitting lmfit models."""

def __call__(
self,
coords: str | xr.DataArray | Iterable[str | xr.DataArray],
model: lmfit.Model,
reduce_dims: Dims = None,
skipna: bool = True,
params: str
| lmfit.Parameters
| Mapping[str, float | dict[str, Any]]
| xr.DataArray
| None = None,
guess: bool = False,
errors: Literal["raise", "ignore"] = "raise",
**kwargs: dict[str, Any],
) -> xr.Dataset:
"""
Curve fitting optimization for arbitrary functions.
Wraps :func:`lmfit.Model.fit` with `apply_ufunc`.
Parameters
----------
coords : hashable, xarray.DataArray, or sequence of hashable or xarray.DataArray
Independent coordinate(s) over which to perform the curve fitting. Must
share at least one dimension with the calling object. When fitting
multi-dimensional functions, supply `coords` as a sequence in the same order
as arguments in `func`. To fit along existing dimensions of the calling
object, `coords` can also be specified as a str or sequence of strs.
model : lmfit.Model
A model object to fit to the data. The model must be an instance of
`lmfit.Model`.
reduce_dims : str, Iterable of Hashable or None, optional
Additional dimension(s) over which to aggregate while fitting. For example,
calling `ds.curvefit(coords='time', reduce_dims=['lat', 'lon'], ...)` will
aggregate all lat and lon points and fit the specified function along the
time dimension.
skipna : bool, default: True
Whether to skip missing values when fitting. Default is True.
params : str, lmfit.Parameters, dict-like, or xarray.DataArray, optional
Optional input parameters to the fit. If a string, it should be a JSON
string representation of the parameters, generated by
`lmfit.Parameters.dumps`. If a `lmfit.Parameters` object, it will be used as
is. If a dict-like object, it will be converted to a `lmfit.Parameters`
object. If the values are DataArrays, they will be appropriately broadcast
to the coordinates of the array. If none or only some parameters are passed,
the rest will be assigned initial values and bounds with
:meth:`lmfit.Model.make_params`, or guessed with :meth:`lmfit.Model.guess`
if `guess` is `True`.
guess : bool, default: `False`
Whether to guess the initial parameters with :meth:`lmfit.Model.guess`. For
composite models, the parameters will be guessed for each component.
errors : {"raise", "ignore"}, default: "raise"
If `'raise'`, any errors from the :func:`lmfit.Model.fit` optimization will
raise an exception. If `'ignore'`, the return values for the coordinates
where the fitting failed will be NaN.
**kwargs : optional
Additional keyword arguments to passed to :func:`lmfit.Model.fit`.
Returns
-------
curvefit_results : xarray.Dataset
A single dataset which contains:
[var]_modelfit_coefficients
The coefficients of the best fit.
[var]_modelfit_stderr
The standard errors of the coefficients.
[var]_modelfit_stats
Statistics from the fit. See :func:`lmfit.minimize`.
See Also
--------
xarray.Dataset.curve_fit xarray.Dataset.polyfit lmfit.model.Model.fit
scipy.optimize.curve_fit
"""
# Implementation analogous to xarray.Dataset.curve_fit

if params is None:
params = lmfit.create_params()

if kwargs is None:
kwargs = {}

# Input to apply_ufunc cannot be a Mapping, so convert parameters to str
if isinstance(params, lmfit.Parameters):
params: str = params.dumps()
elif isinstance(params, Mapping):
# Given as a mapping from str to float or dict
params: str = lmfit.create_params(**params).dumps()

reduce_dims_: list[Hashable]
if not reduce_dims:
reduce_dims_ = []
elif isinstance(reduce_dims, str) or not isinstance(reduce_dims, Iterable):
reduce_dims_ = [reduce_dims]
else:
reduce_dims_ = list(reduce_dims)

if (
isinstance(coords, str)
or isinstance(coords, xr.DataArray)
or not isinstance(coords, Iterable)
):
coords = [coords]
coords_: Sequence[xr.DataArray] = [
self._obj[coord] if isinstance(coord, str) else coord for coord in coords
]

# Determine whether any coords are dims on self._obj
for coord in coords_:
reduce_dims_ += [c for c in self._obj.dims if coord.equals(self._obj[c])]
reduce_dims_ = list(set(reduce_dims_))
preserved_dims = list(set(self._obj.dims) - set(reduce_dims_))
if not reduce_dims_:
raise ValueError(
"No arguments to `coords` were identified as a dimension on the "
"calling object, and no dims were supplied to `reduce_dims`. This "
"would result in fitting on scalar data."
)

# Check that initial guess and bounds only contain coords in preserved_dims
if isinstance(params, xr.DataArray):
unexpected = set(params.dims) - set(preserved_dims)
if unexpected:
raise ValueError(
f"Initial guess has unexpected dimensions {tuple(unexpected)}. It "
"should only have dimensions that are in data dimensions "
f"{preserved_dims}."
)

if errors not in ["raise", "ignore"]:
raise ValueError('errors must be either "raise" or "ignore"')

# Broadcast all coords with each other
coords_ = xr.broadcast(*coords_)
coords_ = [
coord.broadcast_like(self._obj, exclude=preserved_dims) for coord in coords_
]
n_coords = len(coords_)

# Get the parameter names
param_names: list[str] = model.param_names
n_params = len(param_names)

# Define the statistics to extract from the fit result
stat_names = [
"nfev",
"nvarys",
"ndata",
"nfree",
"chisqr",
"redchi",
"aic",
"bic",
]
n_stats = len(stat_names)

def _wrapper(Y, *args, **kwargs):
# Wrap Model.fit with raveled coordinates and pointwise NaN handling
# *args contains:
# - the coordinates
# - parameters object
coords__ = args[:n_coords]
init_params_ = args[n_coords]

if guess:
initial_params = lmfit.create_params()
else:
initial_params = model.make_params()
if isinstance(init_params_, str):
initial_params.update(lmfit.Parameters().loads(init_params_))
elif isinstance(init_params_, lmfit.Parameters):
initial_params.update(init_params_)
elif isinstance(init_params_, Mapping):
for p, v in init_params_.items():
if isinstance(v, Mapping):
initial_params[p].set(**v)
else:
initial_params[p].set(value=v)

x = np.vstack([c.ravel() for c in coords__])
y = Y.ravel()
if skipna:
mask = np.all([np.any(~np.isnan(x), axis=0), ~np.isnan(y)], axis=0)
x = x[:, mask]
y = y[mask]
if not len(y):
popt = np.full([n_params], np.nan)
perr = np.full([n_params, n_params], np.nan)
stats = np.full([n_stats], np.nan)
return popt, perr, stats
x = np.squeeze(x)

if n_coords == 1:
indep_var_kwargs = {model.independent_vars[0]: x}
else:
indep_var_kwargs = dict(zip(model.independent_vars[:n_coords], x))

if guess:
if isinstance(model, lmfit.model.CompositeModel):
guessed_params = model.make_params()
for comp in model.components:
try:
guessed_params.update(comp.guess(y, **indep_var_kwargs))
except NotImplementedError:
pass
# Given parameters must override guessed parameters
initial_params = guessed_params.update(initial_params)

else:
try:
initial_params = model.guess(y, **indep_var_kwargs).update(
initial_params
)
except NotImplementedError:
warnings.warn(
f"`guess` is not implemented for {model}, "
"using supplied initial parameters",
stacklevel=1,
)
initial_params = model.make_params().update(initial_params)
try:
fitresult: lmfit.model.ModelResult = model.fit(
y, **indep_var_kwargs, params=initial_params, **kwargs
)
except RuntimeError:
if errors == "raise":
raise
popt = np.full([n_params], np.nan)
perr = np.full([n_params, n_params], np.nan)
stats = np.full([n_stats], np.nan)
else:
if fitresult.success:
popt, perr = [], []
for name in param_names:
p = fitresult.params[name]
popt.append(p.value if p.value is not None else np.nan)
perr.append(p.stderr if p.stderr is not None else np.nan)
popt, perr = np.array(popt), np.array(perr)

stats = [getattr(fitresult, s) for s in stat_names]
stats = np.array([s if s is not None else np.nan for s in stats])
else:
popt = np.full([n_params], np.nan)
perr = np.full([n_params, n_params], np.nan)
stats = np.full([n_stats], np.nan)

return popt, perr, stats

result = type(self._obj)()
for name, da in self._obj.data_vars.items():
if name is _THIS_ARRAY:
name = ""
else:
name = f"{name!s}_"

input_core_dims = [reduce_dims_ for _ in range(n_coords + 1)]
input_core_dims.extend([[] for _ in range(1)]) # core_dims for parameters

popt, perr, stats = xr.apply_ufunc(
_wrapper,
da,
*coords_,
params,
vectorize=True,
dask="parallelized",
input_core_dims=input_core_dims,
output_core_dims=[["param"], ["param"], ["fit_stat"]],
dask_gufunc_kwargs={
"output_sizes": {
"param": n_params,
"stat": n_stats,
},
},
output_dtypes=(np.float64, np.float64, np.float64),
exclude_dims=set(reduce_dims_),
kwargs=kwargs,
)
result[name + "modelfit_coefficients"] = popt
result[name + "modelfit_stderr"] = perr
result[name + "modelfit_stats"] = stats

result = result.assign_coords({"param": param_names, "fit_stat": stat_names})
result.attrs = self._obj.attrs.copy()

return result


@xr.register_dataarray_accessor("modelfit")
class ModelFitDataArrayAccessor(ERLabAccessor):
"""`xarray.DataArray.modelfit` accessor for fitting lmfit models."""

def __call__(self, *args, **kwargs):
return self._obj.to_dataset(name=_THIS_ARRAY).modelfit(*args, **kwargs)

__call__.__doc__ = (
ModelFitDatasetAccessor.__call__.__doc__.replace(
"Dataset.curve_fit", "DataArray.curve_fit"
)
.replace("Dataset.polyfit", "DataArray.polyfit")
.replace("[var]_", "")
)
6 changes: 2 additions & 4 deletions src/erlab/accessors/kspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy.typing as npt
import xarray as xr

from erlab.accessors.utils import ERLabAccessor
from erlab.analysis.interpolate import interpn
from erlab.analysis.kspace import AxesConfiguration, get_kconv_func, kz_func
from erlab.constants import rel_kconv, rel_kzconv
Expand Down Expand Up @@ -175,7 +176,7 @@ def reset(self) -> "OffsetView":


@xr.register_dataarray_accessor("kspace")
class MomentumAccessor:
class MomentumAccessor(ERLabAccessor):
"""`xarray.DataArray.kspace` accessor for momentum conversion related utilities.
This class provides convenient access to various momentum-related properties of a
Expand All @@ -185,9 +186,6 @@ class MomentumAccessor:
"""

def __init__(self, xarray_obj: xr.DataArray):
self._obj = xarray_obj

@property
def configuration(self) -> AxesConfiguration:
"""Returns the experimental configuration.
Expand Down
4 changes: 4 additions & 0 deletions src/erlab/accessors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

T = TypeVar("T")

# Used as the key corresponding to a DataArray's variable when converting
# arbitrary DataArray objects to datasets, from xarray.core.dataarray
_THIS_ARRAY = "<this-array>"


class ERLabAccessor:
"""Base class for accessors."""
Expand Down

0 comments on commit e06982d

Please sign in to comment.