Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warmup iterations and _group_warmup #1126

Merged
merged 16 commits into from
Apr 6, 2020
115 changes: 84 additions & 31 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,30 @@

SUPPORTED_GROUPS = [
"posterior",
"sample_stats",
"log_likelihood",
"posterior_predictive",
"observed_data",
"constant_data",
"predictions",
"log_likelihood",
"sample_stats",
"prior",
"sample_stats_prior",
"prior_predictive",
"predictions",
"sample_stats_prior",
"observed_data",
"constant_data",
"predictions_constant_data",
]

WARMUP_TAG = "_warmup_"

SUPPORTED_GROUPS_WARMUP = [
"{}posterior".format(WARMUP_TAG),
"{}posterior_predictive".format(WARMUP_TAG),
"{}predictions".format(WARMUP_TAG),
"{}sample_stats".format(WARMUP_TAG),
"{}log_likelihood".format(WARMUP_TAG),
]

SUPPORTED_GROUPS_ALL = SUPPORTED_GROUPS + SUPPORTED_GROUPS_WARMUP


class InferenceData:
"""Container for inference data storage using xarray.
Expand Down Expand Up @@ -76,36 +88,62 @@ def __init__(self, **kwargs):
"""
self._groups = []
key_list = [key for key in SUPPORTED_GROUPS if key in kwargs]
self._groups_warmup = []
save_warmup = kwargs.pop("save_warmup", False)
key_list = [key for key in SUPPORTED_GROUPS_ALL if key in kwargs]
for key in kwargs:
if key not in SUPPORTED_GROUPS:
if key not in SUPPORTED_GROUPS_ALL:
key_list.append(key)
warnings.warn(
"{} group is not defined in the InferenceData scheme".format(key), UserWarning
)
for key in key_list:
dataset = kwargs[key]
dataset_warmup = None
if dataset is None:
continue
elif isinstance(dataset, (list, tuple)):
dataset, dataset_warmup = kwargs[key]
elif not isinstance(dataset, xr.Dataset):
raise ValueError(
"Arguments to InferenceData must be xarray Datasets "
"(argument '{}' was type '{}')".format(key, type(dataset))
)
setattr(self, key, dataset)
self._groups.append(key)
if not key.startswith(WARMUP_TAG):
if dataset:
setattr(self, key, dataset)
self._groups.append(key)
elif key.startswith(WARMUP_TAG):
if dataset:
setattr(self, key, dataset)
self._groups_warmup.append(key)
if save_warmup and dataset_warmup is not None:
if dataset_warmup:
key = "{}{}".format(WARMUP_TAG, key)
setattr(self, key, dataset_warmup)
self._groups_warmup.append(key)

def __repr__(self):
"""Make string representation of object."""
return "Inference data with groups:\n\t> {options}".format(
msg = "Inference data with groups:\n\t> {options}".format(
options="\n\t> ".join(self._groups)
)
if self._groups_warmup:
msg += "\n\nWarmup iterations saved ({}*).".format(WARMUP_TAG)
return msg

def __delattr__(self, group):
"""Delete a group from the InferenceData object."""
self._groups.remove(group)
if group in self._groups:
self._groups.remove(group)
elif group in self._groups_warmup:
self._groups_warmup.remove(group)
object.__delattr__(self, group)

@property
def _groups_all(self):
return self._groups + self._groups_warmup

@staticmethod
def from_netcdf(filename):
"""Initialize object from a netcdf file.
Expand Down Expand Up @@ -155,11 +193,12 @@ def to_netcdf(self, filename, compress=True, groups=None):
Location of netcdf file
"""
mode = "w" # overwrite first, then append
if self._groups: # check's whether a group is present or not.
if self._groups_all: # check's whether a group is present or not.
if groups is None:
groups = self._groups
groups = self._groups_all
else:
groups = [group for group in self._groups if group in groups]
groups = [group for group in self._groups_all if group in groups]

for group in groups:
data = getattr(self, group)
kwargs = {}
Expand All @@ -177,7 +216,7 @@ def __add__(self, other):
"""Concatenate two InferenceData objects."""
return concat(self, other, copy=True, inplace=False)

def sel(self, inplace=False, chain_prior=False, **kwargs):
def sel(self, inplace=False, chain_prior=False, warmup=False, **kwargs):
"""Perform an xarray selection on all groups.
Loops over all groups to perform Dataset.sel(key=item)
Expand All @@ -194,9 +233,11 @@ def sel(self, inplace=False, chain_prior=False, **kwargs):
otherwise, return the modified copy.
chain_prior: bool, optional
If ``False``, do not select prior related groups using ``chain`` dim.
Otherwise, use selection on ``chain`` if present
Otherwise, use selection on ``chain`` if present.
warmup: bool, optional
If ``False``, do not select warmup groups.
**kwargs : mapping
It must be accepted by Dataset.sel()
It must be accepted by Dataset.sel().
Returns
-------
Expand Down Expand Up @@ -229,7 +270,7 @@ def sel(self, inplace=False, chain_prior=False, **kwargs):
"""
out = self if inplace else deepcopy(self)
for group in self._groups:
for group in self._groups_all if warmup else self._groups:
dataset = getattr(self, group)
valid_keys = set(kwargs.keys()).intersection(dataset.dims)
if not chain_prior and "prior" in group:
Expand Down Expand Up @@ -359,12 +400,12 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):

if dim is None:
arg0 = args[0]
arg0_groups = ccopy(arg0._groups)
arg0_groups = ccopy(arg0._groups_all)
args_groups = dict()
# check if groups are independent
# Concat over unique groups
for arg in args[1:]:
for group in arg._groups:
for group in arg._groups_all:
if group in args_groups or group in arg0_groups:
msg = (
"Concatenating overlapping groups is not supported unless `dim` is defined."
Expand All @@ -374,41 +415,53 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
group_data = getattr(arg, group)
args_groups[group] = deepcopy(group_data) if copy else group_data
# add arg0 to args_groups if inplace is False
# otherwise it will merge args_groups to arg0
# inference data object
if not inplace:
for group in arg0_groups:
group_data = getattr(arg0, group)
args_groups[group] = deepcopy(group_data) if copy else group_data

other_groups = [group for group in args_groups if group not in SUPPORTED_GROUPS]
other_groups = [group for group in args_groups if group not in SUPPORTED_GROUPS_ALL]

for group in SUPPORTED_GROUPS + other_groups:
for group in SUPPORTED_GROUPS_ALL + other_groups:
if group not in args_groups:
continue
if inplace:
arg0._groups.append(group)
if group.startswith(WARMUP_TAG):
arg0._groups_warmup.append(group)
else:
arg0._groups.append(group)
setattr(arg0, group, args_groups[group])
else:
inference_data_dict[group] = args_groups[group]
if inplace:
other_groups = [
group for group in arg0_groups if group not in SUPPORTED_GROUPS
group for group in arg0_groups if group not in SUPPORTED_GROUPS_ALL
] + other_groups
sorted_groups = [
group for group in SUPPORTED_GROUPS + other_groups if group in arg0._groups
]
setattr(arg0, "_groups", sorted_groups)
sorted_groups_warmup = [
group
for group in SUPPORTED_GROUPS_WARMUP + other_groups
if group in arg0._groups_warmup
]
setattr(arg0, "_groups_warmup", sorted_groups_warmup)
else:
arg0 = args[0]
arg0_groups = arg0._groups
arg0_groups = arg0._groups_all
for arg in args[1:]:
for group0 in arg0_groups:
if group0 not in arg._groups:
if group0 not in arg._groups_all:
if group0 == "observed_data":
continue
msg = "Mismatch between the groups."
raise TypeError(msg)
for group in arg._groups:
if group != "observed_data":
for group in arg._groups_all:
# handle data groups seperately
if group not in ["observed_data", "constant_data", "predictions_constant_data"]:
ahartikainen marked this conversation as resolved.
Show resolved Hide resolved
# assert that groups are equal
if group not in arg0_groups:
msg = "Mismatch between the groups."
Expand Down Expand Up @@ -500,7 +553,7 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
else:
inference_data_dict[group] = concatenated_group
else:
# observed_data
# observed_data, "constant_data", "predictions_constant_data",
if group not in arg0_groups:
setattr(arg0, group, deepcopy(group_data) if copy else group_data)
arg0._groups.append(group)
Expand All @@ -518,7 +571,7 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
for var in group_vars:
if var not in group0_vars:
var_data = getattr(group_data, var)
arg0.observed_data[var] = var_data
getattr(arg0, group)[var] = var_data
else:
var_data = getattr(group_data, var)
var0_data = getattr(group0_data, var)
Expand Down
18 changes: 14 additions & 4 deletions arviz/data/inference_data.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,27 @@ import xarray as xr

class InferenceData:
posterior: Optional[xr.Dataset]
observations: Optional[xr.Dataset]
constant_data: Optional[xr.Dataset]
prior: Optional[xr.Dataset]
prior_predictive: Optional[xr.Dataset]
posterior_predictive: Optional[xr.Dataset]
predictions: Optional[xr.Dataset]
log_likelihood: Optional[xr.Dataset]
sample_stats: Optional[xr.Dataset]
observed_data: Optional[xr.Dataset]
constant_data: Optional[xr.Dataset]
predictions_constant_data: Optional[xr.Dataset]
prior: Optional[xr.Dataset]
prior_predictive: Optional[xr.Dataset]
sample_stats_prior: Optional[xr.Dataset]
_warmup_posterior: Optional[xr.Dataset]
_warmup_posterior_predictive: Optional[xr.Dataset]
_warmup_predictions: Optional[xr.Dataset]
ahartikainen marked this conversation as resolved.
Show resolved Hide resolved
_warmup_log_likelihood: Optional[xr.Dataset]
_warmup_sample_stats: Optional[xr.Dataset]
def __init__(self, **kwargs): ...
def __repr__(self) -> str: ...
def __delattr__(self, group: str) -> None: ...
def __add__(self, other: "InferenceData"): ...
@property
def _groups_all(self) -> List[str]: ...
@staticmethod
def from_netcdf(filename: str) -> "InferenceData": ...
def to_netcdf(
Expand Down
Loading