Skip to content

Commit

Permalink
From dict function (#524)
Browse files Browse the repository at this point in the history
* add io_dict

* add dict tests

* add tests for edge cases
  • Loading branch information
ahartikainen authored Jan 13, 2019
1 parent da7f066 commit e88e1bf
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 0 deletions.
2 changes: 2 additions & 0 deletions arviz/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .base import numpy_to_data_array, dict_to_dataset
from .converters import convert_to_dataset, convert_to_inference_data
from .io_cmdstan import from_cmdstan
from .io_dict import from_dict
from .io_pymc3 import from_pymc3
from .io_pystan import from_pystan
from .io_emcee import from_emcee
Expand All @@ -24,6 +25,7 @@
"from_pystan",
"from_emcee",
"from_cmdstan",
"from_dict",
"from_pyro",
"from_tfp",
"from_netcdf",
Expand Down
182 changes: 182 additions & 0 deletions arviz/data/io_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""Dictionary specific conversion code."""
import logging

import numpy as np
import xarray as xr

from .inference_data import InferenceData
from .base import requires, dict_to_dataset, generate_dims_coords, make_attrs

_log = logging.getLogger("arviz")


class DictConverter:
"""Encapsulate Dictionary specific logic."""

def __init__(
self,
*,
posterior=None,
posterior_predictive=None,
sample_stats=None,
prior=None,
prior_predictive=None,
sample_stats_prior=None,
observed_data=None,
coords=None,
dims=None,
):
self.posterior = posterior
self.posterior_predictive = posterior_predictive
self.sample_stats = sample_stats
self.prior = prior
self.prior_predictive = prior_predictive
self.sample_stats_prior = sample_stats_prior
self.observed_data = observed_data
self.coords = coords
self.dims = dims

@requires("posterior")
def posterior_to_xarray(self):
"""Convert posterior samples to xarray."""
data = self.posterior
if not isinstance(data, dict):
raise TypeError("DictConverter.posterior is not a dictionary")

if "log_likelihood" in data:
_log.warning(
"log_likelihood found in posterior."
" For stats functions log_likelihood needs to be in sample_stats."
)

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)

@requires("sample_stats")
def sample_stats_to_xarray(self):
"""Convert sample_stats samples to xarray."""
data = self.sample_stats
if not isinstance(data, dict):
raise TypeError("DictConverter.sample_stats is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)

@requires("posterior_predictive")
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
data = self.posterior_predictive
if not isinstance(data, dict):
raise TypeError("DictConverter.posterior_predictive is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)

@requires("prior")
def prior_to_xarray(self):
"""Convert prior samples to xarray."""
data = self.prior
if not isinstance(data, dict):
raise TypeError("DictConverter.prior is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)

@requires("sample_stats_prior")
def sample_stats_prior_to_xarray(self):
"""Convert sample_stats_prior samples to xarray."""
data = self.sample_stats_prior
if not isinstance(data, dict):
raise TypeError("DictConverter.sample_stats_prior is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)

@requires("prior_predictive")
def prior_predictive_to_xarray(self):
"""Convert prior_predictive samples to xarray."""
data = self.prior_predictive
if not isinstance(data, dict):
raise TypeError("DictConverter.prior_predictive is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)

@requires("observed_data")
def observed_data_to_xarray(self):
"""Convert observed_data to xarray."""
data = self.observed_data
if not isinstance(data, dict):
raise TypeError("DictConverter.observed_data is not a dictionary")
if self.dims is None:
dims = {}
else:
dims = self.dims
observed_data = dict()
for key, vals in data.items():
vals = np.atleast_1d(vals)
val_dims = dims.get(key)
val_dims, coords = generate_dims_coords(
vals.shape, key, dims=val_dims, coords=self.coords
)
observed_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=observed_data, attrs=make_attrs(library=None))

def to_inference_data(self):
"""Convert all available data to an InferenceData object.
Note that if groups can not be created, then the InferenceData
will not have those groups.
"""
return InferenceData(
**{
"posterior": self.posterior_to_xarray(),
"sample_stats": self.sample_stats_to_xarray(),
"posterior_predictive": self.posterior_predictive_to_xarray(),
"prior": self.prior_to_xarray(),
"sample_stats_prior": self.sample_stats_prior_to_xarray(),
"prior_predictive": self.prior_predictive_to_xarray(),
"observed_data": self.observed_data_to_xarray(),
}
)


# pylint disable=too-many-instance-attributes
def from_dict(
posterior=None,
*,
posterior_predictive=None,
sample_stats=None,
prior=None,
prior_predictive=None,
sample_stats_prior=None,
observed_data=None,
coords=None,
dims=None,
):
"""Convert Dictionary data into an InferenceData object.
Parameters
----------
posterior : dict
posterior_predictive : dict
sample_stats : dict
"log_likelihood" variable for stats needs to be here.
prior : dict
prior_predictive : dict
observed_data : dict
coords : dict[str, iterable]
A dictionary containing the values that are used as index. The key
is the name of the dimension, the values are the index values.
dims : dict[str, List(str)]
A mapping from variables to a list of coordinate names for the variable.
Returns
-------
InferenceData object
"""
return DictConverter(
posterior=posterior,
posterior_predictive=posterior_predictive,
sample_stats=sample_stats,
prior=prior,
prior_predictive=prior_predictive,
sample_stats_prior=sample_stats_prior,
observed_data=observed_data,
coords=coords,
dims=dims,
).to_inference_data()
90 changes: 90 additions & 0 deletions arviz/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
convert_to_inference_data,
convert_to_dataset,
from_cmdstan,
from_dict,
from_pymc3,
from_pystan,
from_pyro,
Expand Down Expand Up @@ -387,6 +388,95 @@ def test_convert_to_dataset(self, eight_schools_params, draws, chains, data):
assert dataset.theta.shape == (chains, draws, eight_schools_params["J"])


class TestDictIONetCDFUtils:
@pytest.fixture(scope="class")
def data(self, eight_schools_params, draws, chains):
# Data of the Eight Schools Model

class Data:
_, stan_fit = load_cached_models(eight_schools_params, draws, chains)["pystan"]
if pystan_version() == 2:
stan_dict = pystan_extract_unpermuted(stan_fit)
obj = {}
for name, vals in stan_dict.items():
if name not in {"y_hat", "log_lik"}: # extra vars
obj[name] = np.swapaxes(vals, 0, 1)
else:
stan_dict = stan_extract_dict(stan_fit)
obj = {}
for name, vals in stan_dict.items():
if name not in {"y_hat", "log_lik"}: # extra vars
obj[name] = vals

return Data

def check_var_names_coords_dims(self, dataset):
assert set(dataset.data_vars) == {"mu", "tau", "eta", "theta"}
assert set(dataset.coords) == {"chain", "draw", "school"}

def get_inference_data(self, data, eight_schools_params):
return from_dict(
posterior=data.obj,
posterior_predictive=data.obj,
sample_stats=data.obj,
prior=data.obj,
prior_predictive=data.obj,
sample_stats_prior=data.obj,
observed_data=eight_schools_params,
coords={"school": np.arange(eight_schools_params["J"])},
dims={"theta": ["school"], "eta": ["school"]},
)

def test_inference_data(self, data, eight_schools_params):
inference_data = self.get_inference_data(data, eight_schools_params)
assert hasattr(inference_data, "posterior")
assert hasattr(inference_data, "posterior_predictive")
assert hasattr(inference_data, "sample_stats")
assert hasattr(inference_data, "prior")
assert hasattr(inference_data, "prior_predictive")
assert hasattr(inference_data, "sample_stats_prior")
assert hasattr(inference_data, "observed_data")
self.check_var_names_coords_dims(inference_data.posterior)
self.check_var_names_coords_dims(inference_data.posterior_predictive)
self.check_var_names_coords_dims(inference_data.sample_stats)
self.check_var_names_coords_dims(inference_data.prior)
self.check_var_names_coords_dims(inference_data.prior_predictive)
self.check_var_names_coords_dims(inference_data.sample_stats_prior)

def test_inference_data_edge_cases(self):
# create data
log_likelihood = {
"y": np.random.randn(4, 100),
"log_likelihood": np.random.randn(4, 100, 8),
}

# log_likelihood to posterior
assert from_dict(posterior=log_likelihood) is not None

# dims == None
assert from_dict(observed_data=log_likelihood, dims=None) is not None

def test_inference_data_bad(self):
# create data
x = np.random.randn(4, 100)

# input ndarray
with pytest.raises(TypeError):
from_dict(posterior=x)
with pytest.raises(TypeError):
from_dict(posterior_predictive=x)
with pytest.raises(TypeError):
from_dict(sample_stats=x)
with pytest.raises(TypeError):
from_dict(prior=x)
with pytest.raises(TypeError):
from_dict(prior_predictive=x)
with pytest.raises(TypeError):
from_dict(sample_stats_prior=x)
with pytest.raises(TypeError):
from_dict(observed_data=x)


class TestEmceeNetCDFUtils:
@pytest.fixture(scope="class")
def data(self, draws, chains):
Expand Down

0 comments on commit e88e1bf

Please sign in to comment.