From 944532f558f7170c3bd53f9e96ac563e13a8badb Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 12 Jul 2022 22:01:42 +0200 Subject: [PATCH 1/4] Move dataset plot functions to utils --- xarray/plot/dataset_plot.py | 86 +------------------------------------ xarray/plot/utils.py | 83 +++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 85 deletions(-) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index aeb53126265..6ca7deae2ce 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -10,99 +10,15 @@ from .utils import ( _add_colorbar, _get_nice_quiver_magnitude, - _is_numeric, + _infer_meta_data, _process_cmap_cbar_kwargs, get_axis, - label_from_attrs, ) # copied from seaborn _MARKERSIZE_RANGE = np.array([18.0, 72.0]) -def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): - dvars = set(ds.variables.keys()) - error_msg = " must be one of ({:s})".format(", ".join(dvars)) - - if x not in dvars: - raise ValueError("x" + error_msg) - - if y not in dvars: - raise ValueError("y" + error_msg) - - if hue is not None and hue not in dvars: - raise ValueError("hue" + error_msg) - - if hue: - hue_is_numeric = _is_numeric(ds[hue].values) - - if hue_style is None: - hue_style = "continuous" if hue_is_numeric else "discrete" - - if not hue_is_numeric and (hue_style == "continuous"): - raise ValueError( - f"Cannot create a colorbar for a non numeric coordinate: {hue}" - ) - - if add_guide is None or add_guide is True: - add_colorbar = True if hue_style == "continuous" else False - add_legend = True if hue_style == "discrete" else False - else: - add_colorbar = False - add_legend = False - else: - if add_guide is True and funcname not in ("quiver", "streamplot"): - raise ValueError("Cannot set add_guide when hue is None.") - add_legend = False - add_colorbar = False - - if (add_guide or add_guide is None) and funcname == "quiver": - add_quiverkey = True - if hue: - add_colorbar = True - if not hue_style: - hue_style = "continuous" - elif hue_style != "continuous": - raise ValueError( - "hue_style must be 'continuous' or None for .plot.quiver or " - ".plot.streamplot" - ) - else: - add_quiverkey = False - - if (add_guide or add_guide is None) and funcname == "streamplot": - if hue: - add_colorbar = True - if not hue_style: - hue_style = "continuous" - elif hue_style != "continuous": - raise ValueError( - "hue_style must be 'continuous' or None for .plot.quiver or " - ".plot.streamplot" - ) - - if hue_style is not None and hue_style not in ["discrete", "continuous"]: - raise ValueError("hue_style must be either None, 'discrete' or 'continuous'.") - - if hue: - hue_label = label_from_attrs(ds[hue]) - hue = ds[hue] - else: - hue_label = None - hue = None - - return { - "add_colorbar": add_colorbar, - "add_legend": add_legend, - "add_quiverkey": add_quiverkey, - "hue_label": hue_label, - "hue_style": hue_style, - "xlabel": label_from_attrs(ds[x]), - "ylabel": label_from_attrs(ds[y]), - "hue": hue, - } - - def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None): broadcast_keys = ["x", "y"] diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index aef21f0be42..d0c0d27aa44 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -1141,3 +1141,86 @@ def _adjust_legend_subtitles(legend): # The sutbtitles should have the same font size # as normal legend titles: text.set_size(font_size) + + +def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): + dvars = set(ds.variables.keys()) + error_msg = " must be one of ({:s})".format(", ".join(dvars)) + + if x not in dvars: + raise ValueError("x" + error_msg) + + if y not in dvars: + raise ValueError("y" + error_msg) + + if hue is not None and hue not in dvars: + raise ValueError("hue" + error_msg) + + if hue: + hue_is_numeric = _is_numeric(ds[hue].values) + + if hue_style is None: + hue_style = "continuous" if hue_is_numeric else "discrete" + + if not hue_is_numeric and (hue_style == "continuous"): + raise ValueError( + f"Cannot create a colorbar for a non numeric coordinate: {hue}" + ) + + if add_guide is None or add_guide is True: + add_colorbar = True if hue_style == "continuous" else False + add_legend = True if hue_style == "discrete" else False + else: + add_colorbar = False + add_legend = False + else: + if add_guide is True and funcname not in ("quiver", "streamplot"): + raise ValueError("Cannot set add_guide when hue is None.") + add_legend = False + add_colorbar = False + + if (add_guide or add_guide is None) and funcname == "quiver": + add_quiverkey = True + if hue: + add_colorbar = True + if not hue_style: + hue_style = "continuous" + elif hue_style != "continuous": + raise ValueError( + "hue_style must be 'continuous' or None for .plot.quiver or " + ".plot.streamplot" + ) + else: + add_quiverkey = False + + if (add_guide or add_guide is None) and funcname == "streamplot": + if hue: + add_colorbar = True + if not hue_style: + hue_style = "continuous" + elif hue_style != "continuous": + raise ValueError( + "hue_style must be 'continuous' or None for .plot.quiver or " + ".plot.streamplot" + ) + + if hue_style is not None and hue_style not in ["discrete", "continuous"]: + raise ValueError("hue_style must be either None, 'discrete' or 'continuous'.") + + if hue: + hue_label = label_from_attrs(ds[hue]) + hue = ds[hue] + else: + hue_label = None + hue = None + + return { + "add_colorbar": add_colorbar, + "add_legend": add_legend, + "add_quiverkey": add_quiverkey, + "hue_label": hue_label, + "hue_style": hue_style, + "xlabel": label_from_attrs(ds[x]), + "ylabel": label_from_attrs(ds[y]), + "hue": hue, + } From 55f10f8167467418c8ddcb734a08974f67ec680d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 12 Jul 2022 22:07:16 +0200 Subject: [PATCH 2/4] move parse_size --- xarray/plot/dataset_plot.py | 42 +------------------------------------ xarray/plot/utils.py | 41 ++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 41 deletions(-) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 6ca7deae2ce..def702fa74f 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -11,6 +11,7 @@ _add_colorbar, _get_nice_quiver_magnitude, _infer_meta_data, + _parse_size, _process_cmap_cbar_kwargs, get_axis, ) @@ -50,47 +51,6 @@ def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None) return data -# copied from seaborn -def _parse_size(data, norm): - - import matplotlib as mpl - - if data is None: - return None - - data = data.values.flatten() - - if not _is_numeric(data): - levels = np.unique(data) - numbers = np.arange(1, 1 + len(levels))[::-1] - else: - levels = numbers = np.sort(np.unique(data)) - - min_width, max_width = _MARKERSIZE_RANGE - # width_range = min_width, max_width - - if norm is None: - norm = mpl.colors.Normalize() - elif isinstance(norm, tuple): - norm = mpl.colors.Normalize(*norm) - elif not isinstance(norm, mpl.colors.Normalize): - err = "``size_norm`` must be None, tuple, or Normalize object." - raise ValueError(err) - - norm.clip = True - if not norm.scaled(): - norm(np.asarray(numbers)) - # limits = norm.vmin, norm.vmax - - scl = norm(numbers) - widths = np.asarray(min_width + scl * (max_width - min_width)) - if scl.mask.any(): - widths[scl.mask] = 0 - sizes = dict(zip(levels, widths)) - - return pd.Series(sizes) - - class _Dataset_PlotMethods: """ Enables use of xarray.plot functions as attributes on a Dataset. diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index d0c0d27aa44..7d3211bb306 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -1224,3 +1224,44 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): "ylabel": label_from_attrs(ds[y]), "hue": hue, } + + +# copied from seaborn +def _parse_size(data, norm): + + import matplotlib as mpl + + if data is None: + return None + + data = data.values.flatten() + + if not _is_numeric(data): + levels = np.unique(data) + numbers = np.arange(1, 1 + len(levels))[::-1] + else: + levels = numbers = np.sort(np.unique(data)) + + min_width, max_width = _MARKERSIZE_RANGE + # width_range = min_width, max_width + + if norm is None: + norm = mpl.colors.Normalize() + elif isinstance(norm, tuple): + norm = mpl.colors.Normalize(*norm) + elif not isinstance(norm, mpl.colors.Normalize): + err = "``size_norm`` must be None, tuple, or Normalize object." + raise ValueError(err) + + norm.clip = True + if not norm.scaled(): + norm(np.asarray(numbers)) + # limits = norm.vmin, norm.vmax + + scl = norm(numbers) + widths = np.asarray(min_width + scl * (max_width - min_width)) + if scl.mask.any(): + widths[scl.mask] = 0 + sizes = dict(zip(levels, widths)) + + return pd.Series(sizes) From a83a2c3990786ba7199b6f92bf3b33c32f58877d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 12 Jul 2022 22:11:42 +0200 Subject: [PATCH 3/4] move markersize --- xarray/plot/dataset_plot.py | 3 --- xarray/plot/utils.py | 2 ++ 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index def702fa74f..9863667ec21 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -16,9 +16,6 @@ get_axis, ) -# copied from seaborn -_MARKERSIZE_RANGE = np.array([18.0, 72.0]) - def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None): diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 7d3211bb306..cca644319c9 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -30,6 +30,8 @@ ROBUST_PERCENTILE = 2.0 +# copied from seaborn +_MARKERSIZE_RANGE = np.array([18.0, 72.0]) def import_matplotlib_pyplot(): """import pyplot""" From 4b2cb93d317c2e687de012ce2c4450fe7db5b8d4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Jul 2022 20:13:15 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/plot/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index cca644319c9..02befbea422 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -33,6 +33,7 @@ # copied from seaborn _MARKERSIZE_RANGE = np.array([18.0, 72.0]) + def import_matplotlib_pyplot(): """import pyplot""" # TODO: This function doesn't do anything (after #6109), remove it?