Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify compare() docstring, error-check, pass var_name. #1616

Merged
merged 19 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions arviz/data/io_pymc3.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def __init__(
# this permits us to get the model from command-line argument or from with model:
try:
self.model = self.pymc3.modelcontext(model or self.model)
except TypeError:
except TypeError as e: # pylint: disable=invalid-name
_log.error("Got error %s trying to find log_likelihood in translation.", e)
self.model = None

if self.model is None:
Expand Down Expand Up @@ -249,12 +250,17 @@ def _extract_log_likelihood(self, trace):
"`pip install pymc3>=3.8` or `conda install -c conda-forge pymc3>=3.8`."
) from err
for var, log_like_fun in cached:
for k, chain in enumerate(trace.chains):
log_like_chain = [
self.log_likelihood_vals_point(point, var, log_like_fun)
for point in trace.points([chain])
]
log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k)
try:
for k, chain in enumerate(trace.chains):
log_like_chain = [
self.log_likelihood_vals_point(point, var, log_like_fun)
for point in trace.points([chain])
]
log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k)
except TypeError as e: # pylint: disable=invalid-name
raise TypeError(
*tuple(["While computing log-likelihood for {var}: "] + list(e.args))
) from e
return log_likelihood_dict.trace_dict

@requires("trace")
Expand Down
38 changes: 25 additions & 13 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
"""Statistical functions in ArviZ."""
import warnings
from copy import deepcopy
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Mapping

import numpy as np
import pandas as pd
import scipy.stats as st
import xarray as xr
from scipy.optimize import minimize
from typing_extensions import Literal

from arviz import _log
from ..data import InferenceData, convert_to_dataset, convert_to_inference_data
Expand All @@ -27,9 +28,6 @@
from ..sel_utils import xarray_var_iter
from ..labels import BaseLabeller

if TYPE_CHECKING:
from typing_extensions import Literal


__all__ = [
"apply_test_function",
Expand All @@ -45,7 +43,14 @@


def compare(
dataset_dict, ic=None, method="stacking", b_samples=1000, alpha=1, seed=None, scale=None
dataset_dict: Mapping[str, InferenceData],
ic: Optional[Literal["loo", "waic"]] = None,
method: Literal["stacking", "BB-pseudo-BMA", "pseudo-MA"] = "stacking",
b_samples: int = 1000,
alpha: float = 1,
seed=None,
scale: Optional[Literal["log", "negative_log", "deviance"]] = None,
var_name: Optional[str] = None,
):
r"""Compare models based on PSIS-LOO `loo` or WAIC `waic` cross-validation.

Expand All @@ -58,10 +63,10 @@ def compare(
----------
dataset_dict: dict[str] -> InferenceData
A dictionary of model names and InferenceData objects
ic: str
ic: str, optional
Information Criterion (PSIS-LOO `loo` or WAIC `waic`) used to compare models. Defaults to
``rcParams["stats.information_criterion"]``.
method: str
method: str, optional
Method used to estimate the weights for each model. Available options are:

- 'stacking' : (default) stacking of predictive distributions.
Expand All @@ -71,18 +76,18 @@ def compare(
weighting, without Bootstrap stabilization (not recommended).

For more information read https://arxiv.org/abs/1704.02030
b_samples: int
b_samples: int, optional default = 1000
Number of samples taken by the Bayesian bootstrap estimation.
Only useful when method = 'BB-pseudo-BMA'.
alpha: float
alpha: float, optional
The shape parameter in the Dirichlet distribution used for the Bayesian bootstrap. Only
useful when method = 'BB-pseudo-BMA'. When alpha=1 (default), the distribution is uniform
on the simplex. A smaller alpha will keeps the final weights more away from 0 and 1.
seed: int or np.random.RandomState instance
seed: int or np.random.RandomState instance, optional
If int or RandomState, use it for seeding Bayesian bootstrap. Only
useful when method = 'BB-pseudo-BMA'. Default None the global
np.random state is used.
scale: str
scale: str, optional
Output scale for IC. Available options are:

- `log` : (default) log-score (after Vehtari et al. (2017))
Expand All @@ -91,6 +96,9 @@ def compare(

A higher log-score (or a lower deviance) indicates a model with better predictive
accuracy.
var_name: str, optional
If there is more than a single observed variable in the ``InferenceData``, which
should be used as the basis for comparison.

Returns
-------
Expand Down Expand Up @@ -141,7 +149,6 @@ def compare(
--------
loo : Compute the Pareto Smoothed importance sampling Leave One Out cross-validation.
waic : Compute the widely applicable information criterion.

"""
names = list(dataset_dict.keys())
scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()
Expand Down Expand Up @@ -206,7 +213,12 @@ def compare(
names = []
for name, dataset in dataset_dict.items():
names.append(name)
ics = ics.append([ic_func(dataset, pointwise=True, scale=scale)])
try:
# Here is where the IC function is actually computed -- the rest of this
# function is argument processing and return value formatting
ics = ics.append([ic_func(dataset, pointwise=True, scale=scale, var_name=var_name)])
except Exception as e: # pylint: disable=invalid-name
raise e.__class__(f"Encountered error trying to compute {ic} from model {name}.") from e
ics.index = names
ics.sort_values(by=ic, inplace=True, ascending=ascending)
ics[ic_i] = ics[ic_i].apply(lambda x: x.values.flatten())
Expand Down