Skip to content

Commit

Permalink
fix(plotting): add some validation checks to plot_array
Browse files Browse the repository at this point in the history
The functions `plot_array` and `plot_array_2d` now checks if the input array coordinates are uniformly spaced. If they are not, a warning is issued and the user is informed that the plot may not be accurate.
  • Loading branch information
kmnhan committed Apr 29, 2024
1 parent 8f23f99 commit 2e0753c
Showing 1 changed file with 110 additions and 8 deletions.
118 changes: 110 additions & 8 deletions src/erlab/plotting/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

import contextlib
import copy
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Literal, Union, cast

import matplotlib
import matplotlib.colorbar
import matplotlib.colors
import matplotlib.image
import matplotlib.patches
Expand All @@ -43,6 +45,8 @@
if TYPE_CHECKING:
from collections.abc import Callable, Collection, Sequence

from matplotlib.typing import ColorType

figure_width_ref = {
"aps": [3.4, 7.0],
"aip": [3.37, 6.69],
Expand Down Expand Up @@ -259,22 +263,40 @@ def place_inset(
return _ez_inset(parent_axes, width, height, pad, loc, **kwargs)


def array_extent(data: xr.DataArray) -> tuple[float, float, float, float]:
def array_extent(
darr: xr.DataArray, rtol: float = 1.0e-5, atol: float = 1.0e-8
) -> tuple[float, float, float, float]:
"""Get the extent of a :class:`xarray.DataArray`.
The extent can be used as the `extent` argument in :func:`matplotlib.pyplot.imshow`.
Parameters
----------
data
darr
A two-dimensional :class:`xarray.DataArray`.
rtol, atol
Tolerance for the coordinates to be considered evenly spaced. The default values
are consistent with `numpy.isclose`.
Returns
-------
x0, x1, y0, y1 : float
"""
data_coords = tuple(data[dim].values for dim in data.dims)
if darr.ndim != 2:
raise ValueError("Input array must be 2D")

data_coords = tuple(darr[dim].values for dim in darr.dims)
for dim, coord in zip(darr.dims, data_coords, strict=True):
dif = np.diff(coord)
if not np.allclose(dif, dif[0], rtol=rtol, atol=atol):
warnings.warn(
f"Coordinates for {dim} are not evenly spaced, and the plot may not be "
"accurate. Use `DataArray.plot`, `xarray.plot.pcolormesh` or "
"`matplotlib.pyplot.pcolormesh` for non-evenly spaced data.",
stacklevel=2,
)

data_incs = tuple(coord[1] - coord[0] for coord in data_coords)
data_lims = tuple((coord[0], coord[-1]) for coord in data_coords)
y0, x0 = data_lims[0][0] - 0.5 * data_incs[0], data_lims[1][0] - 0.5 * data_incs[1]
Expand All @@ -299,14 +321,16 @@ def plot_array(
rad2deg: bool | Iterable[str] = False,
func: Callable | None = None,
func_args: dict | None = None,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
**improps,
) -> matplotlib.image.AxesImage:
"""Plot a 2D :class:`xarray.DataArray` using :func:`matplotlib.pyplot.imshow`.
Parameters
----------
arr
A two-dimensional :class:`xarray.DataArray` to be plotted.
A two-dimensional :class:`xarray.DataArray` with evenly spaced coordinates.
ax
The target :class:`matplotlib.axes.Axes`.
colorbar
Expand All @@ -326,6 +350,10 @@ def plot_array(
same shape as the input.
func_args
Keyword arguments passed onto `func`.
rtol, atol
By default, the input array is checked for evenly spaced coordinates. ``rtol``
and ``atol`` are the tolerances for the coordinates to be considered evenly
spaced. The default values are consistent with `numpy.isclose`.
**improps
Keyword arguments passed onto :func:`matplotlib.axes.Axes.imshow`.
Expand Down Expand Up @@ -377,7 +405,7 @@ def plot_array(

improps_default = {
"interpolation": "none",
"extent": array_extent(arr),
"extent": array_extent(arr, rtol, atol),
"aspect": "auto",
"origin": "lower",
"rasterized": True,
Expand Down Expand Up @@ -422,14 +450,84 @@ def plot_array_2d(
cmap: matplotlib.colors.Colormap | str | None = None,
lnorm: matplotlib.colors.Normalize | None = None,
cnorm: matplotlib.colors.Normalize | None = None,
background: Any = None,
background: ColorType | None = None,
colorbar: bool = True,
cax: matplotlib.axes.Axes | None = None,
colorbar_kw: dict | None = None,
imshow_kw: dict | None = None,
N: int = 256,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
**indexers_kwargs,
):
) -> tuple[matplotlib.image.AxesImage, matplotlib.colorbar.Colorbar | None]:
"""
Plot a 2D array with associated color array.
The lightness array represents the intensity values, while the color array
represents some other property. The arrays must have the same shape.
Parameters
----------
larr
The 2D array representing the lightness values.
carr
The 2D array representing the color values.
ax
The axes on which to plot the array. If None, the current axes will be used.
normalize_with_larr
Whether to normalize the color array with the lightness array. Default is False.
xlim
The x-axis limits for the plot. If a float, it represents the symmetric limits
around 0. If a tuple, it represents the lower and upper limits. If None, the
limits are determined from the data.
ylim
The y-axis limits for the plot. If a float, it represents the symmetric limits
around 0. If a tuple, it represents the lower and upper limits. If None, the
limits are determined from the data.
cmap
The colormap to use for the color array. If None, a linear segmented colormap
consisting of blue, black, and red is used.
lnorm
The normalization object for the lightness array.
cnorm
The normalization object for the color array.
background
The background color to use for the plot. If None, white is used.
colorbar
Whether to create a colorbar. Default is `True`.
cax
The axes on which to create the colorbar if `colorbar` is `True`. If None, a new
axes will be created for the colorbar.
colorbar_kw
Additional keyword arguments to pass to `matplotlib.pyplot.colorbar`.
imshow_kw
Additional keyword arguments to pass to `matplotlib.pyplot.imshow`.
N
The number of levels in the colormap. Default is 256.
rtol, atol
By default, the input array is checked for evenly spaced coordinates. ``rtol``
and ``atol`` are the tolerances for the coordinates to be considered evenly
spaced. The default values are consistent with `numpy.isclose`.
**indexers_kwargs : dict
Additional keyword arguments to pass to `qsel` to select the data to plot. Note
that the resulting data after the selection must be 2D.
Returns
-------
im : matplotlib.image.AxesImage
The plotted image.
cb : matplotlib.colorbar.Colorbar or None
The colorbar associated with the plot. If `colorbar` is False, None is returned.
Example
-------
>>> import erlab.plotting.erplot as eplt
>>> import matplotlib.pyplot as plt
>>> import xarray as xr
>>> larr = xr.DataArray([[1, 2, 3], [4, 5, 6]])
>>> carr = xr.DataArray([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
>>> eplt.plot_array_2d(larr, carr)
"""
if lnorm is None:
lnorm = plt.Normalize()
else:
Expand All @@ -455,6 +553,10 @@ def plot_array_2d(

larr = larr.qsel(**indexers_kwargs).copy(deep=True)
carr = carr.qsel(**indexers_kwargs).copy(deep=True)

if larr.shape != carr.shape:
raise ValueError("Lightness and color array must have the same shape")

sel_kw = {}

if xlim is not None:
Expand Down Expand Up @@ -515,7 +617,7 @@ def plot_array_2d(
aspect="auto",
)

im = ax.imshow(img, extent=array_extent(larr), **imshow_kw)
im = ax.imshow(img, extent=array_extent(larr, rtol, atol), **imshow_kw)
ax.set_xlabel(str(larr.dims[0]))
ax.set_ylabel(str(larr.dims[1]))
fancy_labels(ax)
Expand Down

0 comments on commit 2e0753c

Please sign in to comment.