Skip to content

Commit

Permalink
add observed_rug argument to plot_ppc (#2161)
Browse files Browse the repository at this point in the history
* add observed_rug argument to plot_ppc

* add tests and update changelog
  • Loading branch information
aloctavodia authored Nov 14, 2022
1 parent 24e66c3 commit 6e8c8ab
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 3 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
## v0.x.x Unreleased

### New features
* Add `weight_predictions` function to allow generation of weighted predictions from two or more InfereceData with `posterior_predictive` groups and a set of weights ([2147](https://github.com/arviz-devs/arviz/pull/2147))
- Add `weight_predictions` function to allow generation of weighted predictions from two or more InfereceData with `posterior_predictive` groups and a set of weights ([2147](https://github.com/arviz-devs/arviz/pull/2147))
- Add Savage-Dickey density ratio plot for Bayes factor approximation. ([2037](https://github.com/arviz-devs/arviz/pull/2037), [2152](https://github.com/arviz-devs/arviz/pull/2152
- Adds rug plot for observed variables to `plot_ppc`. ([2161](https://github.com/arviz-devs/arviz/pull/2161))

### Maintenance and fixes
- Fix dimension ordering for `plot_trace` with divergences ([2151](https://github.com/arviz-devs/arviz/pull/2151))
Expand Down
28 changes: 28 additions & 0 deletions arviz/plots/backends/bokeh/ppcplot.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""Bokeh Posterior predictive plot."""
import numpy as np
from bokeh.models.annotations import Legend
from bokeh.models.glyphs import Scatter
from bokeh.models import ColumnDataSource


from ....stats.density_utils import get_bins, histogram, kde
from ...kdeplot import plot_kde
from ...plot_utils import _scale_fig_size, vectorized_to_hex


from .. import show_layout
from . import backend_kwarg_defaults, create_axes_grid

Expand All @@ -27,6 +31,7 @@ def plot_ppc(
textsize,
mean,
observed,
observed_rug,
jitter,
total_pp_samples,
legend, # pylint: disable=unused-argument
Expand Down Expand Up @@ -97,6 +102,7 @@ def plot_ppc(
obs_vals = obs_vals.flatten()
pp_vals = pp_vals.reshape(total_pp_samples, -1)
pp_sampled_vals = pp_vals[pp_sample_ix]
cds_rug = ColumnDataSource({"_": np.array(obs_vals)})

if kind == "kde":
plot_kwargs = {
Expand Down Expand Up @@ -144,6 +150,16 @@ def plot_ppc(
return_glyph=True,
)
legend_it.append((label, glyph))
if observed_rug:
glyph = Scatter(
x="_",
y=0.0,
marker="dash",
angle=np.pi / 2,
line_color=colors[1],
line_width=linewidth,
)
ax_i.add_glyph(cds_rug, glyph)
else:
bins = get_bins(obs_vals)
_, hist, bin_edges = histogram(obs_vals, bins=bins)
Expand Down Expand Up @@ -215,6 +231,18 @@ def plot_ppc(
mode="center",
)
legend_it.append((label, [step]))

if observed_rug:
glyph = Scatter(
x="_",
y=0.0,
marker="dash",
angle=np.pi / 2,
line_color=colors[1],
line_width=linewidth,
)
ax_i.add_glyph(cds_rug, glyph)

pp_densities = np.empty((2 * len(pp_sampled_vals), pp_sampled_vals[0].size))
for idx, vals in enumerate(pp_sampled_vals):
vals = np.array([vals]).flatten()
Expand Down
10 changes: 10 additions & 0 deletions arviz/plots/backends/matplotlib/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def plot_ppc(
textsize,
mean,
observed,
observed_rug,
jitter,
total_pp_samples,
legend,
Expand Down Expand Up @@ -135,6 +136,7 @@ def plot_ppc(
if dtype == "f":
plot_kde(
obs_vals,
rug=observed_rug,
label="Observed",
plot_kwargs={"color": colors[1], "linewidth": linewidth, "zorder": 3},
fill_kwargs={"alpha": 0},
Expand Down Expand Up @@ -232,6 +234,14 @@ def plot_ppc(
drawstyle=drawstyle,
zorder=3,
)
if observed_rug:
ax_i.plot(
obs_vals,
np.zeros_like(obs_vals) - 0.1,
ls="",
marker="|",
color=colors[1],
)
if animated:
animate, init = _set_animation(
pp_sampled_vals,
Expand Down
5 changes: 5 additions & 0 deletions arviz/plots/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def plot_ppc(
alpha=None,
mean=True,
observed=True,
observed_rug=False,
color=None,
colors=None,
grid=None,
Expand Down Expand Up @@ -62,6 +63,9 @@ def plot_ppc(
Defaults to ``True``.
observed: bool, default True
Whether or not to plot the observed data.
observed: bool, default False
Whether or not to plot a rug plot for the observed data. Only valid if `observed` is
`True` and for kind `kde` or `cumulative`.
color: str
Valid matplotlib ``color``. Defaults to ``C0``.
color: list
Expand Down Expand Up @@ -339,6 +343,7 @@ def plot_ppc(
textsize=textsize,
mean=mean,
observed=observed,
observed_rug=observed_rug,
total_pp_samples=total_pp_samples,
legend=legend,
labeller=labeller,
Expand Down
4 changes: 3 additions & 1 deletion arviz/tests/base_tests/test_plots_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,12 +878,14 @@ def test_plot_violin_discrete(discrete_model):
@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
@pytest.mark.parametrize("alpha", [None, 0.2, 1])
@pytest.mark.parametrize("observed", [True, False])
def test_plot_ppc(models, kind, alpha, observed):
@pytest.mark.parametrize("observed_rug", [False, True])
def test_plot_ppc(models, kind, alpha, observed, observed_rug):
axes = plot_ppc(
models.model_1,
kind=kind,
alpha=alpha,
observed=observed,
observed_rug=observed_rug,
random_seed=3,
backend="bokeh",
show=False,
Expand Down
4 changes: 3 additions & 1 deletion arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,8 @@ def test_plot_pair_shared(sharex, sharey, marginals):
@pytest.mark.parametrize("alpha", [None, 0.2, 1])
@pytest.mark.parametrize("animated", [False, True])
@pytest.mark.parametrize("observed", [True, False])
def test_plot_ppc(models, kind, alpha, animated, observed):
@pytest.mark.parametrize("observed_rug", [False, True])
def test_plot_ppc(models, kind, alpha, animated, observed, observed_rug):
if animation and not animation.writers.is_available("ffmpeg"):
pytest.skip("matplotlib animations within ArviZ require ffmpeg")
animation_kwargs = {"blit": False}
Expand All @@ -720,6 +721,7 @@ def test_plot_ppc(models, kind, alpha, animated, observed):
kind=kind,
alpha=alpha,
observed=observed,
observed_rug=observed_rug,
animated=animated,
animation_kwargs=animation_kwargs,
random_seed=3,
Expand Down

0 comments on commit 6e8c8ab

Please sign in to comment.