Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warmup iterations and _group_warmup #1126

Merged
merged 16 commits into from
Apr 6, 2020
97 changes: 76 additions & 21 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@
"predictions_constant_data",
]

SUPPORTED_GROUPS_WARMUP = [
"_warmup_posterior",
"_warmup_sample_stats",
"_warmup_log_likelihood",
"_warmup_posterior_predictive",
"_warmup_prior",
"_warmup_sample_stats_prior",
"_warmup_prior_predictive",
"_warmup_predictions",
]


class InferenceData:
"""Container for inference data storage using xarray.
Expand Down Expand Up @@ -76,34 +87,53 @@ def __init__(self, **kwargs):

"""
self._groups = []
key_list = [key for key in SUPPORTED_GROUPS if key in kwargs]
self._groups_warmup = []
save_warmup = kwargs.pop("save_warmup", False)
key_list = [key for key in SUPPORTED_GROUPS + SUPPORTED_GROUPS_WARMUP if key in kwargs]
for key in kwargs:
if key not in SUPPORTED_GROUPS:
if key not in SUPPORTED_GROUPS + SUPPORTED_GROUPS_WARMUP:
ahartikainen marked this conversation as resolved.
Show resolved Hide resolved
key_list.append(key)
warnings.warn(
"{} group is not defined in the InferenceData scheme".format(key), UserWarning
)
for key in key_list:
dataset = kwargs[key]
dataset_warmup = None
if dataset is None:
continue
elif isinstance(dataset, (list, tuple)):
dataset, dataset_warmup = kwargs[key]
elif not isinstance(dataset, xr.Dataset):
raise ValueError(
"Arguments to InferenceData must be xarray Datasets "
"(argument '{}' was type '{}')".format(key, type(dataset))
)
setattr(self, key, dataset)
self._groups.append(key)
if not key.startswith("_warmup_"):
ahartikainen marked this conversation as resolved.
Show resolved Hide resolved
setattr(self, key, dataset)
self._groups.append(key)
elif key.startswith("_warmup_"):
setattr(self, key, dataset)
self._groups_warmup.append(key)
if save_warmup and dataset_warmup is not None:
key = "_warmup_{}".format(key)
setattr(self, key, dataset_warmup)
self._groups_warmup.append(key)

def __repr__(self):
"""Make string representation of object."""
return "Inference data with groups:\n\t> {options}".format(
msg = "Inference data with groups:\n\t> {options}".format(
options="\n\t> ".join(self._groups)
)
if self._groups_warmup:
msg += "warmup iterations saved."
return msg

def __delattr__(self, group):
"""Delete a group from the InferenceData object."""
self._groups.remove(group)
if group in self._groups:
self._groups.remove(group)
elif group in self._groups_warmup:
self._groups_warmup.remove(group)
object.__delattr__(self, group)

@staticmethod
Expand Down Expand Up @@ -155,11 +185,12 @@ def to_netcdf(self, filename, compress=True, groups=None):
Location of netcdf file
"""
mode = "w" # overwrite first, then append
if self._groups: # check's whether a group is present or not.
if self._groups or self._groups_warmup: # check's whether a group is present or not.
if groups is None:
groups = self._groups
groups = self._groups + self._groups_warmup
ahartikainen marked this conversation as resolved.
Show resolved Hide resolved
else:
groups = [group for group in self._groups if group in groups]
groups = [group for group in self._groups + self._groups_warmup if group in groups]

for group in groups:
data = getattr(self, group)
kwargs = {}
Expand Down Expand Up @@ -236,6 +267,13 @@ def sel(self, inplace=False, chain_prior=False, **kwargs):
valid_keys -= {"chain"}
dataset = dataset.sel(**{key: kwargs[key] for key in valid_keys})
setattr(out, group, dataset)
for group in self._groups_warmup:
dataset = getattr(self, group)
valid_keys = set(kwargs.keys()).intersection(dataset.dims)
if not chain_prior and "prior" in group:
valid_keys -= {"chain"}
dataset = dataset.sel(**{key: kwargs[key] for key in valid_keys})
setattr(out, group, dataset)
if inplace:
return None
else:
Expand Down Expand Up @@ -359,12 +397,12 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):

if dim is None:
arg0 = args[0]
arg0_groups = ccopy(arg0._groups)
arg0_groups = ccopy(arg0._groups + arg0._groups_warmup)
args_groups = dict()
# check if groups are independent
# Concat over unique groups
for arg in args[1:]:
for group in arg._groups:
for group in arg._groups + arg._groups_warmup:
if group in args_groups or group in arg0_groups:
msg = (
"Concatenating overlapping groups is not supported unless `dim` is defined."
Expand All @@ -374,41 +412,58 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
group_data = getattr(arg, group)
args_groups[group] = deepcopy(group_data) if copy else group_data
# add arg0 to args_groups if inplace is False
# otherwise it will merge args_groups to arg0
# inference data object
if not inplace:
for group in arg0_groups:
group_data = getattr(arg0, group)
args_groups[group] = deepcopy(group_data) if copy else group_data

other_groups = [group for group in args_groups if group not in SUPPORTED_GROUPS]
other_groups = [
group
for group in args_groups
if group not in SUPPORTED_GROUPS + SUPPORTED_GROUPS_WARMUP
]

for group in SUPPORTED_GROUPS + other_groups:
for group in SUPPORTED_GROUPS + SUPPORTED_GROUPS_WARMUP + other_groups:
if group not in args_groups:
continue
if inplace:
arg0._groups.append(group)
if group.startswith("_warmup_"):
arg0._groups_warmup.append(group)
else:
arg0._groups.append(group)
setattr(arg0, group, args_groups[group])
else:
inference_data_dict[group] = args_groups[group]
if inplace:
other_groups = [
group for group in arg0_groups if group not in SUPPORTED_GROUPS
group
for group in arg0_groups
if group not in SUPPORTED_GROUPS + SUPPORTED_GROUPS_WARMUP
] + other_groups
sorted_groups = [
group for group in SUPPORTED_GROUPS + other_groups if group in arg0._groups
]
setattr(arg0, "_groups", sorted_groups)
sorted_groups_warmup = [
group
for group in SUPPORTED_GROUPS_WARMUP + other_groups
if group in arg0._groups_warmup
]
setattr(arg0, "_groups_warmup", sorted_groups_warmup)
else:
arg0 = args[0]
arg0_groups = arg0._groups
arg0_groups = arg0._groups + arg0._groups_warmup
for arg in args[1:]:
for group0 in arg0_groups:
if group0 not in arg._groups:
if group0 not in arg._groups + arg._groups_warmup:
if group0 == "observed_data":
continue
msg = "Mismatch between the groups."
raise TypeError(msg)
for group in arg._groups:
if group != "observed_data":
for group in arg._groups + arg._groups_warmup:
if group not in ["observed_data", "constant_data", "predictions_constant_data"]:
ahartikainen marked this conversation as resolved.
Show resolved Hide resolved
# assert that groups are equal
if group not in arg0_groups:
msg = "Mismatch between the groups."
Expand Down Expand Up @@ -500,7 +555,7 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
else:
inference_data_dict[group] = concatenated_group
else:
# observed_data
# observed_data, "constant_data", "predictions_constant_data",
if group not in arg0_groups:
setattr(arg0, group, deepcopy(group_data) if copy else group_data)
arg0._groups.append(group)
Expand All @@ -518,7 +573,7 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
for var in group_vars:
if var not in group0_vars:
var_data = getattr(group_data, var)
arg0.observed_data[var] = var_data
getattr(arg0, group)[var] = var_data
else:
var_data = getattr(group_data, var)
var0_data = getattr(group0_data, var)
Expand Down
3 changes: 3 additions & 0 deletions arviz/data/inference_data.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class InferenceData:
posterior_predictive: Optional[xr.Dataset]
predictions: Optional[xr.Dataset]
predictions_constant_data: Optional[xr.Dataset]
_warmup_posterior: Optional[xr.Dataset]
_warmup_posterior_predictive: Optional[xr.Dataset]
_warmup_predictions: Optional[xr.Dataset]
ahartikainen marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, **kwargs): ...
def __repr__(self) -> str: ...
def __delattr__(self, group: str) -> None: ...
Expand Down
Loading