diff --git a/CHANGELOG.md b/CHANGELOG.md index ebf6822400..fbea5fcdff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,9 +3,10 @@ ## v0.x.x Unreleased ### New features +* Add `true_values` argument for `plot_pair`. It allows for a scatter plot showing the true values of the variables #1140 +* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds #1079 * Add out-of-sample groups (`predictions` and `predictions_constant_data`) and `constant_data` group to pyro translation #1090 * Add `num_chains` and `pred_dims` arguments to io_pyro #1090 -* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079) * Allow xarray.Dataarray input for plots.(#1120) * Revamped the `hpd` function to make it work with mutidimensional arrays, InferenceData and xarray objects (#1117) * Skip test for optional/extra dependencies when not installed (#1113) @@ -217,3 +218,4 @@ ## v0.3.0 (2018 Dec 14) * First Beta Release + diff --git a/arviz/plots/backends/bokeh/pairplot.py b/arviz/plots/backends/bokeh/pairplot.py index e6d3e8289f..e9ec149f52 100644 --- a/arviz/plots/backends/bokeh/pairplot.py +++ b/arviz/plots/backends/bokeh/pairplot.py @@ -34,6 +34,8 @@ def plot_pair( marginal_kwargs, point_estimate, point_estimate_kwargs, + reference_values, + reference_values_kwargs, show, ): """Bokeh pair plot.""" @@ -56,6 +58,38 @@ def plot_pair( kde_kwargs["contour_kwargs"].setdefault("line_color", "black") kde_kwargs["contour_kwargs"].setdefault("line_alpha", 1) + if reference_values: + reference_values_copy = {} + label = [] + for variable in list(reference_values.keys()): + if " " in variable: + variable_copy = variable.replace(" ", "\n", 1) + else: + variable_copy = variable + + label.append(variable_copy) + reference_values_copy[variable_copy] = reference_values[variable] + + difference = set(flat_var_names).difference(set(label)) + + for dif in difference: + reference_values_copy[dif] = None + + if difference: + warn = [dif.replace("\n", " ", 1) for dif in difference] + warnings.warn( + "Argument reference_values does not include reference value for: {}".format( + ", ".join(warn) + ), + UserWarning, + ) + + if reference_values_kwargs is None: + reference_values_kwargs = {} + + reference_values_kwargs.setdefault("line_color", "red") + reference_values_kwargs.setdefault("line_width", 5) + dpi = backend_kwargs.pop("dpi") max_plots = ( numvars ** 2 if rcParams["plot.max_subplots"] is None else rcParams["plot.max_subplots"] @@ -150,6 +184,7 @@ def get_width_and_height(jointplot, rotate): ax = np.array(ax) else: assert ax.shape == (numvars - var, numvars - var) + # pylint: disable=too-many-nested-blocks for i in range(0, numvars - var): @@ -268,6 +303,12 @@ def get_width_and_height(jointplot, rotate): ) ax[-1, -1].add_layout(ax_pe_hline) + if reference_values: + x = reference_values_copy[flat_var_names[j + var]] + y = reference_values_copy[flat_var_names[i]] + if x and y: + ax[j, i].circle(y, x, **reference_values_kwargs) + ax[j, i].xaxis.axis_label = flat_var_names[i] ax[j, i].yaxis.axis_label = flat_var_names[j + var] diff --git a/arviz/plots/backends/matplotlib/pairplot.py b/arviz/plots/backends/matplotlib/pairplot.py index 12e1b7f2d4..75559b533a 100644 --- a/arviz/plots/backends/matplotlib/pairplot.py +++ b/arviz/plots/backends/matplotlib/pairplot.py @@ -39,6 +39,8 @@ def plot_pair( point_estimate, point_estimate_kwargs, point_estimate_marker_kwargs, + reference_values, + reference_values_kwargs, ): """Matplotlib pairplot.""" if backend_kwargs is None: @@ -60,6 +62,35 @@ def plot_pair( kde_kwargs.setdefault("contour_kwargs", {}) kde_kwargs["contour_kwargs"].setdefault("colors", "k") + if reference_values: + reference_values_copy = {} + label = [] + for variable in list(reference_values.keys()): + if " " in variable: + variable_copy = variable.replace(" ", "\n", 1) + else: + variable_copy = variable + + label.append(variable_copy) + reference_values_copy[variable_copy] = reference_values[variable] + + difference = set(flat_var_names).difference(set(label)) + + if difference: + warn = [dif.replace("\n", " ", 1) for dif in difference] + warnings.warn( + "Argument reference_values does not include reference value for: {}".format( + ", ".join(warn) + ), + UserWarning, + ) + + if reference_values_kwargs is None: + reference_values_kwargs = {} + + reference_values_kwargs.setdefault("color", "C3") + reference_values_kwargs.setdefault("marker", "o") + # pylint: disable=too-many-nested-blocks if numvars == 2: (figsize, ax_labelsize, _, xt_labelsize, linewidth, _) = _scale_fig_size( @@ -140,6 +171,12 @@ def plot_pair( ax.scatter(pe_x, pe_y, marker="s", s=figsize[0] + 50, **point_estimate_kwargs, zorder=4) + if reference_values: + ax.plot( + reference_values_copy[flat_var_names[0]], + reference_values_copy[flat_var_names[1]], + **reference_values_kwargs, + ) ax.set_xlabel("{}".format(flat_var_names[0]), fontsize=ax_labelsize, wrap=True) ax.set_ylabel("{}".format(flat_var_names[1]), fontsize=ax_labelsize, wrap=True) ax.tick_params(labelsize=xt_labelsize) @@ -229,6 +266,16 @@ def plot_pair( pe_x, pe_y, s=figsize[0] + 50, zorder=4, **point_estimate_marker_kwargs ) + if reference_values: + x_name = flat_var_names[i] + y_name = flat_var_names[j] + if x_name and y_name not in difference: + ax[j, i].plot( + reference_values_copy[x_name], + reference_values_copy[y_name], + **reference_values_kwargs, + ) + if j != numvars - 1: ax[j, i].axes.get_xaxis().set_major_formatter(NullFormatter()) else: diff --git a/arviz/plots/pairplot.py b/arviz/plots/pairplot.py index 7adff78a47..b98ed226d6 100644 --- a/arviz/plots/pairplot.py +++ b/arviz/plots/pairplot.py @@ -17,7 +17,7 @@ def plot_pair( textsize=None, kind="scatter", gridsize="auto", - contour=False, + contour=True, plot_kwargs=None, fill_last=False, divergences=False, @@ -34,6 +34,8 @@ def plot_pair( point_estimate=None, point_estimate_kwargs=None, point_estimate_marker_kwargs=None, + reference_values=None, + reference_values_kwargs=None, show=None, ): """ @@ -99,6 +101,11 @@ def plot_pair( Additional keywords passed to ax.vline, ax.hline (matplotlib) or ax.square, Span (bokeh) point_estimate_marker_kwargs: dict, optional Additional keywords passed to ax.scatter in point estimate plot. Not available in bokeh + reference_values : dict, optional + Reference values for the plotted variables. The Reference values will be plotted + using a scatter marker + reference_values_kwargs : dict, optional + Additional keywords passed to ax.plot or ax.circle in reference values plot show : bool, optional Call backend show function. @@ -276,6 +283,8 @@ def plot_pair( point_estimate=point_estimate, point_estimate_kwargs=point_estimate_kwargs, point_estimate_marker_kwargs=point_estimate_marker_kwargs, + reference_values=reference_values, + reference_values_kwargs=reference_values_kwargs, ) if backend is None: diff --git a/arviz/tests/base_tests/test_plots_bokeh.py b/arviz/tests/base_tests/test_plots_bokeh.py index 9103d5510a..8275898db0 100644 --- a/arviz/tests/base_tests/test_plots_bokeh.py +++ b/arviz/tests/base_tests/test_plots_bokeh.py @@ -710,6 +710,11 @@ def test_plot_mcse_no_divergences(models): "coords": {"theta_dim_0": [0, 1]}, "textsize": 20, }, + { + "point_estimate": "mean", + "reference_values": {"mu": 0, "tau": 0}, + "reference_values_kwargs": {"line_color": "blue"}, + }, ], ) def test_plot_pair(models, kwargs): diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index 60facbe797..17b7d7da22 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -155,8 +155,12 @@ def test_plot_trace(models, kwargs): assert axes.shape -@pytest.mark.parametrize("compact", [True, False]) -@pytest.mark.parametrize("combined", [True, False]) +@pytest.mark.parametrize( + "compact", [True, False], +) +@pytest.mark.parametrize( + "combined", [True, False], +) def test_plot_trace_legend(compact, combined): idata = load_arviz_data("rugby") axes = plot_trace( @@ -412,6 +416,11 @@ def test_plot_kde_inference_data(models): "hexbin_kwargs": {"cmap": "viridis"}, "textsize": 20, }, + { + "point_estimate": "mean", + "reference_values": {"mu": 0, "tau": 0}, + "reference_values_kwargs": {"c": "C", "marker": "*"}, + }, ], ) def test_plot_pair(models, kwargs):