Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New argument for plotting true values in pairplots #1140

Merged
merged 4 commits into from
Apr 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
## v0.x.x Unreleased

### New features
* Add `true_values` argument for `plot_pair`. It allows for a scatter plot showing the true values of the variables #1140
* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds #1079
* Add out-of-sample groups (`predictions` and `predictions_constant_data`) and `constant_data` group to pyro translation #1090
* Add `num_chains` and `pred_dims` arguments to io_pyro #1090
* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079)
* Allow xarray.Dataarray input for plots.(#1120)
* Revamped the `hpd` function to make it work with mutidimensional arrays, InferenceData and xarray objects (#1117)
* Skip test for optional/extra dependencies when not installed (#1113)
Expand Down Expand Up @@ -217,3 +218,4 @@
## v0.3.0 (2018 Dec 14)

* First Beta Release

41 changes: 41 additions & 0 deletions arviz/plots/backends/bokeh/pairplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def plot_pair(
marginal_kwargs,
point_estimate,
point_estimate_kwargs,
reference_values,
reference_values_kwargs,
show,
):
"""Bokeh pair plot."""
Expand All @@ -56,6 +58,38 @@ def plot_pair(
kde_kwargs["contour_kwargs"].setdefault("line_color", "black")
kde_kwargs["contour_kwargs"].setdefault("line_alpha", 1)

if reference_values:
reference_values_copy = {}
label = []
for variable in list(reference_values.keys()):
if " " in variable:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why there is this step?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! I was trying to figure out a way so that the user can specify variable names in an easy manner. For example, simply writing "theta Choate" instead of "theta\nChoate"

variable_copy = variable.replace(" ", "\n", 1)
else:
variable_copy = variable

label.append(variable_copy)
reference_values_copy[variable_copy] = reference_values[variable]

difference = set(flat_var_names).difference(set(label))

for dif in difference:
reference_values_copy[dif] = None

if difference:
warn = [dif.replace("\n", " ", 1) for dif in difference]
warnings.warn(
"Argument reference_values does not include reference value for: {}".format(
", ".join(warn)
),
UserWarning,
)

if reference_values_kwargs is None:
reference_values_kwargs = {}

reference_values_kwargs.setdefault("line_color", "red")
reference_values_kwargs.setdefault("line_width", 5)

dpi = backend_kwargs.pop("dpi")
max_plots = (
numvars ** 2 if rcParams["plot.max_subplots"] is None else rcParams["plot.max_subplots"]
Expand Down Expand Up @@ -150,6 +184,7 @@ def get_width_and_height(jointplot, rotate):
ax = np.array(ax)
else:
assert ax.shape == (numvars - var, numvars - var)

# pylint: disable=too-many-nested-blocks
for i in range(0, numvars - var):

Expand Down Expand Up @@ -268,6 +303,12 @@ def get_width_and_height(jointplot, rotate):
)
ax[-1, -1].add_layout(ax_pe_hline)

if reference_values:
x = reference_values_copy[flat_var_names[j + var]]
y = reference_values_copy[flat_var_names[i]]
if x and y:
ax[j, i].circle(y, x, **reference_values_kwargs)

ax[j, i].xaxis.axis_label = flat_var_names[i]
ax[j, i].yaxis.axis_label = flat_var_names[j + var]

Expand Down
47 changes: 47 additions & 0 deletions arviz/plots/backends/matplotlib/pairplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def plot_pair(
point_estimate,
point_estimate_kwargs,
point_estimate_marker_kwargs,
reference_values,
reference_values_kwargs,
):
"""Matplotlib pairplot."""
if backend_kwargs is None:
Expand All @@ -60,6 +62,35 @@ def plot_pair(
kde_kwargs.setdefault("contour_kwargs", {})
kde_kwargs["contour_kwargs"].setdefault("colors", "k")

if reference_values:
reference_values_copy = {}
label = []
for variable in list(reference_values.keys()):
if " " in variable:
variable_copy = variable.replace(" ", "\n", 1)
else:
variable_copy = variable

label.append(variable_copy)
reference_values_copy[variable_copy] = reference_values[variable]

difference = set(flat_var_names).difference(set(label))

if difference:
warn = [dif.replace("\n", " ", 1) for dif in difference]
warnings.warn(
"Argument reference_values does not include reference value for: {}".format(
", ".join(warn)
),
UserWarning,
)

if reference_values_kwargs is None:
reference_values_kwargs = {}

reference_values_kwargs.setdefault("color", "C3")
reference_values_kwargs.setdefault("marker", "o")

# pylint: disable=too-many-nested-blocks
if numvars == 2:
(figsize, ax_labelsize, _, xt_labelsize, linewidth, _) = _scale_fig_size(
Expand Down Expand Up @@ -140,6 +171,12 @@ def plot_pair(

ax.scatter(pe_x, pe_y, marker="s", s=figsize[0] + 50, **point_estimate_kwargs, zorder=4)

if reference_values:
ax.plot(
reference_values_copy[flat_var_names[0]],
reference_values_copy[flat_var_names[1]],
**reference_values_kwargs,
)
ax.set_xlabel("{}".format(flat_var_names[0]), fontsize=ax_labelsize, wrap=True)
ax.set_ylabel("{}".format(flat_var_names[1]), fontsize=ax_labelsize, wrap=True)
ax.tick_params(labelsize=xt_labelsize)
Expand Down Expand Up @@ -229,6 +266,16 @@ def plot_pair(
pe_x, pe_y, s=figsize[0] + 50, zorder=4, **point_estimate_marker_kwargs
)

if reference_values:
x_name = flat_var_names[i]
y_name = flat_var_names[j]
if x_name and y_name not in difference:
ax[j, i].plot(
reference_values_copy[x_name],
reference_values_copy[y_name],
**reference_values_kwargs,
)

if j != numvars - 1:
ax[j, i].axes.get_xaxis().set_major_formatter(NullFormatter())
else:
Expand Down
11 changes: 10 additions & 1 deletion arviz/plots/pairplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def plot_pair(
textsize=None,
kind="scatter",
gridsize="auto",
contour=False,
contour=True,
plot_kwargs=None,
fill_last=False,
divergences=False,
Expand All @@ -34,6 +34,8 @@ def plot_pair(
point_estimate=None,
point_estimate_kwargs=None,
point_estimate_marker_kwargs=None,
reference_values=None,
reference_values_kwargs=None,
show=None,
):
"""
Expand Down Expand Up @@ -99,6 +101,11 @@ def plot_pair(
Additional keywords passed to ax.vline, ax.hline (matplotlib) or ax.square, Span (bokeh)
point_estimate_marker_kwargs: dict, optional
Additional keywords passed to ax.scatter in point estimate plot. Not available in bokeh
reference_values : dict, optional
Reference values for the plotted variables. The Reference values will be plotted
using a scatter marker
reference_values_kwargs : dict, optional
Additional keywords passed to ax.plot or ax.circle in reference values plot
show : bool, optional
Call backend show function.

Expand Down Expand Up @@ -276,6 +283,8 @@ def plot_pair(
point_estimate=point_estimate,
point_estimate_kwargs=point_estimate_kwargs,
point_estimate_marker_kwargs=point_estimate_marker_kwargs,
reference_values=reference_values,
reference_values_kwargs=reference_values_kwargs,
)

if backend is None:
Expand Down
5 changes: 5 additions & 0 deletions arviz/tests/base_tests/test_plots_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,11 @@ def test_plot_mcse_no_divergences(models):
"coords": {"theta_dim_0": [0, 1]},
"textsize": 20,
},
{
"point_estimate": "mean",
"reference_values": {"mu": 0, "tau": 0},
"reference_values_kwargs": {"line_color": "blue"},
},
],
)
def test_plot_pair(models, kwargs):
Expand Down
13 changes: 11 additions & 2 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,12 @@ def test_plot_trace(models, kwargs):
assert axes.shape


@pytest.mark.parametrize("compact", [True, False])
@pytest.mark.parametrize("combined", [True, False])
@pytest.mark.parametrize(
"compact", [True, False],
)
@pytest.mark.parametrize(
"combined", [True, False],
)
def test_plot_trace_legend(compact, combined):
idata = load_arviz_data("rugby")
axes = plot_trace(
Expand Down Expand Up @@ -412,6 +416,11 @@ def test_plot_kde_inference_data(models):
"hexbin_kwargs": {"cmap": "viridis"},
"textsize": 20,
},
{
"point_estimate": "mean",
"reference_values": {"mu": 0, "tau": 0},
"reference_values_kwargs": {"c": "C", "marker": "*"},
},
],
)
def test_plot_pair(models, kwargs):
Expand Down