Skip to content

Commit

Permalink
Add attr and name to idata (#1357)
Browse files Browse the repository at this point in the history
* added attr

* added attr to az.concat

* modified attrs

* added documentation for group kwargs

* test error

* final linting change

* pydocstyle

* minimal change

* Fix typo

* Fix long line

* Update inference_data.py

* Update io_dict.py

* mypy fixes

* Update inference_data.py

Co-authored-by: Ari Hartikainen <[email protected]>
Co-authored-by: Ari Hartikainen <[email protected]>
  • Loading branch information
3 people authored Jan 16, 2022
1 parent 08929a4 commit 4544dd0
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 33 deletions.
47 changes: 37 additions & 10 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,16 @@ class InferenceData(Mapping[str, xr.Dataset]):
"""

def __init__(
self, **kwargs: Union[xr.Dataset, List[xr.Dataset], Tuple[xr.Dataset, xr.Dataset]]
self,
attrs: Union[None, Mapping[Any, Any]] = None,
**kwargs: Union[xr.Dataset, List[xr.Dataset], Tuple[xr.Dataset, xr.Dataset]],
) -> None:
"""Initialize InferenceData object from keyword xarray datasets.
Parameters
----------
attrs : dict
sets global attribute for InferenceData object.
kwargs :
Keyword arguments of xarray datasets
Expand Down Expand Up @@ -133,6 +137,7 @@ def __init__(
"""
self._groups: List[str] = []
self._groups_warmup: List[str] = []
self._attrs: Union[None, dict] = dict(attrs) if attrs is not None else None
save_warmup = kwargs.pop("save_warmup", False)
key_list = [key for key in SUPPORTED_GROUPS_ALL if key in kwargs]
for key in kwargs:
Expand Down Expand Up @@ -167,6 +172,17 @@ def __init__(
setattr(self, key, dataset_warmup)
self._groups_warmup.append(key)

@property
def attrs(self) -> dict:
"""Attributes of InferenceData object."""
if self._attrs is None:
self._attrs = {}
return self._attrs

@attrs.setter
def attrs(self, value) -> None:
self._attrs = dict(value)

def __repr__(self) -> str:
"""Make string representation of InferenceData object."""
msg = "Inference data with groups:\n\t> {options}".format(
Expand Down Expand Up @@ -437,7 +453,6 @@ def to_dict(self, groups=None, filter_groups=None):
When `data=False` return just the schema.
"""
ret = defaultdict(dict)
attrs = None
if self._groups_all: # check's whether a group is present or not.
if groups is None:
groups = self._group_names(groups, filter_groups)
Expand Down Expand Up @@ -467,15 +482,9 @@ def to_dict(self, groups=None, filter_groups=None):
if len(dims) > 0:
ret[dims_key][var_name] = dims
ret[group] = data
if attrs is None:
attrs = dataset.attrs
elif attrs != dataset.attrs:
warnings.warn(
"The attributes are not same for all groups."
" Considering only the first group `attrs`"
)
ret[group + "_attrs"] = dataset.attrs

ret["attrs"] = attrs
ret["attrs"] = self.attrs
return ret

def to_json(self, filename, groups=None, filter_groups=None, **kwargs):
Expand Down Expand Up @@ -1874,6 +1883,21 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
return args[0]

current_time = str(datetime.now())
combined_attr = defaultdict(list)
for idata in args:
for key, val in idata.attrs.items():
combined_attr[key].append(val)

for key, val in combined_attr.items():
all_same = True
for indx in range(len(val) - 1):
if val[indx] != val[indx + 1]:
all_same = False
break
if all_same:
combined_attr[key] = val[0]
if inplace:
setattr(args[0], "_attrs", dict(combined_attr))

if not inplace:
# Keep order for python 3.5
Expand Down Expand Up @@ -2114,4 +2138,7 @@ def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
else:
inference_data_dict[group] = group0_data

if not inplace:
inference_data_dict["attrs"] = combined_attr

return None if inplace else InferenceData(**inference_data_dict)
64 changes: 43 additions & 21 deletions arviz/data/io_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
pred_dims=None,
pred_coords=None,
attrs=None,
**kwargs,
):
self.posterior = posterior
self.posterior_predictive = posterior_predictive
Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(
self.attrs = {} if attrs is None else attrs
self.attrs.pop("created_at", None)
self.attrs.pop("arviz_version", None)
self._kwargs = kwargs

def _init_dict(self, attr_name):
dict_or_none = getattr(self, attr_name, {})
Expand All @@ -90,22 +92,23 @@ def posterior_to_xarray(self):
" For stats functions log likelihood data needs to be in log_likelihood group.",
UserWarning,
)

posterior_attrs = self._kwargs.get("posterior_attrs")
posterior_warmup_attrs = self._kwargs.get("posterior_warmup_attrs")
return (
dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
attrs=posterior_attrs,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=None,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
attrs=posterior_warmup_attrs,
index_origin=self.index_origin,
),
)
Expand All @@ -127,22 +130,23 @@ def sample_stats_to_xarray(self):
"favour of storing them in the log_likelihood group.",
PendingDeprecationWarning,
)

sample_stats_attrs = self._kwargs.get("sample_stats_attrs")
sample_stats_warmup_attrs = self._kwargs.get("sample_stats_warmup_attrs")
return (
dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
attrs=sample_stats_attrs,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=None,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
attrs=sample_stats_warmup_attrs,
index_origin=self.index_origin,
),
)
Expand All @@ -156,14 +160,15 @@ def log_likelihood_to_xarray(self):
raise TypeError("DictConverter.log_likelihood is not a dictionary")
if not isinstance(data_warmup, dict):
raise TypeError("DictConverter.warmup_log_likelihood is not a dictionary")

log_likelihood_attrs = self._kwargs.get("log_likelihood_attrs")
log_likelihood_warmup_attrs = self._kwargs.get("log_likelihood_warmup_attrs")
return (
dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
attrs=log_likelihood_attrs,
index_origin=self.index_origin,
skip_event_dims=True,
),
Expand All @@ -172,7 +177,7 @@ def log_likelihood_to_xarray(self):
library=None,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
attrs=log_likelihood_warmup_attrs,
index_origin=self.index_origin,
skip_event_dims=True,
),
Expand All @@ -187,22 +192,23 @@ def posterior_predictive_to_xarray(self):
raise TypeError("DictConverter.posterior_predictive is not a dictionary")
if not isinstance(data_warmup, dict):
raise TypeError("DictConverter.warmup_posterior_predictive is not a dictionary")

posterior_predictive_attrs = self._kwargs.get("posterior_predictive_attrs")
posterior_predictive_warmup_attrs = self._kwargs.get("posterior_predictive_warmup_attrs")
return (
dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
attrs=posterior_predictive_attrs,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=None,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
attrs=posterior_predictive_warmup_attrs,
index_origin=self.index_origin,
),
)
Expand All @@ -216,22 +222,23 @@ def predictions_to_xarray(self):
raise TypeError("DictConverter.predictions is not a dictionary")
if not isinstance(data_warmup, dict):
raise TypeError("DictConverter.warmup_predictions is not a dictionary")

predictions_attrs = self._kwargs.get("predictions_attrs")
predictions_warmup_attrs = self._kwargs.get("predictions_warmup_attrs")
return (
dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.pred_dims,
attrs=self.attrs,
attrs=predictions_attrs,
index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=None,
coords=self.coords,
dims=self.pred_dims,
attrs=self.attrs,
attrs=predictions_warmup_attrs,
index_origin=self.index_origin,
),
)
Expand All @@ -242,13 +249,13 @@ def prior_to_xarray(self):
data = self.prior
if not isinstance(data, dict):
raise TypeError("DictConverter.prior is not a dictionary")

prior_attrs = self._kwargs.get("prior_attrs")
return dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
attrs=prior_attrs,
index_origin=self.index_origin,
)

Expand All @@ -258,13 +265,13 @@ def sample_stats_prior_to_xarray(self):
data = self.sample_stats_prior
if not isinstance(data, dict):
raise TypeError("DictConverter.sample_stats_prior is not a dictionary")

sample_stats_prior_attrs = self._kwargs.get("sample_stats_prior_attrs")
return dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
attrs=sample_stats_prior_attrs,
index_origin=self.index_origin,
)

Expand All @@ -274,13 +281,13 @@ def prior_predictive_to_xarray(self):
data = self.prior_predictive
if not isinstance(data, dict):
raise TypeError("DictConverter.prior_predictive is not a dictionary")

prior_predictive_attrs = self._kwargs.get("prior_predictive_attrs")
return dict_to_dataset(
data,
library=None,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
attrs=prior_predictive_attrs,
index_origin=self.index_origin,
)

Expand Down Expand Up @@ -337,6 +344,7 @@ def to_inference_data(self):
"constant_data": self.constant_data_to_xarray(),
"predictions_constant_data": self.predictions_constant_data_to_xarray(),
"save_warmup": self.save_warmup,
"attrs": self.attrs,
}
)

Expand Down Expand Up @@ -367,6 +375,7 @@ def from_dict(
pred_dims=None,
pred_coords=None,
attrs=None,
**kwargs,
):
"""Convert Dictionary data into an InferenceData object.
Expand Down Expand Up @@ -406,6 +415,18 @@ def from_dict(
A mapping from variables to a list of coordinate values for predictions.
attrs : dict
A dictionary containing attributes for different groups.
kwargs : dict
A dictionary containing group attrs.
Accepted kwargs are:
- posterior_attrs, posterior_warmup_attrs : attrs for posterior group
- sample_stats_attrs, sample_stats_warmup_attrs : attrs for sample_stats group
- log_likelihood_attrs, log_likelihood_warmup_attrs : attrs for log_likelihood group
- posterior_predictive_attrs, posterior_predictive_warmup_attrs : attrs for
posterior_predictive group
- predictions_attrs, predictions_warmup_attrs : attrs for predictions group
- prior_attrs : attrs for prior group
- sample_stats_prior_attrs : attrs for sample_stats_prior group
- prior_predictive_attrs : attrs for prior_predictive group
Returns
-------
Expand Down Expand Up @@ -435,4 +456,5 @@ def from_dict(
pred_dims=pred_dims,
pred_coords=pred_coords,
attrs=attrs,
**kwargs,
).to_inference_data()
4 changes: 2 additions & 2 deletions arvizrc.template
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ data.index_origin : 0 # index origin, must be either 0 or 1
data.load : lazy # Sets the default data loading mode.
# "lazy" stands for xarray lazy loading,
# "eager" loads all datasets into memory
data.log_likelihood : true # save pointwise log likelihood values, one of "true", "false"
data.log_likelihood : true # save pointwise log likelihood values, one of "true", "false"
data.metagroups : {
posterior_groups: posterior, posterior_predictive, sample_stats, log_likelihood
prior_groups: prior, prior_predictive, sample_stats_prior
posterior_groups_warmup: _warmup_posterior, _warmup_posterior_predictive, _warmup_sample_stats
latent_vars: posterior, prior
observed_vars: posterior_predictive, observed_data, prior_predictive
}
data.save_warmup : false # save warmup iterations, one of "true", "false"
data.save_warmup : false # save warmup iterations, one of "true", "false"

### PLOT ###
# rcParams related with plotting functions
Expand Down

0 comments on commit 4544dd0

Please sign in to comment.