Skip to content

Commit

Permalink
Merge pull request #33 from kmnhan/dev-2.5
Browse files Browse the repository at this point in the history
2.5 Update
  • Loading branch information
kmnhan authored May 13, 2024
2 parents f968c37 + f6f19ab commit dc9a9a8
Show file tree
Hide file tree
Showing 15 changed files with 831 additions and 101 deletions.
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
ci:
autoupdate_commit_msg: "ci(pre-commit): pre-commit autoupdate"
autofix_commit_msg: "style: pre-commit auto fixes [...]"

repos:

# Meta hooks
Expand Down
2 changes: 2 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def linkcode_resolve(domain, info):
"cmasher": ("https://cmasher.readthedocs.io/", None),
"ipywidgets": ("https://ipywidgets.readthedocs.io/en/stable/", None),
"joblib": ("https://joblib.readthedocs.io/en/latest/", None),
"panel": ("https://panel.holoviz.org/", None),
"hvplot": ("https://hvplot.holoviz.org/", None),
}


Expand Down
39 changes: 37 additions & 2 deletions docs/source/user-guide/curve-fitting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@
"metadata": {},
"outputs": [],
"source": [
"darr.modelfit(\n",
"result_ds = darr.modelfit(\n",
" coords=\"alpha\",\n",
" model=GaussianModel() + LinearModel(),\n",
" params={\n",
Expand All @@ -795,7 +795,42 @@
" },\n",
" \"slope\": -0.1,\n",
" },\n",
")"
")\n",
"result_ds"
]
},
{
"cell_type": "raw",
"metadata": {
"editable": true,
"raw_mimetype": "text/restructuredtext",
"slideshow": {
"slide_type": ""
},
"tags": [],
"vscode": {
"languageId": "raw"
}
},
"source": [
"Visualizing fits\n",
"~~~~~~~~~~~~~~~~\n",
"\n",
"If `hvplot <https://github.com/holoviz/hvplot>`_ is installed, we can visualize the fit\n",
"results interactively with the `qshow` accessor.\n",
"\n",
".. note::\n",
"\n",
" If you are viewing this documentation online, the plot will not be interactive. Run the code locally to try it out."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"result_ds.qshow(plot_components=True)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs/source/user-guide/plotting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@
}
},
"source": [
":func:`plot_array <erlab.plotting.general.plot_array>` can also be accessed (for 2D data) through an :class:`accessor <erlab.accessors.utils.PlotAccessor>`."
":func:`plot_array <erlab.plotting.general.plot_array>` can also be accessed (for 2D data) through the :class:`qplot <erlab.accessors.utils.PlotAccessor>` accessor."
]
},
{
Expand Down Expand Up @@ -316,7 +316,7 @@
"source": [
"Here, we plotted each constant energy surface with :func:`plot_array\n",
"<erlab.plotting.general.plot_array>`. To remove the duplicated y axis labels and add\n",
"some annotations, we can use :func:`clean_labels <erlab.plotting.erplot.clean_labels>`\n",
"some annotations, we can use :func:`clean_labels <erlab.plotting.general.clean_labels>`\n",
"and :func:`label_subplot_properties\n",
"<erlab.plotting.annotations.label_subplot_properties>`:"
]
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ exclude = [
module = [
"astropy.*",
"h5netcdf.*",
"hvplot.*",
"igor2.*",
"iminuit.*",
"ipywidgets.*",
Expand Down
10 changes: 8 additions & 2 deletions src/erlab/accessors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
""" # noqa: D205

__all__ = [
"ImageToolAccessor",
"InteractiveDataArrayAccessor",
"InteractiveDatasetAccessor",
"ModelFitDataArrayAccessor",
"ModelFitDatasetAccessor",
"MomentumAccessor",
Expand All @@ -34,4 +35,9 @@
ParallelFitDataArrayAccessor,
)
from erlab.accessors.kspace import MomentumAccessor, OffsetView
from erlab.accessors.utils import ImageToolAccessor, PlotAccessor, SelectionAccessor
from erlab.accessors.utils import (
InteractiveDataArrayAccessor,
InteractiveDatasetAccessor,
PlotAccessor,
SelectionAccessor,
)
242 changes: 237 additions & 5 deletions src/erlab/accessors/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
__all__ = [
"ImageToolAccessor",
"InteractiveDataArrayAccessor",
"InteractiveDatasetAccessor",
"PlotAccessor",
"SelectionAccessor",
]

import importlib
import warnings
from collections.abc import Hashable, Mapping
from typing import Any, TypeGuard, TypeVar, cast

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

import erlab.plotting.erplot as eplt
Expand Down Expand Up @@ -90,16 +93,245 @@ def __call__(self, *args, **kwargs):


@xr.register_dataarray_accessor("qshow")
class ImageToolAccessor(ERLabDataArrayAccessor):
class InteractiveDataArrayAccessor(ERLabDataArrayAccessor):
"""`xarray.DataArray.qshow` accessor for interactive visualization."""

def __call__(self, *args, **kwargs):
"""Visualize the data interactively.
Chooses the appropriate interactive visualization method based on the number of
dimensions in the data.
Parameters
----------
*args
Positional arguments passed onto the interactive visualization function.
**kwargs
Keyword arguments passed onto the interactive visualization function.
"""
if self._obj.ndim >= 2 and self._obj.ndim <= 4:
return self.itool(*args, **kwargs)
else:
if importlib.util.find_spec("hvplot"):
return self._obj.hvplot(*args, **kwargs)

raise ValueError("Data must have at least two dimensions.")

def itool(self, *args, **kwargs):
"""Shortcut for :func:`itool <erlab.interactive.imagetool.itool>`.
Parameters
----------
*args
Positional arguments passed onto :func:`itool <erlab.interactive.imagetool.itool>`.
**kwargs
Keyword arguments passed onto :func:`itool <erlab.interactive.imagetool.itool>`.
"""
from erlab.interactive.imagetool import itool

return itool(self._obj, *args, **kwargs)

def hvplot(self, *args, **kwargs):
""":mod:`hvplot`-based interactive visualization.
Parameters
----------
*args
Positional arguments passed onto :meth:`DataArray.hvplot`.
**kwargs
Keyword arguments passed onto :meth:`DataArray.hvplot`.
Raises
------
ImportError
If :mod:`hvplot` is not installed.
"""
if not importlib.util.find_spec("hvplot"):
raise ImportError("hvplot is required for qshow.hvplot()")
import hvplot.xarray # noqa: F401

return self._obj.hvplot(*args, **kwargs)


@xr.register_dataset_accessor("qshow")
class InteractiveDatasetAccessor(ERLabDatasetAccessor):
"""`xarray.Dataset.qshow` accessor for interactive visualization."""

def __call__(self, *args, **kwargs):
"""Visualize the data interactively.
Chooses the appropriate interactive visualization method based on the data
variables.
Parameters
----------
*args
Positional arguments passed onto the interactive visualization function.
**kwargs
Keyword arguments passed onto the interactive visualization function.
"""
if self._is_fitresult:
return self.fit(*args, **kwargs)
else:
return self.itool(*args, **kwargs)

@property
def _is_fitresult(self) -> bool:
from erlab.accessors.fit import ParallelFitDataArrayAccessor

for var in set(ParallelFitDataArrayAccessor._VAR_KEYS) - {"modelfit_results"}:
if var not in self._obj.data_vars:
return False
return True

def itool(self, *args, **kwargs):
from erlab.interactive.imagetool import itool

if len(self._obj.dims) >= 2:
return itool(self._obj, *args, **kwargs)
return itool(self._obj, *args, **kwargs)

def hvplot(self, *args, **kwargs):
if not importlib.util.find_spec("hvplot"):
raise ImportError("hvplot is required for qshow.hvplot()")
import hvplot.xarray # noqa: F401

return self._obj.hvplot(*args, **kwargs)

itool.__doc__ = InteractiveDataArrayAccessor.itool.__doc__
hvplot.__doc__ = str(InteractiveDataArrayAccessor.hvplot.__doc__).replace(
"DataArray", "Dataset"
)

def fit(self, plot_components: bool = False):
"""Interactive visualization of fit results.
Parameters
----------
plot_components
If `True`, plot the components of the fit. Default is `False`. Requires the
Dataset to have a `modelfit_results` variable.
Returns
-------
:class:`panel.Column`
A panel containing the interactive visualization.
"""
if not importlib.util.find_spec("hvplot"):
raise ImportError("hvplot is required for interactive fit visualization")

import hvplot.xarray
import panel
import panel.widgets

from erlab.accessors.fit import ParallelFitDataArrayAccessor

for var in set(ParallelFitDataArrayAccessor._VAR_KEYS) - {"modelfit_results"}:
if var not in self._obj.data_vars:
raise ValueError("Dataset is not a fit result")

coord_dims = [
d
for d in self._obj.modelfit_stats.dims
if d in self._obj.modelfit_data.dims
]
other_dims = [d for d in self._obj.modelfit_data.dims if d not in coord_dims]

if len(other_dims) != 1:
raise ValueError("Only 1D fits are supported")

sliders = [
panel.widgets.DiscreteSlider(name=d, options=list(np.array(self._obj[d])))
for d in coord_dims
]

def get_slice(*s):
return self._obj.sel(dict(zip(coord_dims, s, strict=False)))

def get_slice_params(*s):
res_part = get_slice(*s).rename(param="Parameter")
return xr.merge(
[
res_part.modelfit_coefficients.rename("Value"),
res_part.modelfit_stderr.rename("Stderr"),
]
)

def get_comps(*s):
partial_res = get_slice(*s)
return xr.merge(
[
xr.DataArray(
v, dims=other_dims, coords=[self._obj[other_dims[0]]]
).rename(k)
for k, v in partial_res.modelfit_results.item()
.eval_components()
.items()
]
+ [
partial_res.modelfit_data,
partial_res.modelfit_best_fit,
]
)

part = hvplot.bind(get_slice, *sliders).interactive()
part_params = hvplot.bind(get_slice_params, *sliders).interactive()

if "modelfit_results" not in self._obj.data_vars:
warnings.warn(
"`model_results` not included in Dataset. "
"Components will not be plotted",
stacklevel=2,
)
plot_components = False

plot_kwargs = {
"responsive": True,
"min_width": 400,
"min_height": 500,
"title": "",
}
if plot_components:
part_comps = hvplot.bind(get_comps, *sliders).interactive()
data = part_comps.modelfit_data.hvplot.scatter(**plot_kwargs)
fit = part_comps.modelfit_best_fit.hvplot(c="k", ylabel="", **plot_kwargs)
components = part_comps.hvplot(
y=list(
self._obj.modelfit_results.values.flatten()[0]
.eval_components()
.keys()
),
legend="top_right",
group_label="Component",
**plot_kwargs,
)
plots = components * data * fit
else:
raise ValueError("Data must have at leasst two dimensions.")
data = part.modelfit_data.hvplot.scatter(**plot_kwargs)
fit = part.modelfit_best_fit.hvplot(c="k", ylabel="", **plot_kwargs)
plots = data * fit

return panel.Column(
plots,
panel.Spacer(height=30),
panel.Tabs(
(
"Parameters",
part_params.hvplot.table(
columns=["Parameter", "Value", "Stderr"],
title="",
responsive=True,
),
),
(
"Fit statistics",
part.modelfit_stats.hvplot.table(
columns=["fit_stat", "modelfit_stats"],
title="",
responsive=True,
),
),
),
)


@xr.register_dataarray_accessor("qsel")
Expand Down
Loading

0 comments on commit dc9a9a8

Please sign in to comment.