Skip to content

Commit

Permalink
feat(plotting.general): add NonUniformImage functionality to `plot_…
Browse files Browse the repository at this point in the history
…array`

`plot_array` can now plot data with unevenly spaced coordinates. It uses `matplotlib.image.NonUniformImage` with the default interpolation option set to 'nearest'. The resulting plot may be different from `xarray.DataArray.plot` which uses `pcolormesh` to generate image plots.
  • Loading branch information
kmnhan committed Aug 12, 2024
1 parent 29c37c4 commit 86d8c1a
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 20 deletions.
105 changes: 92 additions & 13 deletions src/erlab/plotting/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
gen_2d_colormap,
nice_colorbar,
)
from erlab.utils.array import is_dims_uniform

if TYPE_CHECKING:
from collections.abc import Callable, Collection, Sequence
Expand Down Expand Up @@ -170,6 +171,56 @@ def array_extent(
return x0, x1, y0, y1


def _imshow_nonuniform(
ax,
x,
y,
A,
cmap=None,
norm=None,
*,
aspect=None,
interpolation=None,
alpha=None,
vmin=None,
vmax=None,
url=None,
**kwargs,
):
"""Display data as an image on a non-uniform grid.
:func:`matplotlib.pyplot.imshow` creates a :class:`matplotlib.image.AxesImage`, but
this function creates a :class:`matplotlib.image.NonUniformImage` instead.
"""
im = matplotlib.image.NonUniformImage(
ax,
cmap=cmap,
norm=norm,
interpolation=interpolation,
**kwargs,
)

if aspect is None and not (
im.is_transform_set() and not im.get_transform().contains_branch(ax.transData)
):
aspect = matplotlib.rcParams["image.aspect"]
if aspect is not None:
ax.set_aspect(aspect)

im.set_data(x, y, A)
im.set_alpha(alpha)
if im.get_clip_path() is None:
im.set_clip_path(ax.patch)
im._scale_norm(norm, vmin, vmax)
im.set_url(url)

im.set_extent(im.get_extent())

ax.add_image(im)
return im


def plot_array(
arr: xr.DataArray,
ax: matplotlib.axes.Axes | None = None,
Expand All @@ -186,10 +237,14 @@ def plot_array(
func_args: dict | None = None,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
rasterized: bool = True,
**improps,
) -> matplotlib.image.AxesImage:
"""Plot a 2D :class:`xarray.DataArray` using :func:`matplotlib.pyplot.imshow`.
If the input array is detected to have non-evenly spaced coordinates, it is plotted
as a :class:`matplotlib.image.NonUniformImage`.
Parameters
----------
arr
Expand All @@ -214,16 +269,30 @@ def plot_array(
func_args
Keyword arguments passed onto `func`.
rtol, atol
By default, the input array is checked for evenly spaced coordinates. ``rtol``
and ``atol`` are the tolerances for the coordinates to be considered evenly
spaced. The default values are consistent with `numpy.isclose`.
By default, the input array is checked for evenly spaced coordinates. If it is
not evenly spaced, it is plotted as a :class:`matplotlib.image.NonUniformImage`
instead of a :class:`matplotlib.image.AxesImage`. ``rtol`` and ``atol`` are the
tolerances for the coordinates to be considered evenly spaced. The default
values are consistent with `numpy.isclose`.
rasterized
Force rasterized output.
**improps
Keyword arguments passed onto :func:`matplotlib.axes.Axes.imshow`.
Keyword arguments passed onto :func:`matplotlib.pyplot.imshow`.
Returns
-------
matplotlib.image.AxesImage
Notes
-----
Some keyword arguments have different default behavior compared to matplotlib.
- `interpolation` is set to ``'none'`` for evenly spaced data and ``'nearest'`` for
nonuniform data.
- `aspect` is set to ``'auto'``.
- `origin` is set to ``'lower'``.
- The image is rasterized by default.
"""
if colorbar_kw is None:
colorbar_kw = {}
Expand All @@ -232,6 +301,8 @@ def plot_array(

if isinstance(arr, np.ndarray):
arr = xr.DataArray(arr)
if arr.ndim != 2:
raise ValueError("Input array must be 2D")

if ax is None:
ax = plt.gca()
Expand Down Expand Up @@ -266,13 +337,7 @@ def plot_array(
if norm is None:
norm = copy.deepcopy(matplotlib.colors.PowerNorm(gamma, **norm_kw))

improps_default = {
"interpolation": "none",
"extent": array_extent(arr, rtol, atol),
"aspect": "auto",
"origin": "lower",
"rasterized": True,
}
improps_default = {"aspect": "auto", "origin": "lower", "rasterized": rasterized}
for k, v in improps_default.items():
improps.setdefault(k, v)

Expand All @@ -283,9 +348,23 @@ def plot_array(
arr = arr.copy(deep=True).sel({arr.dims[0]: slice(*ylim)})

if func is not None:
img = ax.imshow(func(arr, **func_args), norm=norm, **improps)
arr = func(arr.copy(deep=True), **func_args)

if is_dims_uniform(arr, rtol=rtol, atol=atol):
improps.setdefault("interpolation", "none")
img = ax.imshow(
arr.values, norm=norm, extent=array_extent(arr, rtol, atol), **improps
)
else:
img = ax.imshow(arr.values, norm=norm, **improps)
improps.setdefault("interpolation", "nearest")
img = _imshow_nonuniform(
ax,
x=arr[arr.dims[1]].values,
y=arr[arr.dims[0]].values,
A=arr.values,
norm=norm,
**improps,
)

ax.set_xlabel(str(arr.dims[1]))
ax.set_ylabel(str(arr.dims[0]))
Expand Down
23 changes: 16 additions & 7 deletions tests/plotting/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,22 @@ def test_plot_slices():
@pytest.mark.parametrize("ylim", [None, 0.5, (-0.3, 0.7)])
@pytest.mark.parametrize("xlim", [None, 0.5, (-0.3, 0.7)])
@pytest.mark.parametrize("colorbar", [False, True])
def test_plot_array(colorbar, xlim, ylim, crop, kwargs):
data = xr.DataArray(
np.random.default_rng(0).random((11, 11)),
coords=[np.linspace(-1, 1, 11), np.linspace(-1, 1, 11)],
dims=["x", "y"],
)

@pytest.mark.parametrize(
"data",
[
xr.DataArray(
np.random.default_rng(0).random((11, 11)),
coords=[np.linspace(-1, 1, 11), np.linspace(-1, 1, 11)],
dims=["x", "y"],
),
xr.DataArray(
np.random.default_rng(0).random((11, 11)),
coords=[np.linspace(-1, 1, 11) ** 3, np.linspace(-1, 1, 11)],
dims=["x", "y"],
),
],
)
def test_plot_array(data, colorbar, xlim, ylim, crop, kwargs):
_, ax = plt.subplots()
plot_array(data, colorbar=colorbar, xlim=xlim, ylim=ylim, crop=crop, **kwargs)
assert ax.get_xlabel() == "y"
Expand Down

0 comments on commit 86d8c1a

Please sign in to comment.