diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d23e55d60..4ea450404a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,9 +16,9 @@ * Added the functionality [interactive legends](https://docs.bokeh.org/en/1.4.0/docs/user_guide/interaction/legends.html) for bokeh plots of `densityplot`, `energyplot` and `essplot` (#1024) * New defaults for cross validation: `loo` (old: waic) and `log` -scale (old: `deviance` -scale) (#1067) -* **Experimental Feature**: Added `arviz.wrappers` module to allow ArviZ to - refit the models if necessary (#771) +* **Experimental Feature**: Added `arviz.wrappers` module to allow ArviZ to refit the models if necessary (#771) * **Experimental Feature**: Added `reloo` function to ArviZ (#771) +* Added new helper function `matplotlib_kwarg_dealiaser` (#1073) * ArviZ version to InferenceData attributes. (#1086) * Add `log_likelihood` argument to `from_pymc3` (#1082) * Integrated rcParams for `plot.bokeh.layout` and `plot.backend`. (#1089) diff --git a/arviz/plots/backends/matplotlib/distplot.py b/arviz/plots/backends/matplotlib/distplot.py index 4ba197770e..1c56e27883 100644 --- a/arviz/plots/backends/matplotlib/distplot.py +++ b/arviz/plots/backends/matplotlib/distplot.py @@ -4,6 +4,7 @@ from . import backend_show from ...kdeplot import plot_kde +from ...plot_utils import matplotlib_kwarg_dealiaser def plot_dist( @@ -49,9 +50,7 @@ def plot_dist( ) elif kind == "kde": - if plot_kwargs is None: - plot_kwargs = {} - + plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot") plot_kwargs.setdefault("color", color) legend = label is not None diff --git a/arviz/plots/backends/matplotlib/essplot.py b/arviz/plots/backends/matplotlib/essplot.py index 0e79fff237..36eef38d3d 100644 --- a/arviz/plots/backends/matplotlib/essplot.py +++ b/arviz/plots/backends/matplotlib/essplot.py @@ -7,6 +7,7 @@ from ...plot_utils import ( make_label, _create_axes_grid, + matplotlib_kwarg_dealiaser, ) @@ -63,8 +64,7 @@ def plot_ess( ess_tail = ess_tail_dataset[var_name].sel(**selection) ax_.plot(xdata, ess_tail, **extra_kwargs) elif rug: - if rug_kwargs is None: - rug_kwargs = {} + rug_kwargs = matplotlib_kwarg_dealiaser(rug_kwargs, "plot") if not hasattr(idata, "sample_stats"): raise ValueError("InferenceData object must contain sample_stats for rug plot") if not hasattr(idata.sample_stats, rug_kind): diff --git a/arviz/plots/backends/matplotlib/kdeplot.py b/arviz/plots/backends/matplotlib/kdeplot.py index 63c01030da..9acd045c00 100644 --- a/arviz/plots/backends/matplotlib/kdeplot.py +++ b/arviz/plots/backends/matplotlib/kdeplot.py @@ -4,7 +4,7 @@ import numpy as np from . import backend_show -from ...plot_utils import _scale_fig_size +from ...plot_utils import _scale_fig_size, matplotlib_kwarg_dealiaser def plot_kde( @@ -53,19 +53,15 @@ def plot_kde( figsize, *_, xt_labelsize, linewidth, markersize = _scale_fig_size(figsize, textsize, 1, 1) if values2 is None: - if plot_kwargs is None: - plot_kwargs = {} + plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot") plot_kwargs.setdefault("color", "C0") default_color = plot_kwargs.get("color") - if fill_kwargs is None: - fill_kwargs = {} - + fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "hexbin") fill_kwargs.setdefault("color", default_color) - if rug_kwargs is None: - rug_kwargs = {} + rug_kwargs = matplotlib_kwarg_dealiaser(rug_kwargs, "plot") rug_kwargs.setdefault("marker", "_" if rotated else "|") rug_kwargs.setdefault("linestyle", "None") rug_kwargs.setdefault("color", default_color) @@ -122,13 +118,10 @@ def plot_kde( if legend and label: ax.legend() else: - if contour_kwargs is None: - contour_kwargs = {} + contour_kwargs = matplotlib_kwarg_dealiaser(contour_kwargs, "contour") contour_kwargs.setdefault("colors", "0.5") - if contourf_kwargs is None: - contourf_kwargs = {} - if pcolormesh_kwargs is None: - pcolormesh_kwargs = {} + contourf_kwargs = matplotlib_kwarg_dealiaser(contourf_kwargs, "contour") + pcolormesh_kwargs = matplotlib_kwarg_dealiaser(pcolormesh_kwargs, "pcolormesh") # gridsize = (128, 128) if contour else (256, 256) diff --git a/arviz/plots/backends/matplotlib/mcseplot.py b/arviz/plots/backends/matplotlib/mcseplot.py index 5325964d98..aa938fc132 100644 --- a/arviz/plots/backends/matplotlib/mcseplot.py +++ b/arviz/plots/backends/matplotlib/mcseplot.py @@ -8,6 +8,7 @@ from ...plot_utils import ( make_label, _create_axes_grid, + matplotlib_kwarg_dealiaser, ) @@ -87,8 +88,7 @@ def plot_mcse( **text_kwargs, ) if rug: - if rug_kwargs is None: - rug_kwargs = {} + rug_kwargs = matplotlib_kwarg_dealiaser(rug_kwargs, "plot") if not hasattr(idata, "sample_stats"): raise ValueError("InferenceData object must contain sample_stats for rug plot") if not hasattr(idata.sample_stats, rug_kind): diff --git a/arviz/plots/distplot.py b/arviz/plots/distplot.py index f4f6912784..cb2234db6e 100644 --- a/arviz/plots/distplot.py +++ b/arviz/plots/distplot.py @@ -1,6 +1,7 @@ # pylint: disable=unexpected-keyword-arg """Plot distribution as histogram or kernel density estimates.""" -from .plot_utils import get_bins, get_plotting_function + +from .plot_utils import get_bins, get_plotting_function, matplotlib_kwarg_dealiaser from ..rcparams import rcParams @@ -148,9 +149,7 @@ def plot_dist( kind = "hist" if values.dtype.kind == "i" else "kde" if kind == "hist": - if hist_kwargs is None: - hist_kwargs = {} - + hist_kwargs = matplotlib_kwarg_dealiaser(hist_kwargs, "hist") hist_kwargs.setdefault("bins", get_bins(values)) hist_kwargs.setdefault("cumulative", cumulative) hist_kwargs.setdefault("color", color) diff --git a/arviz/plots/elpdplot.py b/arviz/plots/elpdplot.py index b3fb9644f7..cc4d1f9f31 100644 --- a/arviz/plots/elpdplot.py +++ b/arviz/plots/elpdplot.py @@ -5,7 +5,13 @@ from matplotlib.lines import Line2D from ..data import convert_to_inference_data -from .plot_utils import get_coords, format_coords_as_labels, color_from_dim, get_plotting_function +from .plot_utils import ( + get_coords, + format_coords_as_labels, + color_from_dim, + get_plotting_function, + matplotlib_kwarg_dealiaser, +) from ..stats import waic, loo, ELPDData from ..rcparams import rcParams @@ -146,8 +152,8 @@ def plot_elpd( if coords is None: coords = {} - if plot_kwargs is None: - plot_kwargs = {} + plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "scatter") + if backend == "bokeh": plot_kwargs.setdefault("marker", rcParams["plot.bokeh.marker"]) diff --git a/arviz/plots/energyplot.py b/arviz/plots/energyplot.py index 8a9e915115..24381fc549 100644 --- a/arviz/plots/energyplot.py +++ b/arviz/plots/energyplot.py @@ -4,7 +4,7 @@ import numpy as np from ..data import convert_to_dataset -from .plot_utils import _scale_fig_size, get_plotting_function +from .plot_utils import _scale_fig_size, get_plotting_function, matplotlib_kwarg_dealiaser from ..rcparams import rcParams @@ -93,11 +93,9 @@ def plot_energy( """ energy = convert_to_dataset(data, group="sample_stats").energy.values - if fill_kwargs is None: - fill_kwargs = {} - - if plot_kwargs is None: - plot_kwargs = {} + fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "hexbin") + types = "hist" if kind in {"hist", "histogram"} else "plot" + plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, types) figsize, _, _, xt_labelsize, linewidth, _ = _scale_fig_size(figsize, textsize, 1, 1) diff --git a/arviz/plots/essplot.py b/arviz/plots/essplot.py index cb07b8035d..cb6a2445a1 100644 --- a/arviz/plots/essplot.py +++ b/arviz/plots/essplot.py @@ -11,6 +11,7 @@ get_coords, filter_plotters_list, get_plotting_function, + matplotlib_kwarg_dealiaser, ) from ..rcparams import rcParams from ..utils import _var_names @@ -248,14 +249,15 @@ def plot_ess( (figsize, ax_labelsize, titlesize, xt_labelsize, _linewidth, _markersize) = _scale_fig_size( figsize, textsize, rows, cols ) - _linestyle = kwargs.pop("ls", "-" if kind == "evolution" else "none") + kwargs = matplotlib_kwarg_dealiaser(kwargs, "plot") + _linestyle = "-" if kind == "evolution" else "none" kwargs.setdefault("linestyle", _linestyle) - kwargs.setdefault("linewidth", kwargs.pop("lw", _linewidth)) - kwargs.setdefault("markersize", kwargs.pop("ms", _markersize)) + kwargs.setdefault("linewidth", _linewidth) + kwargs.setdefault("markersize", _markersize) kwargs.setdefault("marker", "o") kwargs.setdefault("zorder", 3) - if extra_kwargs is None: - extra_kwargs = {} + + extra_kwargs = matplotlib_kwarg_dealiaser(extra_kwargs, "plot") if kind == "evolution": extra_kwargs = { **extra_kwargs, @@ -264,28 +266,27 @@ def plot_ess( kwargs.setdefault("label", "bulk") extra_kwargs.setdefault("label", "tail") else: - extra_kwargs.setdefault("linestyle", extra_kwargs.pop("ls", "-")) - extra_kwargs.setdefault("linewidth", extra_kwargs.pop("lw", _linewidth / 2)) + extra_kwargs.setdefault("linestyle", "-") + extra_kwargs.setdefault("linewidth", _linewidth / 2) extra_kwargs.setdefault("color", "k") extra_kwargs.setdefault("alpha", 0.5) kwargs.setdefault("label", kind) - if hline_kwargs is None: - hline_kwargs = {} - hline_kwargs.setdefault("linewidth", hline_kwargs.pop("lw", _linewidth)) - hline_kwargs.setdefault("linestyle", hline_kwargs.pop("ls", "--")) - hline_kwargs.setdefault("color", hline_kwargs.pop("c", "gray")) + + hline_kwargs = matplotlib_kwarg_dealiaser(hline_kwargs, "plot") + hline_kwargs.setdefault("linewidth", _linewidth) + hline_kwargs.setdefault("linestyle", "--") + hline_kwargs.setdefault("color", "gray") hline_kwargs.setdefault("alpha", 0.7) if extra_methods: mean_ess = ess(data, var_names=var_names, method="mean", relative=relative) sd_ess = ess(data, var_names=var_names, method="sd", relative=relative) - if text_kwargs is None: - text_kwargs = {} + text_kwargs = matplotlib_kwarg_dealiaser(text_kwargs, "text") text_x = text_kwargs.pop("x", 1) - text_kwargs.setdefault("fontsize", text_kwargs.pop("size", xt_labelsize * 0.7)) + text_kwargs.setdefault("fontsize", xt_labelsize * 0.7) text_kwargs.setdefault("alpha", extra_kwargs["alpha"]) text_kwargs.setdefault("color", extra_kwargs["color"]) - text_kwargs.setdefault("horizontalalignment", text_kwargs.pop("ha", "right")) - text_va = text_kwargs.pop("verticalalignment", text_kwargs.pop("va", None)) + text_kwargs.setdefault("horizontalalignment", "right") + text_va = text_kwargs.pop("verticalalignment", None) essplot_kwargs = dict( ax=ax, diff --git a/arviz/plots/hpdplot.py b/arviz/plots/hpdplot.py index 563dc62e44..93f56a9453 100644 --- a/arviz/plots/hpdplot.py +++ b/arviz/plots/hpdplot.py @@ -4,7 +4,7 @@ from scipy.signal import savgol_filter from ..stats import hpd -from .plot_utils import get_plotting_function +from .plot_utils import get_plotting_function, matplotlib_kwarg_dealiaser from ..rcparams import rcParams @@ -64,13 +64,11 @@ def plot_hpd( ------- axes : matplotlib axes or bokeh figures """ - if plot_kwargs is None: - plot_kwargs = {} + plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot") plot_kwargs.setdefault("color", color) plot_kwargs.setdefault("alpha", 0) - if fill_kwargs is None: - fill_kwargs = {} + fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "hexbin") fill_kwargs.setdefault("color", color) fill_kwargs.setdefault("alpha", 0.5) diff --git a/arviz/plots/jointplot.py b/arviz/plots/jointplot.py index d580a8f97c..eb1f4466a7 100644 --- a/arviz/plots/jointplot.py +++ b/arviz/plots/jointplot.py @@ -1,6 +1,12 @@ """Joint scatter plot of two variables.""" from ..data import convert_to_dataset -from .plot_utils import _scale_fig_size, xarray_var_iter, get_coords, get_plotting_function +from .plot_utils import ( + _scale_fig_size, + xarray_var_iter, + get_coords, + get_plotting_function, + matplotlib_kwarg_dealiaser, +) from ..rcparams import rcParams from ..utils import _var_names @@ -157,8 +163,13 @@ def plot_joint( figsize, ax_labelsize, _, xt_labelsize, linewidth, _ = _scale_fig_size(figsize, textsize) - if joint_kwargs is None: - joint_kwargs = {} + if kind == "kde": + types = "plot" + elif kind == "scatter": + types = "scatter" + else: + types = "hexbin" + joint_kwargs = matplotlib_kwarg_dealiaser(joint_kwargs, types) if marginal_kwargs is None: marginal_kwargs = {} diff --git a/arviz/plots/khatplot.py b/arviz/plots/khatplot.py index 935367831c..c62cfae63d 100644 --- a/arviz/plots/khatplot.py +++ b/arviz/plots/khatplot.py @@ -11,6 +11,7 @@ color_from_dim, format_coords_as_labels, get_plotting_function, + matplotlib_kwarg_dealiaser, ) from ..stats import ELPDData from ..rcparams import rcParams @@ -132,8 +133,7 @@ def plot_khat( >>> az.plot_khat(loo_radon, color=colors) """ - if hlines_kwargs is None: - hlines_kwargs = {} + hlines_kwargs = matplotlib_kwarg_dealiaser(hlines_kwargs, "hlines") hlines_kwargs.setdefault("linestyle", [":", "-.", "--", "-"]) hlines_kwargs.setdefault("alpha", 0.7) hlines_kwargs.setdefault("zorder", -1) @@ -173,6 +173,8 @@ def plot_khat( if markersize is None: markersize = scaled_markersize ** 2 # s in scatter plot mus be markersize square # for dots to have the same size + + kwargs = matplotlib_kwarg_dealiaser(kwargs, "scatter") kwargs.setdefault("s", markersize) kwargs.setdefault("marker", "+") color_mapping = None diff --git a/arviz/plots/loopitplot.py b/arviz/plots/loopitplot.py index f9e3ac8ae8..40cf225fca 100644 --- a/arviz/plots/loopitplot.py +++ b/arviz/plots/loopitplot.py @@ -9,6 +9,7 @@ _scale_fig_size, get_plotting_function, _fast_kde, + matplotlib_kwarg_dealiaser, ) from ..rcparams import rcParams @@ -137,8 +138,7 @@ def plot_loo_pit( loo_pit = _loo_pit(idata=idata, y=y, y_hat=y_hat, log_weights=log_weights) loo_pit = loo_pit.flatten() if isinstance(loo_pit, np.ndarray) else loo_pit.values.flatten() - if plot_kwargs is None: - plot_kwargs = {} + plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot") plot_kwargs["color"] = to_hex(color) plot_kwargs.setdefault("linewidth", linewidth * 1.4) if isinstance(y, str): @@ -155,8 +155,7 @@ def plot_loo_pit( plot_kwargs.setdefault("label", label) plot_kwargs.setdefault("zorder", 5) - if plot_unif_kwargs is None: - plot_unif_kwargs = {} + plot_unif_kwargs = matplotlib_kwarg_dealiaser(plot_unif_kwargs, "plot") light_color = rgb_to_hsv(to_rgb(plot_kwargs.get("color"))) light_color[1] /= 2 # pylint: disable=unsupported-assignment-operation light_color[2] += (1 - light_color[2]) / 2 # pylint: disable=unsupported-assignment-operation diff --git a/arviz/plots/mcseplot.py b/arviz/plots/mcseplot.py index 562f753c94..fac8f15359 100644 --- a/arviz/plots/mcseplot.py +++ b/arviz/plots/mcseplot.py @@ -11,6 +11,7 @@ get_coords, filter_plotters_list, get_plotting_function, + matplotlib_kwarg_dealiaser, ) from ..rcparams import rcParams from ..utils import _var_names @@ -134,28 +135,29 @@ def plot_mcse( (figsize, ax_labelsize, titlesize, xt_labelsize, _linewidth, _markersize) = _scale_fig_size( figsize, textsize, rows, cols ) - kwargs.setdefault("linestyle", kwargs.pop("ls", "none")) - kwargs.setdefault("linewidth", kwargs.pop("lw", _linewidth)) - kwargs.setdefault("markersize", kwargs.pop("ms", _markersize)) + kwargs = matplotlib_kwarg_dealiaser(kwargs, "plot") + kwargs.setdefault("linestyle", "none") + kwargs.setdefault("linewidth", _linewidth) + kwargs.setdefault("markersize", _markersize) kwargs.setdefault("marker", "_" if errorbar else "o") kwargs.setdefault("zorder", 3) - if extra_kwargs is None: - extra_kwargs = {} - extra_kwargs.setdefault("linestyle", extra_kwargs.pop("ls", "-")) - extra_kwargs.setdefault("linewidth", extra_kwargs.pop("lw", _linewidth / 2)) + + extra_kwargs = matplotlib_kwarg_dealiaser(extra_kwargs, "plot") + extra_kwargs.setdefault("linestyle", "-") + extra_kwargs.setdefault("linewidth", _linewidth / 2) extra_kwargs.setdefault("color", "k") extra_kwargs.setdefault("alpha", 0.5) if extra_methods: mean_mcse = mcse(data, var_names=var_names, method="mean") sd_mcse = mcse(data, var_names=var_names, method="sd") - if text_kwargs is None: - text_kwargs = {} + + text_kwargs = matplotlib_kwarg_dealiaser(text_kwargs, "text") text_x = text_kwargs.pop("x", 1) - text_kwargs.setdefault("fontsize", text_kwargs.pop("size", xt_labelsize * 0.7)) + text_kwargs.setdefault("fontsize", xt_labelsize * 0.7) text_kwargs.setdefault("alpha", extra_kwargs["alpha"]) text_kwargs.setdefault("color", extra_kwargs["color"]) - text_kwargs.setdefault("horizontalalignment", text_kwargs.pop("ha", "right")) - text_va = text_kwargs.pop("verticalalignment", text_kwargs.pop("va", None)) + text_kwargs.setdefault("horizontalalignment", "right") + text_va = text_kwargs.pop("verticalalignment", None) mcse_kwargs = dict( ax=ax, diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index 43c6377208..68f0e26b83 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -12,6 +12,7 @@ import numpy as np import matplotlib.pyplot as plt import matplotlib as mpl +import matplotlib.cbook as cbook import xarray as xr @@ -914,3 +915,24 @@ def _fast_kde_2d(x, y, gridsize=(128, 128), circular=False): grid /= norm_factor return grid, xmin, xmax, ymin, ymax + + +def matplotlib_kwarg_dealiaser(args, kind, backend="matplotlib"): + """De-aliase the kwargs passed to plots.""" + if args is None: + return {} + matplotlib_kwarg_dealiaser_dict = { + "scatter": mpl.collections.PathCollection, + "plot": mpl.lines.Line2D, + "hist": mpl.patches.Patch, + "hexbin": mpl.collections.PolyCollection, + "hlines": mpl.collections.LineCollection, + "text": mpl.text.Text, + "contour": mpl.contour.ContourSet, + "pcolormesh": mpl.collections.QuadMesh, + } + if backend == "matplotlib": + return cbook.normalize_kwargs( + args, getattr(matplotlib_kwarg_dealiaser_dict[kind], "_alias_map", {}) + ) + return args diff --git a/arviz/plots/posteriorplot.py b/arviz/plots/posteriorplot.py index 561b0b89d9..6bfe7f0abe 100644 --- a/arviz/plots/posteriorplot.py +++ b/arviz/plots/posteriorplot.py @@ -9,6 +9,7 @@ get_coords, filter_plotters_list, get_plotting_function, + matplotlib_kwarg_dealiaser, ) from ..utils import _var_names from ..rcparams import rcParams @@ -208,6 +209,10 @@ def plot_posterior( (figsize, ax_labelsize, titlesize, xt_labelsize, _linewidth, _) = _scale_fig_size( figsize, textsize, rows, cols ) + if kind == "hist": + kwargs = matplotlib_kwarg_dealiaser(kwargs, "hist") + else: + kwargs = matplotlib_kwarg_dealiaser(kwargs, "plot") kwargs.setdefault("linewidth", _linewidth) posteriorplot_kwargs = dict( diff --git a/arviz/plots/traceplot.py b/arviz/plots/traceplot.py index 55c94f7d4d..e0628ece87 100644 --- a/arviz/plots/traceplot.py +++ b/arviz/plots/traceplot.py @@ -5,7 +5,13 @@ import matplotlib.pyplot as plt -from .plot_utils import get_plotting_function, get_coords, xarray_var_iter, KwargSpec +from .plot_utils import ( + get_plotting_function, + get_coords, + xarray_var_iter, + KwargSpec, + matplotlib_kwarg_dealiaser, +) from ..data import convert_to_dataset, InferenceData, CoordSpec from ..utils import _var_names from ..rcparams import rcParams @@ -216,8 +222,7 @@ def plot_trace( if figsize is None: figsize = (12, len(plotters) * 2) - if trace_kwargs is None: - trace_kwargs = {} + trace_kwargs = matplotlib_kwarg_dealiaser(trace_kwargs, "plot") trace_kwargs.setdefault("alpha", 0.35) if hist_kwargs is None: diff --git a/arviz/plots/violinplot.py b/arviz/plots/violinplot.py index e547124343..61e33b10c2 100644 --- a/arviz/plots/violinplot.py +++ b/arviz/plots/violinplot.py @@ -6,6 +6,7 @@ filter_plotters_list, default_grid, get_plotting_function, + matplotlib_kwarg_dealiaser, ) from ..utils import _var_names from ..rcparams import rcParams @@ -116,8 +117,7 @@ def plot_violin( list(xarray_var_iter(data, var_names=var_names, combined=True)), "plot_violin" ) - if shade_kwargs is None: - shade_kwargs = {} + shade_kwargs = matplotlib_kwarg_dealiaser(shade_kwargs, "hexbin") rows, cols = default_grid(len(plotters)) @@ -125,8 +125,7 @@ def plot_violin( figsize, textsize, rows, cols ) - if rug_kwargs is None: - rug_kwargs = {} + rug_kwargs = matplotlib_kwarg_dealiaser(rug_kwargs, "plot") if credible_interval is None: credible_interval = rcParams["stats.credible_interval"] diff --git a/arviz/tests/base_tests/test_plot_utils.py b/arviz/tests/base_tests/test_plot_utils.py index b6b795cf1b..4ce2e643ea 100644 --- a/arviz/tests/base_tests/test_plot_utils.py +++ b/arviz/tests/base_tests/test_plot_utils.py @@ -13,6 +13,7 @@ filter_plotters_list, format_sig_figs, get_plotting_function, + matplotlib_kwarg_dealiaser, ) from ...rcparams import rc_context @@ -200,3 +201,24 @@ def test_bokeh_import(): from arviz.plots.backends.bokeh.distplot import plot_dist assert plot is plot_dist + + +@pytest.mark.parametrize( + "params", + [ + {"input": ({"dashes": "-",}, "scatter"), "output": "linestyle", }, + { + "input": ({"mfc": "blue", "c": "blue", "line_width": 2}, "plot",), + "output": ("markerfacecolor", "color", "line_width"), + }, + {"input": ({"ec": "blue", "fc": "black"}, "hist"), "output": ("edgecolor", "facecolor")}, + { + "input": ({"edgecolors": "blue", "lw": 3}, "hlines"), + "output": ("edgecolor", "linewidth"), + }, + ], +) +def test_matplotlib_kwarg_dealiaser(params): + dealiased = matplotlib_kwarg_dealiaser(params["input"][0], kind=params["input"][1]) + for returned in dealiased: + assert returned in params["output"]