From 59b35f53f3ef7f518aec92e05854dba42ddba56f Mon Sep 17 00:00:00 2001 From: Kimoon Han Date: Sat, 20 Apr 2024 12:44:14 +0900 Subject: [PATCH] feat: add more output and parallelization to fit accessor Allows dictionary of `DataArray`s as parameter to fit accessor. Now, the return `Dataset` contains the data and the best fit array. Relevant tests have been added. --- src/erlab/accessors/__init__.py | 7 +- src/erlab/accessors/fit.py | 488 +++++++++++++++++++++++++++----- tests/accessors/test_fit.py | 185 ++++++++++++ 3 files changed, 608 insertions(+), 72 deletions(-) create mode 100644 tests/accessors/test_fit.py diff --git a/src/erlab/accessors/__init__.py b/src/erlab/accessors/__init__.py index 1777a559..f8912bcf 100644 --- a/src/erlab/accessors/__init__.py +++ b/src/erlab/accessors/__init__.py @@ -20,6 +20,7 @@ __all__ = [ "ModelFitDataArrayAccessor", "ModelFitDatasetAccessor", + "ParallelFitDataArrayAccessor", "MomentumAccessor", "OffsetView", "ImageToolAccessor", @@ -27,6 +28,10 @@ "SelectionAccessor", ] -from erlab.accessors.fit import ModelFitDataArrayAccessor, ModelFitDatasetAccessor +from erlab.accessors.fit import ( + ModelFitDataArrayAccessor, + ModelFitDatasetAccessor, + ParallelFitDataArrayAccessor, +) from erlab.accessors.kspace import MomentumAccessor, OffsetView from erlab.accessors.utils import ImageToolAccessor, PlotAccessor, SelectionAccessor diff --git a/src/erlab/accessors/fit.py b/src/erlab/accessors/fit.py index bbd32e14..eaa540d7 100644 --- a/src/erlab/accessors/fit.py +++ b/src/erlab/accessors/fit.py @@ -1,15 +1,113 @@ -__all__ = ["ModelFitDataArrayAccessor", "ModelFitDatasetAccessor"] +from __future__ import annotations +__all__ = [ + "ModelFitDataArrayAccessor", + "ModelFitDatasetAccessor", + "ParallelFitDataArrayAccessor", +] + +import copy +import itertools import warnings from collections.abc import Hashable, Iterable, Mapping, Sequence -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal +import joblib import lmfit import numpy as np +import tqdm.auto import xarray as xr -from xarray.core.types import Dims from erlab.accessors.utils import _THIS_ARRAY, ERLabAccessor +from erlab.parallel import joblib_progress + +if TYPE_CHECKING: + from xarray.core.types import Dims + + +def _nested_dict_vals(d): + for v in d.values(): + if isinstance(v, Mapping): + yield from _nested_dict_vals(v) + else: + yield v + + +def _broadcast_dict_values(d: Mapping[str, Any]) -> Mapping[str, xr.DataArray]: + to_broadcast = {} + for k, v in d.items(): + if isinstance(v, xr.DataArray | xr.Dataset): + to_broadcast[k] = v + else: + to_broadcast[k] = xr.DataArray(v) + + for k, v in zip(to_broadcast.keys(), xr.broadcast(*to_broadcast.values())): + d[k] = v + return d + + +def _concat_along_keys(d: Mapping[str, xr.DataArray], dim_name: str) -> xr.DataArray: + return xr.concat(d.values(), d.keys()).rename(concat_dim=dim_name) + + +def _parse_params(d: Mapping[str, Any], dask: bool) -> xr.DataArray | _ParametersWraper: + if isinstance(d, lmfit.Parameters): + # Input to apply_ufunc cannot be a Mapping, so wrap in a class + return _ParametersWraper(d) + + # Iterate over all values + for v in _nested_dict_vals(d): + if isinstance(v, xr.DataArray): + # For dask arrays, auto rechunking with object dtype is unsupported, so must + # convert to str + return _parse_multiple_params(copy.deepcopy(d), dask) + + return _ParametersWraper(lmfit.create_params(**d)) + + +def _parse_multiple_params(d: Mapping[str, Any], as_str: bool) -> xr.DataArray: + for k in d.keys(): + if isinstance(d[k], int | float | complex | xr.DataArray): + d[k] = {"value": d[k]} + + d[k] = _concat_along_keys(_broadcast_dict_values(d[k]), "__dict_keys") + + da = _concat_along_keys(_broadcast_dict_values(d), "__param_names") + + pnames = tuple(da["__param_names"].values) + argnames = tuple(da["__dict_keys"].values) + + def _reduce_to_param(arr, axis=0): + out_arr = np.empty_like(arr.mean(axis=axis), dtype=object) + for i in range(out_arr.size): + out_arr.flat[i] = {} + + for i, par in enumerate(pnames): + for j, name in enumerate(argnames): + for k, val in enumerate(arr[i, j].flat): + if par not in out_arr.flat[k]: + out_arr.flat[k][par] = {} + + if np.isfinite(val): + out_arr.flat[k][par][name] = val + + for i in range(out_arr.size): + out_arr.flat[i] = lmfit.create_params(**out_arr.flat[i]) + if as_str: + out_arr.flat[i] = out_arr.flat[i].dumps() + + if as_str: + return out_arr.astype(str) + else: + return out_arr + + da = da.reduce(_reduce_to_param, ("__dict_keys", "__param_names")) + return da + + +class _ParametersWraper: + def __init__(self, params: lmfit.Parameters): + self.params = params @xr.register_dataset_accessor("modelfit") @@ -22,10 +120,10 @@ def __call__( model: lmfit.Model, reduce_dims: Dims = None, skipna: bool = True, - params: str - | lmfit.Parameters + params: lmfit.Parameters | Mapping[str, float | dict[str, Any]] | xr.DataArray + | xr.Dataset | None = None, guess: bool = False, errors: Literal["raise", "ignore"] = "raise", @@ -35,64 +133,105 @@ def __call__( output_result: bool = True, **kwargs, ) -> xr.Dataset: - """ - Curve fitting optimization for arbitrary functions. + """Curve fitting optimization for arbitrary models. - Wraps :func:`lmfit.Model.fit` with `apply_ufunc`. + Wraps :meth:`lmfit.Model.fit ` with + :func:`xarray.apply_ufunc`. Parameters ---------- - coords : hashable, xarray.DataArray, or sequence of hashable or xarray.DataArray + coords : Hashable, xarray.DataArray, or Sequence of Hashable or xarray.DataArray Independent coordinate(s) over which to perform the curve fitting. Must share at least one dimension with the calling object. When fitting multi-dimensional functions, supply `coords` as a sequence in the same order as arguments in `func`. To fit along existing dimensions of the calling object, `coords` can also be specified as a str or sequence of strs. - model : lmfit.Model - A model object to fit to the data. The model must be an instance of - `lmfit.Model`. + model : `lmfit.Model ` + A model object to fit to the data. The model must be an *instance* of + :class:`lmfit.Model `. reduce_dims : str, Iterable of Hashable or None, optional Additional dimension(s) over which to aggregate while fitting. For example, - calling `ds.curvefit(coords='time', reduce_dims=['lat', 'lon'], ...)` will + calling `ds.modelfit(coords='time', reduce_dims=['lat', 'lon'], ...)` will aggregate all lat and lon points and fit the specified function along the time dimension. skipna : bool, default: True Whether to skip missing values when fitting. Default is True. - params : str, lmfit.Parameters, dict-like, or xarray.DataArray, optional - Optional input parameters to the fit. If a string, it should be a JSON - string representation of the parameters, generated by - `lmfit.Parameters.dumps`. If a `lmfit.Parameters` object, it will be used as - is. If a dict-like object, it will be converted to a `lmfit.Parameters` - object. If the values are DataArrays, they will be appropriately broadcast - to the coordinates of the array. If none or only some parameters are passed, - the rest will be assigned initial values and bounds with - :meth:`lmfit.Model.make_params`, or guessed with :meth:`lmfit.Model.guess` - if `guess` is `True`. + params : lmfit.Parameters, dict-like, or xarray.DataArray, optional + Optional input parameters to the fit. If a `lmfit.Parameters + ` object, it will be used for all fits. If a + dict-like object, it must look like the keyword arguments to + :func:`lmfit.create_params `. Additionally, + each value of the dictionary may also be a DataArray, which will be + broadcasted appropriately. If a DataArray, each entry must be a + dictionary-like object, a `lmfit.Parameters ` + object, or a JSON string created with :meth:`lmfit.Parameters.dumps + `. If given a Dataset, the name of the + data variables in the Dataset must match the name of the data variables in + the calling object, and each data variable will be used as the parameters + for the corresponding data variable. If none or only some parameters are + passed, the rest will be assigned initial values and bounds with + :meth:`lmfit.Model.make_params `, or guessed + with :meth:`lmfit.Model.guess ` if `guess` is + `True`. guess : bool, default: `False` - Whether to guess the initial parameters with :meth:`lmfit.Model.guess`. For - composite models, the parameters will be guessed for each component. + Whether to guess the initial parameters with :meth:`lmfit.Model.guess + `. For composite models, the parameters will be + guessed for each component. errors : {"raise", "ignore"}, default: "raise" - If `'raise'`, any errors from the :func:`lmfit.Model.fit` optimization will - raise an exception. If `'ignore'`, the return values for the coordinates - where the fitting failed will be NaN. + If `'raise'`, any errors from the :meth:`lmfit.Model.fit + ` optimization will raise an exception. If + `'ignore'`, the return values for the coordinates where the fitting failed + will be NaN. + parallel : bool, optional + Whether to parallelize the fits over the data variables. If not provided, + parallelization is only applied for non-dask Datasets with more than 200 + data variables. + parallel_kw : dict, optional + Additional keyword arguments to pass to the parallelization backend + :class:`joblib.Parallel` if `parallel` is `True`. + progress : bool, default: `False` + Whether to show a progress bar for fitting over data variables. Only useful + if there are multiple data variables to fit. + output_result : bool, default: `True` + Whether to include the full :class:`lmfit.model.ModelResult` object in the + output dataset. If `True`, the result will be stored in a variable named + `[var]_modelfit_results`. **kwargs : optional - Additional keyword arguments to passed to :func:`lmfit.Model.fit`. + Additional keyword arguments to passed to :meth:`lmfit.Model.fit + `. Returns ------- curvefit_results : xarray.Dataset A single dataset which contains: + [var]_modelfit_results + The full :class:`lmfit.model.ModelResult` object from the fit. Only + included if `output_result` is `True`. [var]_modelfit_coefficients The coefficients of the best fit. [var]_modelfit_stderr The standard errors of the coefficients. + [var]_modelfit_covariance + The covariance matrix of the coefficients. Note that elements + corresponding to non varying parameters are set to NaN, and the actual + size of the covariance matrix may be smaller than the array. [var]_modelfit_stats - Statistics from the fit. See :func:`lmfit.minimize`. + Statistics from the fit. See :func:`lmfit.minimize + `. + [var]_modelfit_data + Data used for the fit. + [var]_modelfit_best_fit + The best fit data of the fit. See Also -------- - xarray.Dataset.curve_fit xarray.Dataset.polyfit lmfit.model.Model.fit + xarray.Dataset.curvefit + + xarray.Dataset.polyfit + + lmfit.model.Model.fit + scipy.optimize.curve_fit """ @@ -101,15 +240,20 @@ def __call__( if params is None: params = lmfit.create_params() + if parallel_kw is None: + parallel_kw = {} + if kwargs is None: kwargs = {} - # Input to apply_ufunc cannot be a Mapping, so convert parameters to str - if isinstance(params, lmfit.Parameters): - params: str = params.dumps() - elif isinstance(params, Mapping): - # Given as a mapping from str to float or dict - params: str = lmfit.create_params(**params).dumps() + is_dask: bool = not ( + self._obj.chunksizes is None or len(self._obj.chunksizes) == 0 + ) + + if not isinstance(params, xr.Dataset) and isinstance(params, Mapping): + # Given as a mapping from str to ... + # float or DataArray or dict of str to Any (including DataArray of Any) + params = _parse_params(params, is_dask) reduce_dims_: list[Hashable] if not reduce_dims: @@ -142,11 +286,11 @@ def __call__( ) # Check that initial guess and bounds only contain coords in preserved_dims - if isinstance(params, xr.DataArray): + if isinstance(params, xr.DataArray | xr.Dataset): unexpected = set(params.dims) - set(preserved_dims) if unexpected: raise ValueError( - f"Initial guess has unexpected dimensions {tuple(unexpected)}. It " + f"Parameters object has unexpected dimensions {tuple(unexpected)}. It " "should only have dimensions that are in data dimensions " f"{preserved_dims}." ) @@ -190,10 +334,16 @@ def _wrapper(Y, *args, **kwargs): initial_params = lmfit.create_params() else: initial_params = model.make_params() - if isinstance(init_params_, str): + + if isinstance(init_params_, _ParametersWraper): + initial_params.update(init_params_.params) + + elif isinstance(init_params_, str): initial_params.update(lmfit.Parameters().loads(init_params_)) + elif isinstance(init_params_, lmfit.Parameters): initial_params.update(init_params_) + elif isinstance(init_params_, Mapping): for p, v in init_params_.items(): if isinstance(v, Mapping): @@ -201,17 +351,27 @@ def _wrapper(Y, *args, **kwargs): else: initial_params[p].set(value=v) + popt = np.full([n_params], np.nan) + perr = np.full([n_params], np.nan) + pcov = np.full([n_params, n_params], np.nan) + stats = np.full([n_stats], np.nan) + data = Y.copy() + best = np.full_like(data, np.nan) + x = np.vstack([c.ravel() for c in coords__]) y = Y.ravel() + if skipna: mask = np.all([np.any(~np.isnan(x), axis=0), ~np.isnan(y)], axis=0) x = x[:, mask] y = y[mask] if not len(y): - popt = np.full([n_params], np.nan) - perr = np.full([n_params, n_params], np.nan) - stats = np.full([n_stats], np.nan) - return popt, perr, stats + modres = lmfit.model.ModelResult(model, model.make_params(), data=y) + modres.success = False + return popt, perr, pcov, stats, data, best, modres + else: + mask = np.full_like(y, True) + x = np.squeeze(x) if n_coords == 1: @@ -243,67 +403,169 @@ def _wrapper(Y, *args, **kwargs): ) initial_params = model.make_params().update(initial_params) try: - fitresult: lmfit.model.ModelResult = model.fit( + modres: lmfit.model.ModelResult = model.fit( y, **indep_var_kwargs, params=initial_params, **kwargs ) - except RuntimeError: + except ValueError: if errors == "raise": raise - popt = np.full([n_params], np.nan) - perr = np.full([n_params, n_params], np.nan) - stats = np.full([n_stats], np.nan) + modres = lmfit.model.ModelResult(model, initial_params, data=y) + modres.success = False + return popt, perr, pcov, stats, data, best, modres else: - if fitresult.success: + if modres.success: popt, perr = [], [] for name in param_names: - p = fitresult.params[name] + p = modres.params[name] popt.append(p.value if p.value is not None else np.nan) perr.append(p.stderr if p.stderr is not None else np.nan) + popt, perr = np.array(popt), np.array(perr) - stats = [getattr(fitresult, s) for s in stat_names] + stats = [getattr(modres, s) for s in stat_names] stats = np.array([s if s is not None else np.nan for s in stats]) - else: - popt = np.full([n_params], np.nan) - perr = np.full([n_params, n_params], np.nan) - stats = np.full([n_stats], np.nan) - return popt, perr, stats + if modres.covar is not None: + var_names = modres.var_names + for vi in range(modres.nvarys): + i = param_names.index(var_names[vi]) + for vj in range(modres.nvarys): + j = param_names.index(var_names[vj]) + pcov[i, j] = modres.covar[vi, vj] - result = type(self._obj)() - for name, da in self._obj.data_vars.items(): + best.flat[mask] = modres.best_fit + + return popt, perr, pcov, stats, data, best, modres + + def _output_wrapper(name, da, out=None) -> dict: if name is _THIS_ARRAY: name = "" else: name = f"{name!s}_" + if out is None: + out = {} + input_core_dims = [reduce_dims_ for _ in range(n_coords + 1)] input_core_dims.extend([[] for _ in range(1)]) # core_dims for parameters - popt, perr, stats = xr.apply_ufunc( + if isinstance(params, xr.Dataset): + try: + params_to_apply = params[name.rstrip("_")] + except KeyError: + params_to_apply = params[float(name.rstrip("_"))] + else: + params_to_apply = params + + popt, perr, pcov, stats, data, best, modres = xr.apply_ufunc( _wrapper, da, *coords_, - params, + params_to_apply, vectorize=True, dask="parallelized", input_core_dims=input_core_dims, - output_core_dims=[["param"], ["param"], ["fit_stat"]], + output_core_dims=[ + ["param"], + ["param"], + ["cov_i", "cov_j"], + ["fit_stat"], + reduce_dims_, + reduce_dims_, + [], + ], dask_gufunc_kwargs={ "output_sizes": { "param": n_params, - "stat": n_stats, - }, + "fit_stat": n_stats, + "cov_i": n_params, + "cov_j": n_params, + } + | {dim: self._obj.coords[dim].size for dim in reduce_dims_} }, - output_dtypes=(np.float64, np.float64, np.float64), + output_dtypes=( + np.float64, + np.float64, + np.float64, + np.float64, + np.float64, + np.float64, + lmfit.model.ModelResult, + ), exclude_dims=set(reduce_dims_), kwargs=kwargs, ) - result[name + "modelfit_coefficients"] = popt - result[name + "modelfit_stderr"] = perr - result[name + "modelfit_stats"] = stats - result = result.assign_coords({"param": param_names, "fit_stat": stat_names}) + if output_result: + out[name + "modelfit_results"] = modres + + out[name + "modelfit_coefficients"] = popt + out[name + "modelfit_stderr"] = perr + out[name + "modelfit_covariance"] = pcov + out[name + "modelfit_stats"] = stats + out[name + "modelfit_data"] = data + out[name + "modelfit_best_fit"] = best + + return out + + if parallel is None: + parallel = (not is_dask) and (len(self._obj.data_vars) > 200) + + tqdm_kw = { + "desc": "Fitting", + "total": len(self._obj.data_vars), + "disable": not progress, + } + + if parallel: + if is_dask: + warnings.warn( + "The input Dataset is chunked. Parallel fitting will not offer any " + "performance benefits.", + stacklevel=1, + ) + + parallel_kw.setdefault("n_jobs", -1) + parallel_kw.setdefault("max_nbytes", None) + parallel_kw.setdefault("return_as", "generator_unordered") + parallel_kw.setdefault("pre_dispatch", "n_jobs") + parallel_kw.setdefault("prefer", "processes") + + parallel_obj = joblib.Parallel(**parallel_kw) + + if parallel_obj.return_generator: + out_dicts = tqdm.auto.tqdm( + parallel_obj( + joblib.delayed(_output_wrapper)(name, da) + for name, da in self._obj.data_vars.items() + ), + **tqdm_kw, + ) + else: + with joblib_progress(**tqdm_kw) as _: + out_dicts = parallel_obj( + joblib.delayed(_output_wrapper)(name, da) + for name, da in self._obj.data_vars.items() + ) + result = type(self._obj)( + dict(itertools.chain.from_iterable(d.items() for d in out_dicts)) + ) + del out_dicts + + else: + result = type(self._obj)() + for name, da in tqdm.auto.tqdm(self._obj.data_vars.items(), **tqdm_kw): + _output_wrapper(name, da, result) + + result = result.assign_coords( + { + "param": param_names, + "fit_stat": stat_names, + "cov_i": param_names, + "cov_j": param_names, + } + | {dim: self._obj.coords[dim] for dim in reduce_dims_} + ) result.attrs = self._obj.attrs.copy() return result @@ -313,13 +575,97 @@ def _wrapper(Y, *args, **kwargs): class ModelFitDataArrayAccessor(ERLabAccessor): """`xarray.DataArray.modelfit` accessor for fitting lmfit models.""" - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> xr.Dataset: return self._obj.to_dataset(name=_THIS_ARRAY).modelfit(*args, **kwargs) __call__.__doc__ = ( ModelFitDatasetAccessor.__call__.__doc__.replace( - "Dataset.curve_fit", "DataArray.curve_fit" + "Dataset.curvefit", "DataArray.curvefit" ) .replace("Dataset.polyfit", "DataArray.polyfit") .replace("[var]_", "") ) + + +@xr.register_dataarray_accessor("parallel_fit") +class ParallelFitDataArrayAccessor(ERLabAccessor): + """ + `xarray.DataArray.parallel_fit` accessor for fitting lmfit models in parallel along + a single dimension. + + """ + + _VAR_KEYS: tuple[str, ...] = ( + "modelfit_results", + "modelfit_coefficients", + "modelfit_stderr", + "modelfit_covariance", + "modelfit_stats", + "modelfit_data", + "modelfit_best_fit", + ) + + def __call__(self, dim: str, model: lmfit.Model, **kwargs) -> xr.Dataset: + """ + Fit the specified model to the data along the given dimension. + + Parameters + ---------- + dim : str + The name of the dimension along which to fit the model. + model : lmfit.Model + The model to fit. + **kwargs : dict + Additional keyword arguments to be passed to :func:`xarray.Dataset.modelfit + `. + + Returns + ------- + curvefit_results : xarray.Dataset + The dataset containing the results of the fit. See + :func:`xarray.DataArray.modelfit + ` for details. + + """ + if self._obj.chunks is None is not None: + raise ValueError( + "The input DataArray is chunked. Parallel fitting will not offer any " + "performance benefits. Use `modelfit` instead" + ) + + ds = self._obj.to_dataset(dim, promote_attrs=True) + + kwargs.setdefault("parallel", True) + kwargs.setdefault("progress", True) + + if isinstance(kwargs.get("params", None), Mapping): + kwargs["params"] = _parse_params(kwargs["params"], dask=False) + + if isinstance(kwargs.get("params", None), xr.DataArray): + kwargs["params"] = kwargs["params"].to_dataset(dim, promote_attrs=True) + + fitres = ds.modelfit(set(self._obj.dims) - {dim}, model, **kwargs) + + drop_keys = [] + concat_vars = {} + for k in ds.data_vars.keys(): + for var in self._VAR_KEYS: + key = f"{k}_{var}" + if key in fitres: + if var not in concat_vars: + concat_vars[var] = [] + concat_vars[var].append(fitres[key]) + drop_keys.append(key) + + return ( + fitres.drop_vars(drop_keys) + .assign( + { + k: xr.concat( + v, dim, coords="minimal", compat="override", join="override" + ) + for k, v in concat_vars.items() + } + ) + .assign_coords({dim: self._obj[dim]}) + ) diff --git a/tests/accessors/test_fit.py b/tests/accessors/test_fit.py new file mode 100644 index 00000000..6774a63c --- /dev/null +++ b/tests/accessors/test_fit.py @@ -0,0 +1,185 @@ +import erlab.accessors # noqa: F401 +import lmfit +import numpy as np +import pytest +import xarray as xr + + +@pytest.mark.parametrize("use_dask", [True, False]) +def test_modelfit(use_dask: bool): + # Tests are adapted from xarray's curvefit tests + + def exp_decay(t, n0, tau=1): + return n0 * np.exp(-t / tau) + + def power(t, a): + return np.power(t, a) + + t = np.arange(0, 5, 0.5) + da = xr.DataArray( + np.stack([exp_decay(t, 3, 3), exp_decay(t, 5, 4), np.nan * t], axis=-1), + dims=("t", "x"), + coords={"t": t, "x": [0, 1, 2]}, + ) + da[0, 0] = np.nan + + expected = xr.DataArray( + [[3, 3], [5, 4], [np.nan, np.nan]], + dims=("x", "param"), + coords={"x": [0, 1, 2], "param": ["n0", "tau"]}, + ) + + if use_dask: + da = da.chunk({"x": 1}) + + # Create model + model = lmfit.Model(exp_decay) + + # Params as dictionary + fit = da.modelfit( + coords=[da.t], + model=model, + params={"n0": 4, "tau": {"min": 2, "max": 6}}, + ) + np.testing.assert_allclose(fit.modelfit_coefficients, expected, rtol=1e-3) + + # Params as lmfit.Parameters + fit = da.modelfit( + coords=[da.t], + model=model, + params=lmfit.create_params(n0=4, tau={"min": 2, "max": 6}), + ) + np.testing.assert_allclose(fit.modelfit_coefficients, expected, rtol=1e-3) + + # Test parallel fits + if not use_dask: + fit = da.parallel_fit( + dim="x", + model=model, + params={"n0": 4, "tau": {"min": 2, "max": 6}}, + output_result=False, + ) + np.testing.assert_allclose(fit.modelfit_coefficients, expected, rtol=1e-3) + + fit = da.parallel_fit( + dim="x", + model=model, + params=lmfit.create_params(n0=4, tau={"min": 2, "max": 6}), + output_result=False, + ) + np.testing.assert_allclose(fit.modelfit_coefficients, expected, rtol=1e-3) + + if use_dask: + da = da.compute() + + # Test 0dim output + fit = da.modelfit( + coords="t", + model=lmfit.Model(power), + reduce_dims="x", + params={"a": {"value": 0.3, "vary": True}}, + ) + + assert "a" in fit.param + assert fit.modelfit_results.dims == () + + +@pytest.mark.parametrize("use_dask", [True, False]) +def test_modelfit_params(use_dask: bool): + def sine(t, a, f, p): + return a * np.sin(2 * np.pi * (f * t + p)) + + t = np.arange(0, 2, 0.02) + da = xr.DataArray( + np.stack([sine(t, 1.0, 2, 0), sine(t, 1.0, 2, 0)]), + coords={"x": [0, 1], "t": t}, + ) + + expected = xr.DataArray( + [[1, 2, 0], [-1, 2, 0.5]], + coords={"x": [0, 1], "param": ["a", "f", "p"]}, + ) + + # Different initial guesses for different values of x + a_guess = [1.0, -1.0] + p_guess = [0.0, 0.5] + + if use_dask: + da = da.chunk({"x": 1}) + + # params as DataArray of JSON strings + params = [] + for a, p, f in zip(a_guess, p_guess, np.full_like(da.x, 2, dtype=float)): + params.append(lmfit.create_params(a=a, p=p, f=f).dumps()) + params = xr.DataArray(params, coords=[da.x]) + fit = da.modelfit( + coords=[da.t], + model=lmfit.Model(sine), + params=params, + ) + np.testing.assert_allclose(fit.modelfit_coefficients, expected) + + # params as mixed dictionary + fit = da.modelfit( + coords=[da.t], + model=lmfit.Model(sine), + params={ + "a": xr.DataArray(a_guess, coords=[da.x]), + "p": xr.DataArray(p_guess, coords=[da.x]), + "f": 2.0, + }, + ) + np.testing.assert_allclose(fit.modelfit_coefficients, expected) + + def sine(t, a, f, p): + return a * np.sin(2 * np.pi * (f * t + p)) + + t = np.arange(0, 2, 0.02) + da = xr.DataArray( + np.stack([sine(t, 1.0, 2, 0), sine(t, 1.0, 2, 0)]), + coords={"x": [0, 1], "t": t}, + ) + + # Fit a sine with different bounds: positive amplitude should result in a fit with + # phase 0 and negative amplitude should result in phase 0.5 * 2pi. + + expected = xr.DataArray( + [[1, 2, 0], [-1, 2, 0.5]], + coords={"x": [0, 1], "param": ["a", "f", "p"]}, + ) + + if use_dask: + da = da.chunk({"x": 1}) + + # params as DataArray of JSON strings + fit = da.modelfit( + coords=[da.t], + model=lmfit.Model(sine), + params=xr.DataArray( + [ + lmfit.create_params(**param_dict).dumps() + for param_dict in ( + {"f": 2, "p": 0.25, "a": {"value": 1, "min": 0, "max": 2}}, + {"f": 2, "p": 0.25, "a": {"value": -1, "min": -2, "max": 0}}, + ) + ], + coords=[da.x], + ), + ) + np.testing.assert_allclose(fit.modelfit_coefficients, expected, atol=1e-8) + + # params as mixed dictionary + fit = da.modelfit( + coords=[da.t], + model=lmfit.Model(sine), + params={ + "f": {"value": 2}, + "p": 0.25, + "a": { + "value": xr.DataArray([1, -1], coords=[da.x]), + "min": xr.DataArray([0, -2], coords=[da.x]), + "max": xr.DataArray([2, 0], coords=[da.x]), + }, + }, + ) + np.testing.assert_allclose(fit.modelfit_coefficients, expected, atol=1e-8)