Skip to content

Commit

Permalink
Modify compare() docstring, error-check, pass var_name. (arviz-devs#1616
Browse files Browse the repository at this point in the history
)

* Modify compare() docstring, and var_name param and error-check.

Check to see if the log_likelihood groups have more than one data variable, and if so, require the var_name parameter. This avoids having a less understandable error message crop up later.
Introduce `var_name` parameter, and pass it through to the IC function invoked by compare.

* Fix error in argument test.

* More explicit error message.

Previously, if we got a TypeError when trying to find a log_likelihood
in from_pymc3() that TypeError would be squashed completely. Now we
will echo it to the log before handling it.

* Fix type signature of compare().

Took @OriolAbril correction.

* Annotated IC function errors from compare().

Catch errors from IC functions invoked inside compare and annotate
them with information about the source `InferenceData` object.

* Remove incorrect docstring.

* pylint

* mypy fixes.

* Make "e" acceptable variable name.

Many examples show this as a good variable name for exceptions,
particularly in "except <ExceptionClass> as e:"

* Changelog update.

* Backward-compatibility fix.

* Fix test.

Error now caught sooner.

* Fix mypy issues.

* Python 3.5 and 3.6 compatibility.

* Whitespace issue caught by Oriol.

* Test for error-trapping.

Make sure we check for multiple observed variables in compare() and that support for "var_name" works.

* Don't let sample_stats shadow log_likelihood.

Previously, we checked for `sample_stats` in `get_log_likelihood()` *before* reading `log_likelihood`.  Add a check that `log_likelihood` must be missing before we check `sample_stats`.

* Improvements suggested by OriolAbril.

* Limit recomputation in tests by scoping a fixture.
* Test for expected IC in `compare` test.
* Refine type assertion.
  • Loading branch information
rpgoldman authored and utkarsh-maheshwari committed May 27, 2021
1 parent 066e48d commit 2b4d141
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 33 deletions.
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ function-naming-style=snake_case

# Good variable names which should always be accepted, separated by a comma
good-names=b,
e,
i,
j,
k,
Expand All @@ -265,7 +266,7 @@ good-names=b,
ok,
sd,
tr,
eta,
eta,
Run,
_log,
_
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
* Added interactive legend to bokeh `forestplot` ([1591](https://github.com/arviz-devs/arviz/pull/1591))
* Added interactive legend to bokeh `ppcplot` ([1602](https://github.com/arviz-devs/arviz/pull/1602))
* Added `data.log_likelihood`, `stats.ic_compare_method` and `plot.density_kind` to `rcParams` ([1611](https://github.com/arviz-devs/arviz/pull/1611))
* Improve error messages in `stats.compare()`, and `var_name` parameter. ([1616](https://github.com/arviz-devs/arviz/pull/1616))

### Maintenance and fixes
* Enforced using coordinate values as default labels ([1201](https://github.com/arviz-devs/arviz/pull/1201))
Expand Down
20 changes: 13 additions & 7 deletions arviz/data/io_pymc3.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def __init__(
# this permits us to get the model from command-line argument or from with model:
try:
self.model = self.pymc3.modelcontext(model or self.model)
except TypeError:
except TypeError as e:
_log.error("Got error %s trying to find log_likelihood in translation.", e)
self.model = None

if self.model is None:
Expand Down Expand Up @@ -251,12 +252,17 @@ def _extract_log_likelihood(self, trace):
"`pip install pymc3>=3.8` or `conda install -c conda-forge pymc3>=3.8`."
) from err
for var, log_like_fun in cached:
for k, chain in enumerate(trace.chains):
log_like_chain = [
self.log_likelihood_vals_point(point, var, log_like_fun)
for point in trace.points([chain])
]
log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k)
try:
for k, chain in enumerate(trace.chains):
log_like_chain = [
self.log_likelihood_vals_point(point, var, log_like_fun)
for point in trace.points([chain])
]
log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k)
except TypeError as e:
raise TypeError(
*tuple(["While computing log-likelihood for {var}: "] + list(e.args))
) from e
return log_likelihood_dict.trace_dict

@requires("trace")
Expand Down
23 changes: 21 additions & 2 deletions arviz/rcparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,22 @@
from collections.abc import MutableMapping
from pathlib import Path
from typing import Any, Dict
from typing_extensions import Literal

NO_GET_ARGS: bool = False
try:
from typing_extensions import get_args
except ImportError:
NO_GET_ARGS = True


import numpy as np

_log = logging.getLogger(__name__)

ScaleKeyword = Literal["log", "negative_log", "deviance"]
ICKeyword = Literal["loo", "waic"]


def _make_validate_choice(accepted_values, allow_none=False, typeof=str):
"""Validate value is in accepted_values.
Expand Down Expand Up @@ -277,9 +288,17 @@ def validate_iterable(value):
"plot.matplotlib.constrained_layout": (True, _validate_boolean),
"plot.matplotlib.show": (False, _validate_boolean),
"stats.hdi_prob": (0.94, _validate_probability),
"stats.information_criterion": ("loo", _make_validate_choice({"waic", "loo"})),
"stats.information_criterion": (
"loo",
_make_validate_choice({"loo", "waic"} if NO_GET_ARGS else set(get_args(ICKeyword))),
),
"stats.ic_pointwise": (False, _validate_boolean),
"stats.ic_scale": ("log", _make_validate_choice({"deviance", "log", "negative_log"})),
"stats.ic_scale": (
"log",
_make_validate_choice(
{"log", "negative_log", "deviance"} if NO_GET_ARGS else set(get_args(ScaleKeyword))
),
),
"stats.ic_compare_method": (
"stacking",
_make_validate_choice({"stacking", "bb-pseudo-bma", "pseudo-bma"}),
Expand Down
75 changes: 54 additions & 21 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,24 @@
"""Statistical functions in ArviZ."""
import warnings
from copy import deepcopy
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Mapping, cast, Callable

import numpy as np
import pandas as pd
import scipy.stats as st
import xarray as xr
from scipy.optimize import minimize
from typing_extensions import Literal

NO_GET_ARGS: bool = False
try:
from typing_extensions import get_args
except ImportError:
NO_GET_ARGS = True

from arviz import _log
from ..data import InferenceData, convert_to_dataset, convert_to_inference_data
from ..rcparams import rcParams
from ..rcparams import rcParams, ScaleKeyword, ICKeyword
from ..utils import Numba, _numba_var, _var_names, get_coords
from .density_utils import get_bins as _get_bins
from .density_utils import histogram as _histogram
Expand All @@ -27,9 +34,6 @@
from ..sel_utils import xarray_var_iter
from ..labels import BaseLabeller

if TYPE_CHECKING:
from typing_extensions import Literal


__all__ = [
"apply_test_function",
Expand All @@ -45,7 +49,14 @@


def compare(
dataset_dict, ic=None, method="stacking", b_samples=1000, alpha=1, seed=None, scale=None
dataset_dict: Mapping[str, InferenceData],
ic: Optional[ICKeyword] = None,
method: Literal["stacking", "BB-pseudo-BMA", "pseudo-MA"] = "stacking",
b_samples: int = 1000,
alpha: float = 1,
seed=None,
scale: Optional[ScaleKeyword] = None,
var_name: Optional[str] = None,
):
r"""Compare models based on PSIS-LOO `loo` or WAIC `waic` cross-validation.
Expand All @@ -58,10 +69,10 @@ def compare(
----------
dataset_dict: dict[str] -> InferenceData
A dictionary of model names and InferenceData objects
ic: str
ic: str, optional
Information Criterion (PSIS-LOO `loo` or WAIC `waic`) used to compare models. Defaults to
``rcParams["stats.information_criterion"]``.
method: str
method: str, optional
Method used to estimate the weights for each model. Available options are:
- 'stacking' : stacking of predictive distributions.
Expand All @@ -71,19 +82,19 @@ def compare(
weighting, without Bootstrap stabilization (not recommended).
For more information read https://arxiv.org/abs/1704.02030
Defaults to ``rcParams["stats.ic_compare_method"]``.
b_samples: int
b_samples: int, optional default = 1000
Number of samples taken by the Bayesian bootstrap estimation.
Only useful when method = 'BB-pseudo-BMA'.
alpha: float
Defaults to ``rcParams["stats.ic_compare_method"]``.
alpha: float, optional
The shape parameter in the Dirichlet distribution used for the Bayesian bootstrap. Only
useful when method = 'BB-pseudo-BMA'. When alpha=1 (default), the distribution is uniform
on the simplex. A smaller alpha will keeps the final weights more away from 0 and 1.
seed: int or np.random.RandomState instance
seed: int or np.random.RandomState instance, optional
If int or RandomState, use it for seeding Bayesian bootstrap. Only
useful when method = 'BB-pseudo-BMA'. Default None the global
np.random state is used.
scale: str
scale: str, optional
Output scale for IC. Available options are:
- `log` : (default) log-score (after Vehtari et al. (2017))
Expand All @@ -92,6 +103,9 @@ def compare(
A higher log-score (or a lower deviance) indicates a model with better predictive
accuracy.
var_name: str, optional
If there is more than a single observed variable in the ``InferenceData``, which
should be used as the basis for comparison.
Returns
-------
Expand Down Expand Up @@ -142,10 +156,16 @@ def compare(
--------
loo : Compute the Pareto Smoothed importance sampling Leave One Out cross-validation.
waic : Compute the widely applicable information criterion.
"""
names = list(dataset_dict.keys())
scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()
if scale is not None:
scale = cast(ScaleKeyword, scale.lower())
else:
scale = cast(ScaleKeyword, rcParams["stats.ic_scale"])
allowable = ["log", "negative_log", "deviance"] if NO_GET_ARGS else get_args(ScaleKeyword)
if scale not in allowable:
raise ValueError(f"{scale} is not a valid value for scale: must be in {allowable}")

if scale == "log":
scale_value = 1
ascending = False
Expand All @@ -156,9 +176,15 @@ def compare(
scale_value = -2
ascending = True

ic = rcParams["stats.information_criterion"] if ic is None else ic.lower()
if ic is None:
ic = cast(ICKeyword, rcParams["stats.information_criterion"])
else:
ic = cast(ICKeyword, ic.lower())
allowable = ["loo", "waic"] if NO_GET_ARGS else get_args(ICKeyword)
if ic not in allowable:
raise ValueError(f"{ic} is not a valid value for ic: must be in {allowable}")
if ic == "loo":
ic_func = loo
ic_func: Callable = loo
df_comp = pd.DataFrame(
index=names,
columns=[
Expand All @@ -172,7 +198,7 @@ def compare(
"warning",
"loo_scale",
],
dtype=np.float,
dtype=np.float_,
)
scale_col = "loo_scale"
elif ic == "waic":
Expand All @@ -190,7 +216,7 @@ def compare(
"warning",
"waic_scale",
],
dtype=np.float,
dtype=np.float_,
)
scale_col = "waic_scale"
else:
Expand All @@ -208,7 +234,12 @@ def compare(
names = []
for name, dataset in dataset_dict.items():
names.append(name)
ics = ics.append([ic_func(dataset, pointwise=True, scale=scale)])
try:
# Here is where the IC function is actually computed -- the rest of this
# function is argument processing and return value formatting
ics = ics.append([ic_func(dataset, pointwise=True, scale=scale, var_name=var_name)])
except Exception as e:
raise e.__class__(f"Encountered error trying to compute {ic} from model {name}.") from e
ics.index = names
ics.sort_values(by=ic, inplace=True, ascending=ascending)
ics[ic_i] = ics[ic_i].apply(lambda x: x.values.flatten())
Expand Down Expand Up @@ -1280,7 +1311,9 @@ def summary(
n_vars = np.sum([joined[var].size // n_metrics for var in joined.data_vars])

if fmt.lower() == "wide":
summary_df = pd.DataFrame(np.full((n_vars, n_metrics), np.nan), columns=metric_names)
summary_df = pd.DataFrame(
(np.full((cast(int, n_vars), n_metrics), np.nan)), columns=metric_names
)
indexs = []
for i, (var_name, sel, isel, values) in enumerate(
xarray_var_iter(joined, skip_dims={"metric"})
Expand Down
6 changes: 5 additions & 1 deletion arviz/stats/stats_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,11 @@ def not_valid(ary, check_nan=True, check_shape=True, nan_kwargs=None, shape_kwar

def get_log_likelihood(idata, var_name=None):
"""Retrieve the log likelihood dataarray of a given variable."""
if hasattr(idata, "sample_stats") and hasattr(idata.sample_stats, "log_likelihood"):
if (
not hasattr(idata, "log_likelihood")
and hasattr(idata, "sample_stats")
and hasattr(idata.sample_stats, "log_likelihood")
):
warnings.warn(
"Storing the log_likelihood in sample_stats groups has been deprecated",
DeprecationWarning,
Expand Down
33 changes: 32 additions & 1 deletion arviz/tests/base_tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@ def non_centered_eight():
return non_centered_eight


@pytest.fixture(scope="module")
def multivariable_log_likelihood(centered_eight):
centered_eight = centered_eight.copy()
centered_eight.add_groups({"log_likelihood": centered_eight.sample_stats.log_likelihood})
centered_eight.log_likelihood = centered_eight.log_likelihood.rename_vars(
{"log_likelihood": "obs"}
)
new_arr = DataArray(
np.zeros(centered_eight.log_likelihood["obs"].values.shape),
dims=["chain", "draw", "school"],
coords=centered_eight.log_likelihood.coords,
)
centered_eight.log_likelihood["decoy"] = new_arr
delattr(centered_eight, "sample_stats")
return centered_eight


def test_hdp():
normal_sample = np.random.randn(5000000)
interval = hdi(normal_sample)
Expand Down Expand Up @@ -151,7 +168,7 @@ def test_compare_same(centered_eight, multidim_models, method, multidim):

def test_compare_unknown_ic_and_method(centered_eight, non_centered_eight):
model_dict = {"centered": centered_eight, "non_centered": non_centered_eight}
with pytest.raises(NotImplementedError):
with pytest.raises(ValueError):
compare(model_dict, ic="Unknown", method="stacking")
with pytest.raises(ValueError):
compare(model_dict, ic="loo", method="Unknown")
Expand Down Expand Up @@ -192,6 +209,20 @@ def test_compare_different_size(centered_eight, non_centered_eight):
compare(model_dict, ic="waic", method="stacking")


@pytest.mark.parametrize("ic", ["loo", "waic"])
def test_compare_multiple_obs(multivariable_log_likelihood, centered_eight, non_centered_eight, ic):
compare_dict = {
"centered_eight": centered_eight,
"non_centered_eight": non_centered_eight,
"problematic": multivariable_log_likelihood,
}
with pytest.raises(TypeError, match="several log likelihood arrays"):
get_log_likelihood(compare_dict["problematic"])
with pytest.raises(TypeError, match=f"{ic}.*model problematic"):
compare(compare_dict, ic=ic)
assert compare(compare_dict, ic=ic, var_name="obs") is not None


def test_summary_ndarray():
array = np.random.randn(4, 100, 2)
summary_df = summary(array)
Expand Down

0 comments on commit 2b4d141

Please sign in to comment.