From ab60b27aee78590816f6b190c2a7bb6761e4b8e4 Mon Sep 17 00:00:00 2001 From: Ari Hartikainen Date: Thu, 5 Mar 2020 12:28:36 +0200 Subject: [PATCH 1/4] [WIP] Modify pairplot to include jointplot and cornerplot like features (#1079) * add jointplot features into pairplot * add scatter_kde kind for pairplot * add point_estimate arguments * bokeh backend * fix None argument for color in kdeplot bokeh backend * run black, pylint and pytest * remove scatter_kde kind among several other changes * minor changes * run pytest * add plot width and height to backend_kwargs fix pylint issues fix hover feature fix hover feature minor fixes * update docstring run pylint * update changelog --- CHANGELOG.md | 2 ++ arviz/plots/backends/bokeh/pairplot.py | 1 + arviz/plots/backends/matplotlib/pairplot.py | 3 +++ arviz/plots/pairplot.py | 2 +- arviz/tests/base_tests/test_plots_matplotlib.py | 9 +++++++-- 5 files changed, 14 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ebf6822400..1eb653f7dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ * Revamped the `hpd` function to make it work with mutidimensional arrays, InferenceData and xarray objects (#1117) * Skip test for optional/extra dependencies when not installed (#1113) * Add option to display rank plots instead of trace (#1134) +* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds #1079 ### Maintenance and fixes * Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115) * Fixed hist kind of `plot_dist` with multidimensional input (#1115) @@ -217,3 +218,4 @@ ## v0.3.0 (2018 Dec 14) * First Beta Release + diff --git a/arviz/plots/backends/bokeh/pairplot.py b/arviz/plots/backends/bokeh/pairplot.py index e6d3e8289f..9e3b6741b3 100644 --- a/arviz/plots/backends/bokeh/pairplot.py +++ b/arviz/plots/backends/bokeh/pairplot.py @@ -150,6 +150,7 @@ def get_width_and_height(jointplot, rotate): ax = np.array(ax) else: assert ax.shape == (numvars - var, numvars - var) + # pylint: disable=too-many-nested-blocks for i in range(0, numvars - var): diff --git a/arviz/plots/backends/matplotlib/pairplot.py b/arviz/plots/backends/matplotlib/pairplot.py index 12e1b7f2d4..d1c6aa6aa4 100644 --- a/arviz/plots/backends/matplotlib/pairplot.py +++ b/arviz/plots/backends/matplotlib/pairplot.py @@ -88,6 +88,9 @@ def plot_pair( for val, ax_, rotate in ((x, ax_hist_x, False), (y, ax_hist_y, True)): plot_dist(val, textsize=xt_labelsize, rotated=rotate, ax=ax_, **marginal_kwargs) + ax_hist_x.set_xlim(ax.get_xlim()) + ax_hist_y.set_ylim(ax.get_ylim()) + # Personalize axes ax_hist_x.tick_params(labelleft=False, labelbottom=False) ax_hist_y.tick_params(labelleft=False, labelbottom=False) diff --git a/arviz/plots/pairplot.py b/arviz/plots/pairplot.py index 7adff78a47..02150d51d6 100644 --- a/arviz/plots/pairplot.py +++ b/arviz/plots/pairplot.py @@ -17,7 +17,7 @@ def plot_pair( textsize=None, kind="scatter", gridsize="auto", - contour=False, + contour=True, plot_kwargs=None, fill_last=False, divergences=False, diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index 60facbe797..cafee80147 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -155,8 +155,13 @@ def test_plot_trace(models, kwargs): assert axes.shape -@pytest.mark.parametrize("compact", [True, False]) -@pytest.mark.parametrize("combined", [True, False]) +@pytest.mark.parametrize( + "compact", [True, False], +) +@pytest.mark.parametrize( + "combined", [True, False], +) + def test_plot_trace_legend(compact, combined): idata = load_arviz_data("rugby") axes = plot_trace( From 3e21a27a4125fe926a5f926bb148f7c3b944b010 Mon Sep 17 00:00:00 2001 From: agustinaarroyuelo Date: Mon, 6 Apr 2020 17:04:12 -0300 Subject: [PATCH 2/4] add true_values argument for plot_pair remove whitespace --- CHANGELOG.md | 2 + arviz/plots/backends/bokeh/pairplot.py | 38 +++++++++++++++++ arviz/plots/backends/matplotlib/pairplot.py | 45 +++++++++++++++++++++ arviz/plots/pairplot.py | 9 +++++ 4 files changed, 94 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1eb653f7dd..16db85675b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,8 @@ ## v0.x.x Unreleased ### New features +* Add `true_values` argument for `plot_pair`. It allows for a scatter plot showing the true values of the variables +* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds #1079 * Add out-of-sample groups (`predictions` and `predictions_constant_data`) and `constant_data` group to pyro translation #1090 * Add `num_chains` and `pred_dims` arguments to io_pyro #1090 * Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079) diff --git a/arviz/plots/backends/bokeh/pairplot.py b/arviz/plots/backends/bokeh/pairplot.py index 9e3b6741b3..d180419a61 100644 --- a/arviz/plots/backends/bokeh/pairplot.py +++ b/arviz/plots/backends/bokeh/pairplot.py @@ -34,6 +34,8 @@ def plot_pair( marginal_kwargs, point_estimate, point_estimate_kwargs, + true_values, + true_values_kwargs, show, ): """Bokeh pair plot.""" @@ -56,6 +58,36 @@ def plot_pair( kde_kwargs["contour_kwargs"].setdefault("line_color", "black") kde_kwargs["contour_kwargs"].setdefault("line_alpha", 1) + if true_values: + true_values_copy = {} + label = [] + for variable in list(true_values.keys()): + if " " in variable: + variable_copy = variable.replace(" ", "\n", 1) + else: + variable_copy = variable + + label.append(variable_copy) + true_values_copy[variable_copy] = true_values[variable] + + difference = set(flat_var_names).difference(set(label)) + + for dif in difference: + true_values_copy[dif] = None + + if difference: + warn = [dif.replace("\n", " ", 1) for dif in difference] + warnings.warn( + "Argument true_values does not include true value for: {}".format(", ".join(warn)), + UserWarning, + ) + + if true_values_kwargs is None: + true_values_kwargs = {} + + true_values_kwargs.setdefault("line_color", "red") + true_values_kwargs.setdefault("line_width", 5) + dpi = backend_kwargs.pop("dpi") max_plots = ( numvars ** 2 if rcParams["plot.max_subplots"] is None else rcParams["plot.max_subplots"] @@ -269,6 +301,12 @@ def get_width_and_height(jointplot, rotate): ) ax[-1, -1].add_layout(ax_pe_hline) + if true_values: + x = true_values_copy[flat_var_names[j + var]] + y = true_values_copy[flat_var_names[i]] + if x and y: + ax[j, i].circle(y, x, **true_values_kwargs) + ax[j, i].xaxis.axis_label = flat_var_names[i] ax[j, i].yaxis.axis_label = flat_var_names[j + var] diff --git a/arviz/plots/backends/matplotlib/pairplot.py b/arviz/plots/backends/matplotlib/pairplot.py index d1c6aa6aa4..e2eae5839c 100644 --- a/arviz/plots/backends/matplotlib/pairplot.py +++ b/arviz/plots/backends/matplotlib/pairplot.py @@ -39,6 +39,8 @@ def plot_pair( point_estimate, point_estimate_kwargs, point_estimate_marker_kwargs, + true_values, + true_values_kwargs, ): """Matplotlib pairplot.""" if backend_kwargs is None: @@ -60,6 +62,36 @@ def plot_pair( kde_kwargs.setdefault("contour_kwargs", {}) kde_kwargs["contour_kwargs"].setdefault("colors", "k") + if true_values: + true_values_copy = {} + label = [] + for variable in list(true_values.keys()): + if " " in variable: + variable_copy = variable.replace(" ", "\n", 1) + else: + variable_copy = variable + + label.append(variable_copy) + true_values_copy[variable_copy] = true_values[variable] + + difference = set(flat_var_names).difference(set(label)) + + for dif in difference: + true_values_copy[dif] = None + + if difference: + warn = [dif.replace("\n", " ", 1) for dif in difference] + warnings.warn( + "Argument true_values does not include true value for: {}".format(", ".join(warn)), + UserWarning, + ) + + if true_values_kwargs is None: + true_values_kwargs = {} + + true_values_kwargs.setdefault("color", "C3") + true_values_kwargs.setdefault("s", 30) + # pylint: disable=too-many-nested-blocks if numvars == 2: (figsize, ax_labelsize, _, xt_labelsize, linewidth, _) = _scale_fig_size( @@ -143,6 +175,12 @@ def plot_pair( ax.scatter(pe_x, pe_y, marker="s", s=figsize[0] + 50, **point_estimate_kwargs, zorder=4) + if true_values: + ax.scatter( + true_values_copy[flat_var_names[0]], + true_values_copy[flat_var_names[1]], + **true_values_kwargs, + ) ax.set_xlabel("{}".format(flat_var_names[0]), fontsize=ax_labelsize, wrap=True) ax.set_ylabel("{}".format(flat_var_names[1]), fontsize=ax_labelsize, wrap=True) ax.tick_params(labelsize=xt_labelsize) @@ -232,6 +270,13 @@ def plot_pair( pe_x, pe_y, s=figsize[0] + 50, zorder=4, **point_estimate_marker_kwargs ) + if true_values: + ax[j, i].scatter( + true_values_copy[flat_var_names[i]], + true_values_copy[flat_var_names[j]], + **true_values_kwargs, + ) + if j != numvars - 1: ax[j, i].axes.get_xaxis().set_major_formatter(NullFormatter()) else: diff --git a/arviz/plots/pairplot.py b/arviz/plots/pairplot.py index 02150d51d6..730614e17f 100644 --- a/arviz/plots/pairplot.py +++ b/arviz/plots/pairplot.py @@ -34,6 +34,8 @@ def plot_pair( point_estimate=None, point_estimate_kwargs=None, point_estimate_marker_kwargs=None, + true_values=None, + true_values_kwargs=None, show=None, ): """ @@ -99,6 +101,11 @@ def plot_pair( Additional keywords passed to ax.vline, ax.hline (matplotlib) or ax.square, Span (bokeh) point_estimate_marker_kwargs: dict, optional Additional keywords passed to ax.scatter in point estimate plot. Not available in bokeh + true_values : dict, optional + True values for the plotted variables. The true values will be plotted + using a scatter marker + true_values_kwargs : dict, optional + Additional keywords passed to ax.scatter or ax.circle in true values plot show : bool, optional Call backend show function. @@ -276,6 +283,8 @@ def plot_pair( point_estimate=point_estimate, point_estimate_kwargs=point_estimate_kwargs, point_estimate_marker_kwargs=point_estimate_marker_kwargs, + true_values=true_values, + true_values_kwargs=true_values_kwargs, ) if backend is None: From f6bfba0c192329baf1407a0de01e26c669164038 Mon Sep 17 00:00:00 2001 From: agustinaarroyuelo Date: Tue, 7 Apr 2020 16:40:55 -0300 Subject: [PATCH 3/4] change argument name to reference_values change argument name to reference_values --- CHANGELOG.md | 4 +- arviz/plots/backends/bokeh/pairplot.py | 34 +++++++------ arviz/plots/backends/matplotlib/pairplot.py | 49 +++++++++---------- arviz/plots/pairplot.py | 16 +++--- .../tests/base_tests/test_plots_matplotlib.py | 1 - 5 files changed, 51 insertions(+), 53 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 16db85675b..fbea5fcdff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,16 +3,14 @@ ## v0.x.x Unreleased ### New features -* Add `true_values` argument for `plot_pair`. It allows for a scatter plot showing the true values of the variables +* Add `true_values` argument for `plot_pair`. It allows for a scatter plot showing the true values of the variables #1140 * Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds #1079 * Add out-of-sample groups (`predictions` and `predictions_constant_data`) and `constant_data` group to pyro translation #1090 * Add `num_chains` and `pred_dims` arguments to io_pyro #1090 -* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079) * Allow xarray.Dataarray input for plots.(#1120) * Revamped the `hpd` function to make it work with mutidimensional arrays, InferenceData and xarray objects (#1117) * Skip test for optional/extra dependencies when not installed (#1113) * Add option to display rank plots instead of trace (#1134) -* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds #1079 ### Maintenance and fixes * Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115) * Fixed hist kind of `plot_dist` with multidimensional input (#1115) diff --git a/arviz/plots/backends/bokeh/pairplot.py b/arviz/plots/backends/bokeh/pairplot.py index d180419a61..93a937e067 100644 --- a/arviz/plots/backends/bokeh/pairplot.py +++ b/arviz/plots/backends/bokeh/pairplot.py @@ -34,8 +34,8 @@ def plot_pair( marginal_kwargs, point_estimate, point_estimate_kwargs, - true_values, - true_values_kwargs, + reference_values, + reference_values_kwargs, show, ): """Bokeh pair plot.""" @@ -58,35 +58,37 @@ def plot_pair( kde_kwargs["contour_kwargs"].setdefault("line_color", "black") kde_kwargs["contour_kwargs"].setdefault("line_alpha", 1) - if true_values: - true_values_copy = {} + if reference_values: + reference_values_copy = {} label = [] - for variable in list(true_values.keys()): + for variable in list(reference_values.keys()): if " " in variable: variable_copy = variable.replace(" ", "\n", 1) else: variable_copy = variable label.append(variable_copy) - true_values_copy[variable_copy] = true_values[variable] + reference_values_copy[variable_copy] = reference_values[variable] difference = set(flat_var_names).difference(set(label)) for dif in difference: - true_values_copy[dif] = None + reference_values_copy[dif] = None if difference: warn = [dif.replace("\n", " ", 1) for dif in difference] warnings.warn( - "Argument true_values does not include true value for: {}".format(", ".join(warn)), + "Argument reference_values does not include true value for: {}".format( + ", ".join(warn) + ), UserWarning, ) - if true_values_kwargs is None: - true_values_kwargs = {} + if reference_values_kwargs is None: + reference_values_kwargs = {} - true_values_kwargs.setdefault("line_color", "red") - true_values_kwargs.setdefault("line_width", 5) + reference_values_kwargs.setdefault("line_color", "red") + reference_values_kwargs.setdefault("line_width", 5) dpi = backend_kwargs.pop("dpi") max_plots = ( @@ -301,11 +303,11 @@ def get_width_and_height(jointplot, rotate): ) ax[-1, -1].add_layout(ax_pe_hline) - if true_values: - x = true_values_copy[flat_var_names[j + var]] - y = true_values_copy[flat_var_names[i]] + if reference_values: + x = reference_values_copy[flat_var_names[j + var]] + y = reference_values_copy[flat_var_names[i]] if x and y: - ax[j, i].circle(y, x, **true_values_kwargs) + ax[j, i].circle(y, x, **reference_values_kwargs) ax[j, i].xaxis.axis_label = flat_var_names[i] ax[j, i].yaxis.axis_label = flat_var_names[j + var] diff --git a/arviz/plots/backends/matplotlib/pairplot.py b/arviz/plots/backends/matplotlib/pairplot.py index e2eae5839c..c8ba2f1114 100644 --- a/arviz/plots/backends/matplotlib/pairplot.py +++ b/arviz/plots/backends/matplotlib/pairplot.py @@ -39,8 +39,8 @@ def plot_pair( point_estimate, point_estimate_kwargs, point_estimate_marker_kwargs, - true_values, - true_values_kwargs, + reference_values, + reference_values_kwargs, ): """Matplotlib pairplot.""" if backend_kwargs is None: @@ -62,35 +62,37 @@ def plot_pair( kde_kwargs.setdefault("contour_kwargs", {}) kde_kwargs["contour_kwargs"].setdefault("colors", "k") - if true_values: - true_values_copy = {} + if reference_values: + reference_values_copy = {} label = [] - for variable in list(true_values.keys()): + for variable in list(reference_values.keys()): if " " in variable: variable_copy = variable.replace(" ", "\n", 1) else: variable_copy = variable label.append(variable_copy) - true_values_copy[variable_copy] = true_values[variable] + reference_values_copy[variable_copy] = reference_values[variable] difference = set(flat_var_names).difference(set(label)) for dif in difference: - true_values_copy[dif] = None + reference_values_copy[dif] = None if difference: warn = [dif.replace("\n", " ", 1) for dif in difference] warnings.warn( - "Argument true_values does not include true value for: {}".format(", ".join(warn)), + "Argument reference_values does not include true value for: {}".format( + ", ".join(warn) + ), UserWarning, ) - if true_values_kwargs is None: - true_values_kwargs = {} + if reference_values_kwargs is None: + reference_values_kwargs = {} - true_values_kwargs.setdefault("color", "C3") - true_values_kwargs.setdefault("s", 30) + reference_values_kwargs.setdefault("color", "C3") + reference_values_kwargs.setdefault("marker", "o") # pylint: disable=too-many-nested-blocks if numvars == 2: @@ -120,9 +122,6 @@ def plot_pair( for val, ax_, rotate in ((x, ax_hist_x, False), (y, ax_hist_y, True)): plot_dist(val, textsize=xt_labelsize, rotated=rotate, ax=ax_, **marginal_kwargs) - ax_hist_x.set_xlim(ax.get_xlim()) - ax_hist_y.set_ylim(ax.get_ylim()) - # Personalize axes ax_hist_x.tick_params(labelleft=False, labelbottom=False) ax_hist_y.tick_params(labelleft=False, labelbottom=False) @@ -175,11 +174,11 @@ def plot_pair( ax.scatter(pe_x, pe_y, marker="s", s=figsize[0] + 50, **point_estimate_kwargs, zorder=4) - if true_values: - ax.scatter( - true_values_copy[flat_var_names[0]], - true_values_copy[flat_var_names[1]], - **true_values_kwargs, + if reference_values: + ax.plot( + reference_values_copy[flat_var_names[0]], + reference_values_copy[flat_var_names[1]], + **reference_values_kwargs, ) ax.set_xlabel("{}".format(flat_var_names[0]), fontsize=ax_labelsize, wrap=True) ax.set_ylabel("{}".format(flat_var_names[1]), fontsize=ax_labelsize, wrap=True) @@ -270,11 +269,11 @@ def plot_pair( pe_x, pe_y, s=figsize[0] + 50, zorder=4, **point_estimate_marker_kwargs ) - if true_values: - ax[j, i].scatter( - true_values_copy[flat_var_names[i]], - true_values_copy[flat_var_names[j]], - **true_values_kwargs, + if reference_values: + ax[j, i].plot( + reference_values_copy[flat_var_names[i]], + reference_values_copy[flat_var_names[j]], + **reference_values_kwargs, ) if j != numvars - 1: diff --git a/arviz/plots/pairplot.py b/arviz/plots/pairplot.py index 730614e17f..b98ed226d6 100644 --- a/arviz/plots/pairplot.py +++ b/arviz/plots/pairplot.py @@ -34,8 +34,8 @@ def plot_pair( point_estimate=None, point_estimate_kwargs=None, point_estimate_marker_kwargs=None, - true_values=None, - true_values_kwargs=None, + reference_values=None, + reference_values_kwargs=None, show=None, ): """ @@ -101,11 +101,11 @@ def plot_pair( Additional keywords passed to ax.vline, ax.hline (matplotlib) or ax.square, Span (bokeh) point_estimate_marker_kwargs: dict, optional Additional keywords passed to ax.scatter in point estimate plot. Not available in bokeh - true_values : dict, optional - True values for the plotted variables. The true values will be plotted + reference_values : dict, optional + Reference values for the plotted variables. The Reference values will be plotted using a scatter marker - true_values_kwargs : dict, optional - Additional keywords passed to ax.scatter or ax.circle in true values plot + reference_values_kwargs : dict, optional + Additional keywords passed to ax.plot or ax.circle in reference values plot show : bool, optional Call backend show function. @@ -283,8 +283,8 @@ def plot_pair( point_estimate=point_estimate, point_estimate_kwargs=point_estimate_kwargs, point_estimate_marker_kwargs=point_estimate_marker_kwargs, - true_values=true_values, - true_values_kwargs=true_values_kwargs, + reference_values=reference_values, + reference_values_kwargs=reference_values_kwargs, ) if backend is None: diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index cafee80147..21a4612673 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -161,7 +161,6 @@ def test_plot_trace(models, kwargs): @pytest.mark.parametrize( "combined", [True, False], ) - def test_plot_trace_legend(compact, combined): idata = load_arviz_data("rugby") axes = plot_trace( From ee1b316b20899a0801aa0d48ad738f97e02dd30d Mon Sep 17 00:00:00 2001 From: agustinaarroyuelo Date: Thu, 9 Apr 2020 12:49:16 -0300 Subject: [PATCH 4/4] update tests --- arviz/plots/backends/bokeh/pairplot.py | 2 +- arviz/plots/backends/matplotlib/pairplot.py | 18 +++++++++--------- arviz/tests/base_tests/test_plots_bokeh.py | 5 +++++ .../tests/base_tests/test_plots_matplotlib.py | 5 +++++ 4 files changed, 20 insertions(+), 10 deletions(-) diff --git a/arviz/plots/backends/bokeh/pairplot.py b/arviz/plots/backends/bokeh/pairplot.py index 93a937e067..e9ec149f52 100644 --- a/arviz/plots/backends/bokeh/pairplot.py +++ b/arviz/plots/backends/bokeh/pairplot.py @@ -78,7 +78,7 @@ def plot_pair( if difference: warn = [dif.replace("\n", " ", 1) for dif in difference] warnings.warn( - "Argument reference_values does not include true value for: {}".format( + "Argument reference_values does not include reference value for: {}".format( ", ".join(warn) ), UserWarning, diff --git a/arviz/plots/backends/matplotlib/pairplot.py b/arviz/plots/backends/matplotlib/pairplot.py index c8ba2f1114..75559b533a 100644 --- a/arviz/plots/backends/matplotlib/pairplot.py +++ b/arviz/plots/backends/matplotlib/pairplot.py @@ -76,13 +76,10 @@ def plot_pair( difference = set(flat_var_names).difference(set(label)) - for dif in difference: - reference_values_copy[dif] = None - if difference: warn = [dif.replace("\n", " ", 1) for dif in difference] warnings.warn( - "Argument reference_values does not include true value for: {}".format( + "Argument reference_values does not include reference value for: {}".format( ", ".join(warn) ), UserWarning, @@ -270,11 +267,14 @@ def plot_pair( ) if reference_values: - ax[j, i].plot( - reference_values_copy[flat_var_names[i]], - reference_values_copy[flat_var_names[j]], - **reference_values_kwargs, - ) + x_name = flat_var_names[i] + y_name = flat_var_names[j] + if x_name and y_name not in difference: + ax[j, i].plot( + reference_values_copy[x_name], + reference_values_copy[y_name], + **reference_values_kwargs, + ) if j != numvars - 1: ax[j, i].axes.get_xaxis().set_major_formatter(NullFormatter()) diff --git a/arviz/tests/base_tests/test_plots_bokeh.py b/arviz/tests/base_tests/test_plots_bokeh.py index 9103d5510a..8275898db0 100644 --- a/arviz/tests/base_tests/test_plots_bokeh.py +++ b/arviz/tests/base_tests/test_plots_bokeh.py @@ -710,6 +710,11 @@ def test_plot_mcse_no_divergences(models): "coords": {"theta_dim_0": [0, 1]}, "textsize": 20, }, + { + "point_estimate": "mean", + "reference_values": {"mu": 0, "tau": 0}, + "reference_values_kwargs": {"line_color": "blue"}, + }, ], ) def test_plot_pair(models, kwargs): diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index 21a4612673..17b7d7da22 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -416,6 +416,11 @@ def test_plot_kde_inference_data(models): "hexbin_kwargs": {"cmap": "viridis"}, "textsize": 20, }, + { + "point_estimate": "mean", + "reference_values": {"mu": 0, "tau": 0}, + "reference_values_kwargs": {"c": "C", "marker": "*"}, + }, ], ) def test_plot_pair(models, kwargs):