Skip to content
forked from pydata/xarray

Commit

Permalink
Avoid map_groupby & map_groupby_line. Using map_dataarray & map_dataa…
Browse files Browse the repository at this point in the history
…rray_line instead.
  • Loading branch information
dcherian committed Jan 14, 2020
1 parent 35d2b50 commit 0ab1af8
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 105 deletions.
139 changes: 38 additions & 101 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np

from ..core.formatting import format_item
from typing import Mapping
from .utils import (
_infer_xy_labels,
_process_cmap_cbar_kwargs,
Expand Down Expand Up @@ -32,7 +33,7 @@ def _nicetitle(coord, value, maxchar, template):
return title


def parse_sharex_sharey(data, sharex, sharey):
def _parse_sharex_sharey(data, sharex, sharey):

from ..core.groupby import GroupBy

Expand All @@ -50,6 +51,25 @@ def parse_sharex_sharey(data, sharex, sharey):
return sharex, sharey


def _get_subset(data, key: Mapping, expected_ndim=None):
""" Index with "key" using either .loc or get_group as appropriate. """
from ..core.groupby import GroupBy

if isinstance(data, GroupBy):
label = list(key.values())[0]
group = data.get_group(label)

# This path is needed when some groups have fewer dimensions than other groups.
if expected_ndim:
if expected_ndim > group.ndim:
group = group.expand_dims(data._group_dim)
elif expected_ndim < group.ndim:
group = group.squeeze()
return group
else:
return data.loc[key]


class FacetGrid:
"""
Initialize the matplotlib figure and FacetGrid object.
Expand Down Expand Up @@ -179,7 +199,7 @@ def __init__(
cbar_space = 1
figsize = (ncol * size * aspect + cbar_space, nrow * size)

sharex, sharey = parse_sharex_sharey(data, sharex, sharey)
sharex, sharey = _parse_sharex_sharey(data, sharex, sharey)
fig, axes = plt.subplots(
nrow,
ncol,
Expand Down Expand Up @@ -260,13 +280,16 @@ def map_dataarray(self, func, x, y, **kwargs):
self : FacetGrid object
"""
from ..core.groupby import GroupBy

if kwargs.get("cbar_ax", None) is not None:
raise ValueError("cbar_ax not supported by FacetGrid.")

cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
func, self.data.values, **kwargs
)
if isinstance(self.data, GroupBy):
data = self.data._obj.values
else:
data = self.data.values
cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(func, data, **kwargs)

self._cmap_extend = cmap_params.get("extend")

Expand All @@ -281,7 +304,7 @@ def map_dataarray(self, func, x, y, **kwargs):

# Get x, y labels for the first subplot
x, y = _infer_xy_labels(
darray=self.data.loc[self.name_dicts.flat[0]],
darray=_get_subset(self.data, self.name_dicts.flat[0], expected_ndim=2),
x=x,
y=y,
imshow=func.__name__ == "imshow",
Expand All @@ -291,7 +314,7 @@ def map_dataarray(self, func, x, y, **kwargs):
for d, ax in zip(self.name_dicts.flat, self.axes.flat):
# None is the sentinel value
if d is not None:
subset = self.data.loc[d]
subset = _get_subset(self.data, d, expected_ndim=2)
mappable = func(subset, x=x, y=y, ax=ax, **func_kwargs)
self._mappables.append(mappable)

Expand All @@ -310,7 +333,7 @@ def map_dataarray_line(
for d, ax in zip(self.name_dicts.flat, self.axes.flat):
# None is the sentinel value
if d is not None:
subset = self.data.loc[d]
subset = _get_subset(self.data, d, expected_ndim=1)
mappable = func(
subset,
x=x,
Expand All @@ -324,7 +347,10 @@ def map_dataarray_line(
self._mappables.append(mappable)

_, _, hueplt, xlabel, ylabel, huelabel = _infer_line_data(
darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y, hue=hue
darray=_get_subset(self.data, self.name_dicts.flat[0], expected_ndim=1),
x=x,
y=y,
hue=hue,
)

self._hue_var = hueplt
Expand Down Expand Up @@ -382,84 +408,6 @@ def map_dataset(

return self

def map_groupby(
self, func, x=None, y=None, hue=None, hue_style=None, add_guide=None, **kwargs
):

if kwargs.get("cbar_ax", None) is not None:
raise ValueError("cbar_ax not supported by FacetGrid.")

cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
func, self.data._obj.values, **kwargs
)

self._cmap_extend = cmap_params.get("extend")

# Order is important
func_kwargs = {
k: v
for k, v in kwargs.items()
if k not in {"cmap", "colors", "cbar_kwargs", "levels"}
}
func_kwargs.update(cmap_params)
func_kwargs.update({"add_colorbar": False, "add_labels": False})

# Get x, y labels for the first subplot
first_label = list(self.name_dicts.flat[0].values())[0]
x, y = _infer_xy_labels(
darray=self.data.get_group(first_label),
x=x,
y=y,
imshow=func.__name__ == "imshow",
rgb=kwargs.get("rgb", None),
)

for (_, subset), ax in zip(self.data, self.axes.flat):
mappable = func(subset.squeeze(), x=x, y=y, ax=ax, **func_kwargs)
self._mappables.append(mappable)

self._finalize_grid(x, y)

if kwargs.get("add_colorbar", True):
self.add_colorbar(**cbar_kwargs)

return self

def map_groupby_line(
self, func, x, y, hue, add_legend=True, _labels=None, **kwargs
):
from .plot import _infer_line_data

grouped = self.data
first_group = grouped._obj.isel(
**{grouped._group_dim: grouped._group_indices[0]}
).squeeze()
_, _, hueplt, xlabel, ylabel, huelabel = _infer_line_data(
darray=first_group, x=x, y=y, hue=hue
)

for (_, subset), ax in zip(self.data, self.axes.flat):
mappable = func(
subset.squeeze(),
x=x,
y=y,
ax=ax,
hue=hue,
add_legend=False,
_labels=False,
**kwargs,
)
self._mappables.append(mappable)

self._hue_var = hueplt
self._hue_label = huelabel
self._finalize_grid(xlabel, ylabel)

if add_legend and hueplt is not None and huelabel is not None:
self.add_legend()

return self

def _finalize_grid(self, *axlabels):
"""Finalize the annotations and layout."""
if not self._finalized:
Expand Down Expand Up @@ -669,16 +617,11 @@ def map(self, func, *args, **kwargs):
self : FacetGrid object
"""
from ..core.groupby import GroupBy

plt = import_matplotlib_pyplot()

for ax, namedict in zip(self.axes.flat, self.name_dicts.flat):
if namedict is not None:
if not isinstance(self.data, GroupBy):
data = self.data.loc[namedict]
else:
data = self.data.get_group(namedict).squeeze()
data = _get_subset(self.data, namedict, expected_ndim=None)
plt.sca(ax)
innerargs = [data[a].values for a in args]
maybe_mappable = func(*innerargs, **kwargs)
Expand Down Expand Up @@ -737,17 +680,11 @@ def _easy_facetgrid(
subplot_kws=subplot_kws,
)

if kind == "line":
if kind == "line" or kind == "groupby_line":
return g.map_dataarray_line(plotfunc, x, y, **kwargs)

if kind == "dataarray":
if kind == "dataarray" or kind == "groupby":
return g.map_dataarray(plotfunc, x, y, **kwargs)

if kind == "dataset":
return g.map_dataset(plotfunc, x, y, **kwargs)

if kind == "groupby":
return g.map_groupby(plotfunc, x, y, **kwargs)

if kind == "groupby_line":
return g.map_groupby_line(plotfunc, x, y, **kwargs)
2 changes: 1 addition & 1 deletion xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def plot(
if isinstance(darray, GroupBy):
grouped_over_unique_coord = darray._unique_coord.equals(
darray._obj.coords[darray._group_dim]
)
) or len(darray._unique_coord) == len(darray[darray._group_dim])
if grouped_over_unique_coord:
plot_dims.discard(darray._group_dim)

Expand Down
12 changes: 9 additions & 3 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2209,14 +2209,14 @@ def setup(self):
{
"variable": (
("lat", "lon", "time"),
np.arange(60.0).reshape((4, 3, 5)),
np.arange(72.0).reshape((4, 3, 6)),
),
"id": (("lat", "lon"), np.arange(12.0).reshape((4, 3))),
},
coords={
"lat": np.arange(4),
"lon": np.arange(3),
"time": pd.date_range(start="2001-01-01", freq="12H", periods=5),
"time": pd.date_range(start="2001-01-01", freq="12H", periods=6),
},
)

Expand All @@ -2225,10 +2225,16 @@ def setup(self):
def test_cftime_grouping(self):
ds = self.ds.copy()
ds["time"] = xr.cftime_range(
start="0001-01-01", periods=5, freq="12H", calendar="noleap"
start="0001-01-01", periods=6, freq="12H", calendar="noleap"
)
ds.variable.sel(lat=0).groupby("time.day").plot(col="day", x="lon", sharey=True)

# TODO: can't plot single vector with plot2d when axis is CFTime
with raises_regex(TypeError, "unsupported operand"):
ds.variable.sel(lat=0).isel(time=slice(-1)).groupby("time.day").plot(
col="day", x="lon", sharey=True
)

def test_time_grouping(self):
self.ds.variable.sel(lat=0).groupby("time.day").plot(
col="day", x="lon", sharey=True
Expand Down

0 comments on commit 0ab1af8

Please sign in to comment.