Skip to content

Commit

Permalink
Add warmup iterations and _group_warmup (#1126)
Browse files Browse the repository at this point in the history
* pystan save warmup

* fix inferencedata handling

* _group_warmup

* default save warmup false

* Change order

* fix test

* pylint fix

* rcparam and tests

* fix logic for warmup

* simplify test

* Line len

* Remove ws

* simplify code

* Default behaviour for get_sample_stats

* Missing self

* fix empty datasets and minor fixes

Co-authored-by: ahartikainen <[email protected]>
  • Loading branch information
ahartikainen and ahartikainen authored Apr 6, 2020
1 parent 1b2ecdd commit 6600929
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 112 deletions.
115 changes: 84 additions & 31 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,30 @@

SUPPORTED_GROUPS = [
"posterior",
"sample_stats",
"log_likelihood",
"posterior_predictive",
"observed_data",
"constant_data",
"predictions",
"log_likelihood",
"sample_stats",
"prior",
"sample_stats_prior",
"prior_predictive",
"predictions",
"sample_stats_prior",
"observed_data",
"constant_data",
"predictions_constant_data",
]

WARMUP_TAG = "_warmup_"

SUPPORTED_GROUPS_WARMUP = [
"{}posterior".format(WARMUP_TAG),
"{}posterior_predictive".format(WARMUP_TAG),
"{}predictions".format(WARMUP_TAG),
"{}sample_stats".format(WARMUP_TAG),
"{}log_likelihood".format(WARMUP_TAG),
]

SUPPORTED_GROUPS_ALL = SUPPORTED_GROUPS + SUPPORTED_GROUPS_WARMUP


class InferenceData:
"""Container for inference data storage using xarray.
Expand Down Expand Up @@ -76,36 +88,62 @@ def __init__(self, **kwargs):
"""
self._groups = []
key_list = [key for key in SUPPORTED_GROUPS if key in kwargs]
self._groups_warmup = []
save_warmup = kwargs.pop("save_warmup", False)
key_list = [key for key in SUPPORTED_GROUPS_ALL if key in kwargs]
for key in kwargs:
if key not in SUPPORTED_GROUPS:
if key not in SUPPORTED_GROUPS_ALL:
key_list.append(key)
warnings.warn(
"{} group is not defined in the InferenceData scheme".format(key), UserWarning
)
for key in key_list:
dataset = kwargs[key]
dataset_warmup = None
if dataset is None:
continue
elif isinstance(dataset, (list, tuple)):
dataset, dataset_warmup = kwargs[key]
elif not isinstance(dataset, xr.Dataset):
raise ValueError(
"Arguments to InferenceData must be xarray Datasets "
"(argument '{}' was type '{}')".format(key, type(dataset))
)
setattr(self, key, dataset)
self._groups.append(key)
if not key.startswith(WARMUP_TAG):
if dataset:
setattr(self, key, dataset)
self._groups.append(key)
elif key.startswith(WARMUP_TAG):
if dataset:
setattr(self, key, dataset)
self._groups_warmup.append(key)
if save_warmup and dataset_warmup is not None:
if dataset_warmup:
key = "{}{}".format(WARMUP_TAG, key)
setattr(self, key, dataset_warmup)
self._groups_warmup.append(key)

def __repr__(self):
"""Make string representation of object."""
return "Inference data with groups:\n\t> {options}".format(
msg = "Inference data with groups:\n\t> {options}".format(
options="\n\t> ".join(self._groups)
)
if self._groups_warmup:
msg += "\n\nWarmup iterations saved ({}*).".format(WARMUP_TAG)
return msg

def __delattr__(self, group):
"""Delete a group from the InferenceData object."""
self._groups.remove(group)
if group in self._groups:
self._groups.remove(group)
elif group in self._groups_warmup:
self._groups_warmup.remove(group)
object.__delattr__(self, group)

@property
def _groups_all(self):
return self._groups + self._groups_warmup

@staticmethod
def from_netcdf(filename):
"""Initialize object from a netcdf file.
Expand Down Expand Up @@ -155,11 +193,12 @@ def to_netcdf(self, filename, compress=True, groups=None):
Location of netcdf file
"""
mode = "w" # overwrite first, then append
if self._groups: # check's whether a group is present or not.
if self._groups_all: # check's whether a group is present or not.
if groups is None:
groups = self._groups
groups = self._groups_all
else:
groups = [group for group in self._groups if group in groups]
groups = [group for group in self._groups_all if group in groups]

for group in groups:
data = getattr(self, group)
kwargs = {}
Expand All @@ -177,7 +216,7 @@ def __add__(self, other):
"""Concatenate two InferenceData objects."""
return concat(self, other, copy=True, inplace=False)

def sel(self, inplace=False, chain_prior=False, **kwargs):
def sel(self, inplace=False, chain_prior=False, warmup=False, **kwargs):
"""Perform an xarray selection on all groups.
Loops over all groups to perform Dataset.sel(key=item)
Expand All @@ -194,9 +233,11 @@ def sel(self, inplace=False, chain_prior=False, **kwargs):
otherwise, return the modified copy.
chain_prior: bool, optional
If ``False``, do not select prior related groups using ``chain`` dim.
Otherwise, use selection on ``chain`` if present
Otherwise, use selection on ``chain`` if present.
warmup: bool, optional
If ``False``, do not select warmup groups.
**kwargs : mapping
It must be accepted by Dataset.sel()
It must be accepted by Dataset.sel().
Returns
-------
Expand Down Expand Up @@ -229,7 +270,7 @@ def sel(self, inplace=False, chain_prior=False, **kwargs):
"""
out = self if inplace else deepcopy(self)
for group in self._groups:
for group in self._groups_all if warmup else self._groups:
dataset = getattr(self, group)
valid_keys = set(kwargs.keys()).intersection(dataset.dims)
if not chain_prior and "prior" in group:
Expand Down Expand Up @@ -359,12 +400,12 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):

if dim is None:
arg0 = args[0]
arg0_groups = ccopy(arg0._groups)
arg0_groups = ccopy(arg0._groups_all)
args_groups = dict()
# check if groups are independent
# Concat over unique groups
for arg in args[1:]:
for group in arg._groups:
for group in arg._groups_all:
if group in args_groups or group in arg0_groups:
msg = (
"Concatenating overlapping groups is not supported unless `dim` is defined."
Expand All @@ -374,41 +415,53 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
group_data = getattr(arg, group)
args_groups[group] = deepcopy(group_data) if copy else group_data
# add arg0 to args_groups if inplace is False
# otherwise it will merge args_groups to arg0
# inference data object
if not inplace:
for group in arg0_groups:
group_data = getattr(arg0, group)
args_groups[group] = deepcopy(group_data) if copy else group_data

other_groups = [group for group in args_groups if group not in SUPPORTED_GROUPS]
other_groups = [group for group in args_groups if group not in SUPPORTED_GROUPS_ALL]

for group in SUPPORTED_GROUPS + other_groups:
for group in SUPPORTED_GROUPS_ALL + other_groups:
if group not in args_groups:
continue
if inplace:
arg0._groups.append(group)
if group.startswith(WARMUP_TAG):
arg0._groups_warmup.append(group)
else:
arg0._groups.append(group)
setattr(arg0, group, args_groups[group])
else:
inference_data_dict[group] = args_groups[group]
if inplace:
other_groups = [
group for group in arg0_groups if group not in SUPPORTED_GROUPS
group for group in arg0_groups if group not in SUPPORTED_GROUPS_ALL
] + other_groups
sorted_groups = [
group for group in SUPPORTED_GROUPS + other_groups if group in arg0._groups
]
setattr(arg0, "_groups", sorted_groups)
sorted_groups_warmup = [
group
for group in SUPPORTED_GROUPS_WARMUP + other_groups
if group in arg0._groups_warmup
]
setattr(arg0, "_groups_warmup", sorted_groups_warmup)
else:
arg0 = args[0]
arg0_groups = arg0._groups
arg0_groups = arg0._groups_all
for arg in args[1:]:
for group0 in arg0_groups:
if group0 not in arg._groups:
if group0 not in arg._groups_all:
if group0 == "observed_data":
continue
msg = "Mismatch between the groups."
raise TypeError(msg)
for group in arg._groups:
if group != "observed_data":
for group in arg._groups_all:
# handle data groups seperately
if group not in ["observed_data", "constant_data", "predictions_constant_data"]:
# assert that groups are equal
if group not in arg0_groups:
msg = "Mismatch between the groups."
Expand Down Expand Up @@ -500,7 +553,7 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
else:
inference_data_dict[group] = concatenated_group
else:
# observed_data
# observed_data, "constant_data", "predictions_constant_data",
if group not in arg0_groups:
setattr(arg0, group, deepcopy(group_data) if copy else group_data)
arg0._groups.append(group)
Expand All @@ -518,7 +571,7 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
for var in group_vars:
if var not in group0_vars:
var_data = getattr(group_data, var)
arg0.observed_data[var] = var_data
getattr(arg0, group)[var] = var_data
else:
var_data = getattr(group_data, var)
var0_data = getattr(group0_data, var)
Expand Down
18 changes: 14 additions & 4 deletions arviz/data/inference_data.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,27 @@ import xarray as xr

class InferenceData:
posterior: Optional[xr.Dataset]
observations: Optional[xr.Dataset]
constant_data: Optional[xr.Dataset]
prior: Optional[xr.Dataset]
prior_predictive: Optional[xr.Dataset]
posterior_predictive: Optional[xr.Dataset]
predictions: Optional[xr.Dataset]
log_likelihood: Optional[xr.Dataset]
sample_stats: Optional[xr.Dataset]
observed_data: Optional[xr.Dataset]
constant_data: Optional[xr.Dataset]
predictions_constant_data: Optional[xr.Dataset]
prior: Optional[xr.Dataset]
prior_predictive: Optional[xr.Dataset]
sample_stats_prior: Optional[xr.Dataset]
_warmup_posterior: Optional[xr.Dataset]
_warmup_posterior_predictive: Optional[xr.Dataset]
_warmup_predictions: Optional[xr.Dataset]
_warmup_log_likelihood: Optional[xr.Dataset]
_warmup_sample_stats: Optional[xr.Dataset]
def __init__(self, **kwargs): ...
def __repr__(self) -> str: ...
def __delattr__(self, group: str) -> None: ...
def __add__(self, other: "InferenceData"): ...
@property
def _groups_all(self) -> List[str]: ...
@staticmethod
def from_netcdf(filename: str) -> "InferenceData": ...
def to_netcdf(
Expand Down
Loading

0 comments on commit 6600929

Please sign in to comment.