Skip to content

Commit

Permalink
Support multichain ppc in plot_ppc (#526)
Browse files Browse the repository at this point in the history
* support multichains and add jitter

* redefine jitter scale

* reshape

* fix stuff

* optional legend in kdeplot

* redefine jitter so limit works as no jitter

* fix dtype issue

* add tests for multichain, multidimensional and data_pairs

* add tests

* fix flatten bug for pp with data_pairs

* np.all

* flatten_pp defaults

* default jitter to None
  • Loading branch information
ahartikainen authored Jan 15, 2019
1 parent 836d9b9 commit dcc1636
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 34 deletions.
5 changes: 4 additions & 1 deletion arviz/plots/kdeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def plot_kde(
rug_kwargs=None,
contour_kwargs=None,
ax=None,
legend=True,
):
"""1D or 2D KDE plot taking into account boundary conditions.
Expand Down Expand Up @@ -70,6 +71,8 @@ def plot_kde(
contour_kwargs : dict
Keywords passed to the contourplot. Ignored for 1D KDE
ax : matplotlib axes
legend : bool
Add legend to the figure. By default True.
Returns
-------
Expand Down Expand Up @@ -198,7 +201,7 @@ def plot_kde(
ax.plot(x, density, label=label, **plot_kwargs)
fill_func(fill_x, fill_y, **fill_kwargs)

if label:
if legend and label:
ax.legend()
else:
if contour_kwargs is None:
Expand Down
105 changes: 75 additions & 30 deletions arviz/plots/ppcplot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Posterior predictive plot."""
from numbers import Integral
import numpy as np
from .kdeplot import plot_kde, _fast_kde
from .plot_utils import (
Expand All @@ -14,16 +15,19 @@
def plot_ppc(
data,
kind="density",
alpha=0.2,
alpha=None,
mean=True,
figsize=None,
textsize=None,
data_pairs=None,
var_names=None,
coords=None,
flatten=None,
flatten_pp=None,
num_pp_samples=None,
random_seed=None,
jitter=None,
legend=True,
):
"""
Plot for Posterior Predictive checks.
Expand All @@ -36,7 +40,8 @@ def plot_ppc(
kind : str
Type of plot to display (density, cumulative, or scatter). Defaults to density.
alpha : float
Opacity of posterior predictive density curves. Defaults to 0.2.
Opacity of posterior predictive density curves. Defaults to 0.2 for kind = density
and cumulative, for scatter defaults to 0.7
mean : bool
Whether or not to plot the mean posterior predictive distribution. Defaults to True
figsize : tuple
Expand All @@ -63,17 +68,28 @@ def plot_ppc(
that dimension. Defaults to including all coordinates for all
dimensions if None.
flatten : list
List of dimensions to flatten. Only flattens across the coordinates
List of dimensions to flatten in observed_data. Only flattens across the coordinates
specified in the coords argument. Defaults to flattening all of the dimensions.
flatten_pp : list
List of dimensions to flatten in posterior_predictive. Only flattens across the coordinates
specified in the coords argument. Defaults to flattening all of the dimensions.
Dimensions should match flatten excluding dimensions for data_pairs parameters.
If flatten is defined and flatten_pp is None, then `flatten_pp=flatten`.
num_pp_samples : int
The number of posterior predictive samples to plot.
It defaults to a maximum of 30 samples for `kind` = 'scatter'.
It defaults to a maximum of 5 samples for `kind` = 'scatter' and
will set jitter to 0.7 unless defined otherwise.
Otherwise it defaults to all provided samples.
random_seed : int
Random number generator seed passed to numpy.random.seed to allow
reproducibility of the plot. By default, no seed will be provided
and the plot will change each call if a random sample is specified
by `num_pp_samples`.
jitter : float
If kind is "scatter", jitter will add random uniform noise to the height
of the ppc samples and observed data. By default 0.
legend : bool
Add legend to figure. By default True.
Returns
-------
Expand Down Expand Up @@ -133,6 +149,16 @@ def plot_ppc(
if data_pairs is None:
data_pairs = {}

if alpha is None:
if kind.lower() == "scatter":
alpha = 0.7
else:
alpha = 0.2

if jitter is None:
jitter = 0.0
assert jitter >= 0.0

observed = data.observed_data
posterior_predictive = data.posterior_predictive

Expand All @@ -141,6 +167,10 @@ def plot_ppc(
var_names = _var_names(var_names)
pp_var_names = [data_pairs.get(var, var) for var in var_names]

if flatten_pp is None and flatten is None:
flatten_pp = list(posterior_predictive.dims.keys())
elif flatten_pp is None:
flatten_pp = flatten
if flatten is None:
flatten = list(observed.dims.keys())

Expand All @@ -153,14 +183,14 @@ def plot_ppc(
total_pp_samples = posterior_predictive.sizes["chain"] * posterior_predictive.sizes["draw"]
if num_pp_samples is None:
if kind == "scatter":
num_pp_samples = min(30, total_pp_samples)
num_pp_samples = min(5, total_pp_samples)
else:
num_pp_samples = total_pp_samples

if (
not isinstance(num_pp_samples, int)
or not num_pp_samples >= 1
or not num_pp_samples <= total_pp_samples
not isinstance(num_pp_samples, Integral)
or num_pp_samples < 1
or num_pp_samples > total_pp_samples
):
raise TypeError(
"`num_pp_samples` must be an integer between 1 and "
Expand All @@ -181,7 +211,7 @@ def plot_ppc(
xarray_var_iter(
posterior_predictive.isel(coords),
var_names=pp_var_names,
skip_dims=set(flatten),
skip_dims=set(flatten_pp),
combined=True,
)
)
Expand All @@ -201,9 +231,7 @@ def plot_ppc(

# flatten non-specified dimensions
obs_vals = obs_vals.flatten()
pp_vals = pp_vals.squeeze()
if len(pp_vals.shape) > 2:
pp_vals = pp_vals.reshape((pp_vals.shape[0], np.prod(pp_vals.shape[1:])))
pp_vals = pp_vals.reshape(total_pp_samples, -1)
pp_sampled_vals = pp_vals[pp_sample_ix]

if kind == "density":
Expand All @@ -214,6 +242,7 @@ def plot_ppc(
plot_kwargs={"color": "k", "linewidth": linewidth, "zorder": 3},
fill_kwargs={"alpha": 0},
ax=ax,
legend=legend,
)
else:
nbins = round(len(obs_vals) ** 0.5)
Expand Down Expand Up @@ -258,6 +287,7 @@ def plot_ppc(
},
label="Posterior predictive mean {}".format(pp_var_name),
ax=ax,
legend=legend,
)
else:
vals = pp_vals.flatten()
Expand Down Expand Up @@ -333,15 +363,6 @@ def plot_ppc(
ax.set_yticks([0, 0.5, 1])

elif kind == "scatter":
ax.plot(
obs_vals,
np.zeros_like(obs_vals),
"o",
color="C0",
markersize=markersize,
label="Observed {}".format(var_name),
)

if mean:
if dtype == "f":
plot_kde(
Expand All @@ -350,10 +371,11 @@ def plot_ppc(
"color": "C0",
"linestyle": "--",
"linewidth": linewidth,
"zorder": 2,
"zorder": 3,
},
label="Posterior predictive mean {}".format(pp_var_name),
ax=ax,
legend=legend,
)
else:
vals = pp_vals.flatten()
Expand All @@ -366,16 +388,38 @@ def plot_ppc(
color="C0",
linewidth=linewidth,
label="Posterior predictive mean {}".format(pp_var_name),
zorder=2,
zorder=3,
linestyle="--",
drawstyle="steps-pre",
)

limit = ax.get_ylim()[1] * 1.05
y_rows = np.linspace(0, limit, num_pp_samples)
_, limit = ax.get_ylim()
limit *= 1.05
y_rows = np.linspace(0, limit, num_pp_samples + 1)
jitter_scale = y_rows[1] - y_rows[0]
scale_low = 0
scale_high = jitter_scale * jitter

obs_yvals = np.zeros_like(obs_vals, dtype=np.float64)
if jitter:
obs_yvals += np.random.uniform(low=scale_low, high=scale_high, size=len(obs_vals))
ax.plot(
obs_vals,
obs_yvals,
"o",
color="C0",
markersize=markersize,
alpha=alpha,
label="Observed {}".format(var_name),
zorder=4,
)

for vals, y in zip(pp_sampled_vals, y_rows[1:]):
vals = np.array([vals]).flatten()
ax.plot(vals, [y] * len(vals), "o", zorder=1, color="C5", markersize=markersize)
yvals = np.full_like(vals, y, dtype=np.float64)
if jitter:
yvals += np.random.uniform(low=scale_low, high=scale_high, size=len(vals))
ax.plot(vals, yvals, "o", zorder=2, color="C5", markersize=markersize, alpha=alpha)
ax.scatter([], [], color="C5", label="Posterior predictive {}".format(pp_var_name))

ax.set_yticks([])
Expand All @@ -386,10 +430,11 @@ def plot_ppc(
xlabel = var_name
ax.set_xlabel(make_label(xlabel, selection), fontsize=ax_labelsize)

if i == 0:
ax.legend(fontsize=xt_labelsize)
else:
ax.legend([])
if legend:
if i == 0:
ax.legend(fontsize=xt_labelsize)
else:
ax.legend([])

return axes

Expand Down
22 changes: 19 additions & 3 deletions arviz/tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pymc3 as pm


from ..data import from_pymc3, InferenceData
from ..data import from_dict, from_pymc3, InferenceData
from ..stats import compare
from .helpers import eight_schools_params, load_cached_models # pylint: disable=unused-import
from ..plots import (
Expand Down Expand Up @@ -365,12 +365,28 @@ def test_plot_pair_bad(models, model_fit):


@pytest.mark.parametrize("kind", ["density", "cumulative", "scatter"])
def test_plot_ppc(models, pymc3_sample_ppc, kind):
@pytest.mark.parametrize("alpha", [None, 0.2, 1])
def test_plot_ppc(models, pymc3_sample_ppc, kind, alpha):
data = from_pymc3(trace=models.pymc3_fit, posterior_predictive=pymc3_sample_ppc)
axes = plot_ppc(data, kind=kind, random_seed=3)
axes = plot_ppc(data, kind=kind, alpha=alpha, random_seed=3)
assert axes


@pytest.mark.parametrize("kind", ["density", "cumulative", "scatter"])
@pytest.mark.parametrize("jitter", [None, 0, 0.1, 1, 3])
def test_plot_ppc_multichain(kind, jitter):
np.random.seed(23)
data = from_dict(
posterior_predictive={
"x": np.random.randn(4, 100, 30),
"y_hat": np.random.randn(4, 100, 3, 10),
},
observed_data={"x": np.random.randn(30), "y": np.random.randn(3, 10)},
)
axes = plot_ppc(data, kind=kind, data_pairs={"y": "y_hat"}, jitter=jitter, random_seed=3)
assert np.all(axes)


@pytest.mark.parametrize("kind", ["density", "cumulative", "scatter"])
def test_plot_ppc_discrete(kind):
data = MagicMock(spec=InferenceData)
Expand Down

0 comments on commit dcc1636

Please sign in to comment.