Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify compare() docstring, error-check, pass var_name. #1616

Merged
merged 19 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ function-naming-style=snake_case

# Good variable names which should always be accepted, separated by a comma
good-names=b,
e,
i,
j,
k,
Expand All @@ -265,7 +266,7 @@ good-names=b,
ok,
sd,
tr,
eta,
eta,
Run,
_log,
_
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
* Added interactive legend to bokeh `forestplot` ([1591](https://github.com/arviz-devs/arviz/pull/1591))
* Added interactive legend to bokeh `ppcplot` ([1602](https://github.com/arviz-devs/arviz/pull/1602))
* Added `data.log_likelihood`, `stats.ic_compare_method` and `plot.density_kind` to `rcParams` ([1611](https://github.com/arviz-devs/arviz/pull/1611))
* Improve error messages in `stats.compare()`, and `var_name` parameter. ([1616](https://github.com/arviz-devs/arviz/pull/1616))

### Maintenance and fixes
* Enforced using coordinate values as default labels ([1201](https://github.com/arviz-devs/arviz/pull/1201))
Expand Down
20 changes: 13 additions & 7 deletions arviz/data/io_pymc3.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def __init__(
# this permits us to get the model from command-line argument or from with model:
try:
self.model = self.pymc3.modelcontext(model or self.model)
except TypeError:
except TypeError as e:
_log.error("Got error %s trying to find log_likelihood in translation.", e)
rpgoldman marked this conversation as resolved.
Show resolved Hide resolved
self.model = None

if self.model is None:
Expand Down Expand Up @@ -251,12 +252,17 @@ def _extract_log_likelihood(self, trace):
"`pip install pymc3>=3.8` or `conda install -c conda-forge pymc3>=3.8`."
) from err
for var, log_like_fun in cached:
for k, chain in enumerate(trace.chains):
log_like_chain = [
self.log_likelihood_vals_point(point, var, log_like_fun)
for point in trace.points([chain])
]
log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k)
try:
for k, chain in enumerate(trace.chains):
log_like_chain = [
self.log_likelihood_vals_point(point, var, log_like_fun)
for point in trace.points([chain])
]
log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k)
except TypeError as e:
raise TypeError(
*tuple(["While computing log-likelihood for {var}: "] + list(e.args))
) from e
return log_likelihood_dict.trace_dict

@requires("trace")
Expand Down
23 changes: 21 additions & 2 deletions arviz/rcparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,22 @@
from collections.abc import MutableMapping
from pathlib import Path
from typing import Any, Dict
from typing_extensions import Literal

NO_GET_ARGS: bool = False
try:
from typing_extensions import get_args
except ImportError:
NO_GET_ARGS = True
rpgoldman marked this conversation as resolved.
Show resolved Hide resolved


import numpy as np

_log = logging.getLogger(__name__)

ScaleKeyword = Literal["log", "negative_log", "deviance"]
ICKeyword = Literal["loo", "waic"]


def _make_validate_choice(accepted_values, allow_none=False, typeof=str):
"""Validate value is in accepted_values.
Expand Down Expand Up @@ -277,9 +288,17 @@ def validate_iterable(value):
"plot.matplotlib.constrained_layout": (True, _validate_boolean),
"plot.matplotlib.show": (False, _validate_boolean),
"stats.hdi_prob": (0.94, _validate_probability),
"stats.information_criterion": ("loo", _make_validate_choice({"waic", "loo"})),
"stats.information_criterion": (
"loo",
_make_validate_choice({"loo", "waic"} if NO_GET_ARGS else set(get_args(ICKeyword))),
),
"stats.ic_pointwise": (False, _validate_boolean),
"stats.ic_scale": ("log", _make_validate_choice({"deviance", "log", "negative_log"})),
"stats.ic_scale": (
"log",
_make_validate_choice(
{"log", "negative_log", "deviance"} if NO_GET_ARGS else set(get_args(ScaleKeyword))
),
),
"stats.ic_compare_method": (
"stacking",
_make_validate_choice({"stacking", "bb-pseudo-bma", "pseudo-bma"}),
Expand Down
75 changes: 54 additions & 21 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,24 @@
"""Statistical functions in ArviZ."""
import warnings
from copy import deepcopy
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Mapping, cast, Callable

import numpy as np
import pandas as pd
import scipy.stats as st
import xarray as xr
from scipy.optimize import minimize
from typing_extensions import Literal

NO_GET_ARGS: bool = False
try:
from typing_extensions import get_args
except ImportError:
NO_GET_ARGS = True

from arviz import _log
from ..data import InferenceData, convert_to_dataset, convert_to_inference_data
from ..rcparams import rcParams
from ..rcparams import rcParams, ScaleKeyword, ICKeyword
from ..utils import Numba, _numba_var, _var_names, get_coords
from .density_utils import get_bins as _get_bins
from .density_utils import histogram as _histogram
Expand All @@ -27,9 +34,6 @@
from ..sel_utils import xarray_var_iter
from ..labels import BaseLabeller

if TYPE_CHECKING:
from typing_extensions import Literal


__all__ = [
"apply_test_function",
Expand All @@ -45,7 +49,14 @@


def compare(
dataset_dict, ic=None, method="stacking", b_samples=1000, alpha=1, seed=None, scale=None
dataset_dict: Mapping[str, InferenceData],
ic: Optional[ICKeyword] = None,
method: Literal["stacking", "BB-pseudo-BMA", "pseudo-MA"] = "stacking",
b_samples: int = 1000,
alpha: float = 1,
seed=None,
scale: Optional[ScaleKeyword] = None,
var_name: Optional[str] = None,
):
r"""Compare models based on PSIS-LOO `loo` or WAIC `waic` cross-validation.
Expand All @@ -58,10 +69,10 @@ def compare(
----------
dataset_dict: dict[str] -> InferenceData
A dictionary of model names and InferenceData objects
ic: str
ic: str, optional
Information Criterion (PSIS-LOO `loo` or WAIC `waic`) used to compare models. Defaults to
``rcParams["stats.information_criterion"]``.
method: str
method: str, optional
Method used to estimate the weights for each model. Available options are:
- 'stacking' : stacking of predictive distributions.
Expand All @@ -71,19 +82,19 @@ def compare(
weighting, without Bootstrap stabilization (not recommended).
For more information read https://arxiv.org/abs/1704.02030
Defaults to ``rcParams["stats.ic_compare_method"]``.
b_samples: int
b_samples: int, optional default = 1000
Number of samples taken by the Bayesian bootstrap estimation.
Only useful when method = 'BB-pseudo-BMA'.
alpha: float
Defaults to ``rcParams["stats.ic_compare_method"]``.
alpha: float, optional
The shape parameter in the Dirichlet distribution used for the Bayesian bootstrap. Only
useful when method = 'BB-pseudo-BMA'. When alpha=1 (default), the distribution is uniform
on the simplex. A smaller alpha will keeps the final weights more away from 0 and 1.
seed: int or np.random.RandomState instance
seed: int or np.random.RandomState instance, optional
If int or RandomState, use it for seeding Bayesian bootstrap. Only
useful when method = 'BB-pseudo-BMA'. Default None the global
np.random state is used.
scale: str
scale: str, optional
Output scale for IC. Available options are:
- `log` : (default) log-score (after Vehtari et al. (2017))
Expand All @@ -92,6 +103,9 @@ def compare(
A higher log-score (or a lower deviance) indicates a model with better predictive
accuracy.
var_name: str, optional
If there is more than a single observed variable in the ``InferenceData``, which
should be used as the basis for comparison.
Returns
-------
Expand Down Expand Up @@ -142,10 +156,16 @@ def compare(
--------
loo : Compute the Pareto Smoothed importance sampling Leave One Out cross-validation.
waic : Compute the widely applicable information criterion.
"""
names = list(dataset_dict.keys())
scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()
if scale is not None:
scale = cast(ScaleKeyword, scale.lower())
else:
scale = cast(ScaleKeyword, rcParams["stats.ic_scale"])
allowable = ["log", "negative_log", "deviance"] if NO_GET_ARGS else get_args(ScaleKeyword)
if scale not in allowable:
raise ValueError(f"{scale} is not a valid value for scale: must be in {allowable}")

if scale == "log":
scale_value = 1
ascending = False
Expand All @@ -156,9 +176,15 @@ def compare(
scale_value = -2
ascending = True

ic = rcParams["stats.information_criterion"] if ic is None else ic.lower()
if ic is None:
ic = cast(ICKeyword, rcParams["stats.information_criterion"])
else:
ic = cast(ICKeyword, ic.lower())
allowable = ["loo", "waic"] if NO_GET_ARGS else get_args(ICKeyword)
if ic not in allowable:
raise ValueError(f"{ic} is not a valid value for ic: must be in {allowable}")
if ic == "loo":
ic_func = loo
ic_func: Callable = loo
df_comp = pd.DataFrame(
index=names,
columns=[
Expand All @@ -172,7 +198,7 @@ def compare(
"warning",
"loo_scale",
],
dtype=np.float,
dtype=np.float_,
)
scale_col = "loo_scale"
elif ic == "waic":
Expand All @@ -190,7 +216,7 @@ def compare(
"warning",
"waic_scale",
],
dtype=np.float,
dtype=np.float_,
)
scale_col = "waic_scale"
else:
Expand All @@ -208,7 +234,12 @@ def compare(
names = []
for name, dataset in dataset_dict.items():
names.append(name)
ics = ics.append([ic_func(dataset, pointwise=True, scale=scale)])
try:
# Here is where the IC function is actually computed -- the rest of this
# function is argument processing and return value formatting
ics = ics.append([ic_func(dataset, pointwise=True, scale=scale, var_name=var_name)])
except Exception as e:
raise e.__class__(f"Encountered error trying to compute {ic} from model {name}.") from e
ics.index = names
ics.sort_values(by=ic, inplace=True, ascending=ascending)
ics[ic_i] = ics[ic_i].apply(lambda x: x.values.flatten())
Expand Down Expand Up @@ -1280,7 +1311,9 @@ def summary(
n_vars = np.sum([joined[var].size // n_metrics for var in joined.data_vars])

if fmt.lower() == "wide":
summary_df = pd.DataFrame(np.full((n_vars, n_metrics), np.nan), columns=metric_names)
summary_df = pd.DataFrame(
(np.full((cast(int, n_vars), n_metrics), np.nan)), columns=metric_names
)
indexs = []
for i, (var_name, sel, isel, values) in enumerate(
xarray_var_iter(joined, skip_dims={"metric"})
Expand Down
6 changes: 5 additions & 1 deletion arviz/stats/stats_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,11 @@ def not_valid(ary, check_nan=True, check_shape=True, nan_kwargs=None, shape_kwar

def get_log_likelihood(idata, var_name=None):
"""Retrieve the log likelihood dataarray of a given variable."""
if hasattr(idata, "sample_stats") and hasattr(idata.sample_stats, "log_likelihood"):
if (
not hasattr(idata, "log_likelihood")
and hasattr(idata, "sample_stats")
and hasattr(idata.sample_stats, "log_likelihood")
):
warnings.warn(
"Storing the log_likelihood in sample_stats groups has been deprecated",
DeprecationWarning,
Expand Down
33 changes: 32 additions & 1 deletion arviz/tests/base_tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@ def non_centered_eight():
return non_centered_eight


@pytest.fixture(scope="module")
def multivariable_log_likelihood(centered_eight):
centered_eight = centered_eight.copy()
centered_eight.add_groups({"log_likelihood": centered_eight.sample_stats.log_likelihood})
centered_eight.log_likelihood = centered_eight.log_likelihood.rename_vars(
{"log_likelihood": "obs"}
)
new_arr = DataArray(
np.zeros(centered_eight.log_likelihood["obs"].values.shape),
dims=["chain", "draw", "school"],
coords=centered_eight.log_likelihood.coords,
)
centered_eight.log_likelihood["decoy"] = new_arr
delattr(centered_eight, "sample_stats")
return centered_eight


def test_hdp():
normal_sample = np.random.randn(5000000)
interval = hdi(normal_sample)
Expand Down Expand Up @@ -151,7 +168,7 @@ def test_compare_same(centered_eight, multidim_models, method, multidim):

def test_compare_unknown_ic_and_method(centered_eight, non_centered_eight):
model_dict = {"centered": centered_eight, "non_centered": non_centered_eight}
with pytest.raises(NotImplementedError):
with pytest.raises(ValueError):
compare(model_dict, ic="Unknown", method="stacking")
with pytest.raises(ValueError):
compare(model_dict, ic="loo", method="Unknown")
Expand Down Expand Up @@ -192,6 +209,20 @@ def test_compare_different_size(centered_eight, non_centered_eight):
compare(model_dict, ic="waic", method="stacking")


@pytest.mark.parametrize("ic", ["loo", "waic"])
def test_compare_multiple_obs(multivariable_log_likelihood, centered_eight, non_centered_eight, ic):
compare_dict = {
"centered_eight": centered_eight,
"non_centered_eight": non_centered_eight,
"problematic": multivariable_log_likelihood,
}
with pytest.raises(TypeError, match="several log likelihood arrays"):
get_log_likelihood(compare_dict["problematic"])
with pytest.raises(TypeError, match=f"{ic}.*model problematic"):
compare(compare_dict, ic=ic)
assert compare(compare_dict, ic=ic, var_name="obs") is not None


def test_summary_ndarray():
array = np.random.randn(4, 100, 2)
summary_df = summary(array)
Expand Down