Skip to content

Commit

Permalink
Integrate stats.ic_scale rcParam (#993)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shashankjain12 authored and OriolAbril committed Jan 20, 2020
1 parent 4f2afc9 commit 1f4878f
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 52 deletions.
3 changes: 2 additions & 1 deletion arviz/plots/elpdplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def plot_elpd(
threshold=None,
ax=None,
ic=None,
scale="deviance",
scale=None,
plot_kwargs=None,
backend=None,
backend_kwargs=None,
Expand Down Expand Up @@ -109,6 +109,7 @@ def plot_elpd(
"""
valid_ics = ["waic", "loo"]
ic = rcParams["stats.information_criterion"] if ic is None else ic.lower()
scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()
if ic not in valid_ics:
raise ValueError(
("Information Criteria type {} not recognized." "IC must be in {}").format(
Expand Down
28 changes: 12 additions & 16 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,7 @@


def compare(
dataset_dict,
ic=None,
method="BB-pseudo-BMA",
b_samples=1000,
alpha=1,
seed=None,
scale="deviance",
dataset_dict, ic=None, method="BB-pseudo-BMA", b_samples=1000, alpha=1, seed=None, scale=None
):
r"""Compare models based on WAIC or LOO cross-validation.
Expand Down Expand Up @@ -136,7 +130,7 @@ def compare(
"""
names = list(dataset_dict.keys())
scale = scale.lower()
scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()
if scale == "log":
scale_value = 1
ascending = False
Expand Down Expand Up @@ -421,7 +415,7 @@ def hpd(ary, credible_interval=0.94, circular=False, multimodal=False):
return hpd_intervals


def loo(data, pointwise=False, reff=None, scale="deviance"):
def loo(data, pointwise=False, reff=None, scale=None):
"""Pareto-smoothed importance sampling leave-one-out cross-validation.
Calculates leave-one-out (LOO) cross-validation for out of sample predictive model fit,
Expand Down Expand Up @@ -493,12 +487,13 @@ def loo(data, pointwise=False, reff=None, scale="deviance"):
shape = log_likelihood.shape
n_samples = shape[-1]
n_data_points = np.product(shape[:-1])
scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()

if scale.lower() == "deviance":
if scale == "deviance":
scale_value = -2
elif scale.lower() == "log":
elif scale == "log":
scale_value = 1
elif scale.lower() == "negative_log":
elif scale == "negative_log":
scale_value = -1
else:
raise TypeError('Valid scale values are "deviance", "log", "negative_log"')
Expand Down Expand Up @@ -1101,7 +1096,7 @@ def summary(
return summary_df


def waic(data, pointwise=False, scale="deviance"):
def waic(data, pointwise=False, scale=None):
"""Calculate the widely available information criterion.
Also calculates the WAIC's standard error and the effective number of
Expand Down Expand Up @@ -1165,12 +1160,13 @@ def waic(data, pointwise=False, scale="deviance"):
if "log_likelihood" not in inference_data.sample_stats:
raise TypeError("Data must include log_likelihood in sample_stats")
log_likelihood = inference_data.sample_stats.log_likelihood
scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()

if scale.lower() == "deviance":
if scale == "deviance":
scale_value = -2
elif scale.lower() == "log":
elif scale == "log":
scale_value = 1
elif scale.lower() == "negative_log":
elif scale == "negative_log":
scale_value = -1
else:
raise TypeError('Valid scale values are "deviance", "log", "negative_log"')
Expand Down
44 changes: 14 additions & 30 deletions arviz/tests/test_plots_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_plot_kde_1d(continuous_model):
"kwargs",
[
{"contour": True, "fill_last": False},
{"contour": True, "contourf_kwargs": {"cmap": "plasma"},},
{"contour": True, "contourf_kwargs": {"cmap": "plasma"}},
{"contour": False},
{"contour": False, "pcolormesh_kwargs": {"cmap": "plasma"}},
],
Expand Down Expand Up @@ -275,9 +275,7 @@ def test_plot_compare_no_ic(models):
assert "['waic', 'loo']" in str(err.value)


@pytest.mark.parametrize(
"kwargs", [{}, {"ic": "loo"}, {"xlabels": True, "scale": "log"},],
)
@pytest.mark.parametrize("kwargs", [{}, {"ic": "loo"}, {"xlabels": True, "scale": "log"}])
@pytest.mark.parametrize("add_model", [False, True])
@pytest.mark.parametrize("use_elpddata", [False, True])
def test_plot_elpd(models, add_model, use_elpddata, kwargs):
Expand All @@ -300,9 +298,7 @@ def test_plot_elpd(models, add_model, use_elpddata, kwargs):
assert axes.shape[0] == len(model_dict) - 1


@pytest.mark.parametrize(
"kwargs", [{}, {"ic": "loo"}, {"xlabels": True, "scale": "log"},],
)
@pytest.mark.parametrize("kwargs", [{}, {"ic": "loo"}, {"xlabels": True, "scale": "log"}])
@pytest.mark.parametrize("add_model", [False, True])
@pytest.mark.parametrize("use_elpddata", [False, True])
def test_plot_elpd_multidim(multidim_models, add_model, use_elpddata, kwargs):
Expand Down Expand Up @@ -443,9 +439,7 @@ def test_plot_forest(models, model_fits, args_expected):

def test_plot_forest_rope_exception():
with pytest.raises(ValueError) as err:
plot_forest(
{"x": [1]}, rope="not_correct_format", backend="bokeh", show=False,
)
plot_forest({"x": [1]}, rope="not_correct_format", backend="bokeh", show=False)
assert "Argument `rope` must be None, a dictionary like" in str(err.value)


Expand Down Expand Up @@ -508,19 +502,15 @@ def test_plot_joint_discrete(discrete_model):
def test_plot_joint_bad(models):
with pytest.raises(ValueError):
plot_joint(
models.model_1, var_names=("mu", "tau"), kind="bad_kind", backend="bokeh", show=False,
models.model_1, var_names=("mu", "tau"), kind="bad_kind", backend="bokeh", show=False
)

with pytest.raises(Exception):
plot_joint(
models.model_1, var_names=("mu", "tau", "eta"), backend="bokeh", show=False,
)
plot_joint(models.model_1, var_names=("mu", "tau", "eta"), backend="bokeh", show=False)

with pytest.raises(ValueError):
_, axes = list(range(5))
plot_joint(
models.model_1, var_names=("mu", "tau"), ax=axes, backend="bokeh", show=False,
)
plot_joint(models.model_1, var_names=("mu", "tau"), ax=axes, backend="bokeh", show=False)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -614,7 +604,7 @@ def test_plot_loo_pit_incompatible_args(models):
"""Test error when both ecdf and use_hpd are True."""
with pytest.raises(ValueError, match="incompatible"):
plot_loo_pit(
idata=models.model_1, y="y", ecdf=True, use_hpd=True, backend="bokeh", show=False,
idata=models.model_1, y="y", ecdf=True, use_hpd=True, backend="bokeh", show=False
)


Expand Down Expand Up @@ -694,8 +684,8 @@ def test_plot_mcse_no_divergences(models):
@pytest.mark.parametrize(
"kwargs",
[
{"var_names": "theta", "divergences": True, "coords": {"theta_dim_0": [0, 1]},},
{"divergences": True, "var_names": ["theta", "mu"],},
{"var_names": "theta", "divergences": True, "coords": {"theta_dim_0": [0, 1]}},
{"divergences": True, "var_names": ["theta", "mu"]},
{"kind": "kde", "var_names": ["theta"]},
{"kind": "hexbin", "var_names": ["theta"]},
{"kind": "hexbin", "var_names": ["theta"]},
Expand Down Expand Up @@ -760,7 +750,7 @@ def test_plot_parallel_exception(models, var_names):
"""Ensure that correct exception is raised when one variable is passed."""
with pytest.raises(ValueError):
assert plot_parallel(
models.model_1, var_names=var_names, norm_method="foo", backend="bokeh", show=False,
models.model_1, var_names=var_names, norm_method="foo", backend="bokeh", show=False
)


Expand Down Expand Up @@ -860,9 +850,7 @@ def test_plot_ppc_bad(models, kind):
with pytest.raises(TypeError):
plot_ppc(models.model_1, kind="bad_val", backend="bokeh", show=False)
with pytest.raises(TypeError):
plot_ppc(
models.model_1, num_pp_samples="bad_val", backend="bokeh", show=False,
)
plot_ppc(models.model_1, num_pp_samples="bad_val", backend="bokeh", show=False)


@pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
Expand Down Expand Up @@ -913,13 +901,9 @@ def test_plot_posterior_bad(models):
with pytest.raises(ValueError):
plot_posterior(models.model_1, backend="bokeh", show=False, rope="bad_value")
with pytest.raises(ValueError):
plot_posterior(
models.model_1, ref_val="bad_value", backend="bokeh", show=False,
)
plot_posterior(models.model_1, ref_val="bad_value", backend="bokeh", show=False)
with pytest.raises(ValueError):
plot_posterior(
models.model_1, point_estimate="bad_value", backend="bokeh", show=False,
)
plot_posterior(models.model_1, point_estimate="bad_value", backend="bokeh", show=False)


@pytest.mark.parametrize("point_estimate", ("mode", "mean", "median"))
Expand Down
4 changes: 1 addition & 3 deletions examples/bokeh/bokeh_plot_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
ax_poisson = bkp.figure(**figure_kwargs)
ax_normal = bkp.figure(**figure_kwargs)

az.plot_dist(
a, color="black", label="Poisson", ax=ax_poisson, backend="bokeh", show=False,
)
az.plot_dist(a, color="black", label="Poisson", ax=ax_poisson, backend="bokeh", show=False)
az.plot_dist(b, color="red", label="Gaussian", ax=ax_normal, backend="bokeh", show=False)

ax = row(ax_poisson, ax_normal)
Expand Down
2 changes: 1 addition & 1 deletion examples/bokeh/bokeh_plot_loo_pit_ecdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
log_weights = az.psislw(-log_like)[0]

ax = az.plot_loo_pit(
idata, y="y_like", log_weights=log_weights, ecdf=True, color="orange", backend="bokeh",
idata, y="y_like", log_weights=log_weights, ecdf=True, color="orange", backend="bokeh"
)
2 changes: 1 addition & 1 deletion examples/bokeh/bokeh_plot_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@

coords = {"school": ["Choate", "Deerfield"]}
ax = az.plot_pair(
centered, var_names=["theta", "mu", "tau"], coords=coords, divergences=True, backend="bokeh",
centered, var_names=["theta", "mu", "tau"], coords=coords, divergences=True, backend="bokeh"
)

0 comments on commit 1f4878f

Please sign in to comment.