Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add new groups to from_numpyro and from_dict #1125

Merged
merged 8 commits into from
Apr 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
* Revamped the `hpd` function to make it work with mutidimensional arrays, InferenceData and xarray objects (#1117)
* Skip test for optional/extra dependencies when not installed (#1113)
* Add option to display rank plots instead of trace (#1134)

* Add out-of-sample groups (`predictions` and `predictions_constant_data`) to `from_dict` (#1125)
* Add out-of-sample groups (`predictions` and `predictions_constant_data`) and `constant_data` group to pyro and numpyro translation (#1090, #1125)
* Add `num_chains` and `pred_dims` arguments to from_pyro and from_numpyro (#1090, #1125)
* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079)
### Maintenance and fixes
* Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115)
* Fixed hist kind of `plot_dist` with multidimensional input (#1115)
Expand Down
65 changes: 41 additions & 24 deletions arviz/data/io_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,29 @@ def __init__(
*,
posterior=None,
posterior_predictive=None,
predictions=None,
sample_stats=None,
log_likelihood=None,
prior=None,
prior_predictive=None,
sample_stats_prior=None,
observed_data=None,
constant_data=None,
predictions_constant_data=None,
coords=None,
dims=None
):
self.posterior = posterior
self.posterior_predictive = posterior_predictive
self.predictions = predictions
self.sample_stats = sample_stats
self.log_likelihood = log_likelihood
self.prior = prior
self.prior_predictive = prior_predictive
self.sample_stats_prior = sample_stats_prior
self.observed_data = observed_data
self.constant_data = constant_data
self.predictions_constant_data = predictions_constant_data
self.coords = coords
self.dims = dims

Expand Down Expand Up @@ -89,6 +93,15 @@ def posterior_predictive_to_xarray(self):

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)

@requires("predictions")
def predictions_to_xarray(self):
"""Convert predictions to xarray."""
data = self.predictions
if not isinstance(data, dict):
raise TypeError("DictConverter.predictions is not a dictionary")

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)

@requires("prior")
def prior_to_xarray(self):
"""Convert prior samples to xarray."""
Expand Down Expand Up @@ -116,45 +129,41 @@ def prior_predictive_to_xarray(self):

return dict_to_dataset(data, library=None, coords=self.coords, dims=self.dims)

@requires("observed_data")
def observed_data_to_xarray(self):
"""Convert observed_data to xarray."""
data = self.observed_data
def data_to_xarray(self, dct, group):
"""Convert data to xarray."""
data = dct
if not isinstance(data, dict):
raise TypeError("DictConverter.observed_data is not a dictionary")
raise TypeError("DictConverter.{} is not a dictionary".format(group))
if self.dims is None:
dims = {}
else:
dims = self.dims
observed_data = dict()
new_data = dict()
for key, vals in data.items():
vals = utils.one_de(vals)
val_dims = dims.get(key)
val_dims, coords = generate_dims_coords(
vals.shape, key, dims=val_dims, coords=self.coords
)
observed_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=observed_data, attrs=make_attrs(library=None))
new_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=new_data, attrs=make_attrs(library=None))

@requires("observed_data")
def observed_data_to_xarray(self):
"""Convert observed_data to xarray."""
return self.data_to_xarray(self.observed_data, group="observed_data")

@requires("constant_data")
def constant_data_to_xarray(self):
"""Convert constant_data to xarray."""
data = self.constant_data
if not isinstance(data, dict):
raise TypeError("DictConverter.constant_data is not a dictionary")
if self.dims is None:
dims = {}
else:
dims = self.dims
constant_data = dict()
for key, vals in data.items():
vals = utils.one_de(vals)
val_dims = dims.get(key)
val_dims, coords = generate_dims_coords(
vals.shape, key, dims=val_dims, coords=self.coords
)
constant_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=constant_data, attrs=make_attrs(library=None))
return self.data_to_xarray(self.constant_data, group="constant_data")

@requires("predictions_constant_data")
def predictions_constant_data_to_xarray(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is not too much to ask, no pressure, could you do the same trick here as with constant data and predictions constant data in from_pyro and from_numpyro? If I am not mistaken, the code here is actually shared not between the 2 constant data group but also with observed data group, it would significantly reduce code duplication

Copy link
Contributor Author

@nitishp25 nitishp25 Mar 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course! It makes much more sense. Thanks! Should I also do this for other groups like posterior_predictive, predictions, prior_predictive etc?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be done, I am not sure it is worth it though, they are already 3 lines long, the bulk of the code is already the external function dict_to_dataset, the only difference between them is the error.

"""Convert predictions_constant_data to xarray."""
return self.data_to_xarray(
self.predictions_constant_data, group="predictions_constant_data"
)

def to_inference_data(self):
"""Convert all available data to an InferenceData object.
Expand All @@ -168,11 +177,13 @@ def to_inference_data(self):
"sample_stats": self.sample_stats_to_xarray(),
"log_likelihood": self.log_likelihood_to_xarray(),
"posterior_predictive": self.posterior_predictive_to_xarray(),
"predictions": self.predictions_to_xarray(),
"prior": self.prior_to_xarray(),
"sample_stats_prior": self.sample_stats_prior_to_xarray(),
"prior_predictive": self.prior_predictive_to_xarray(),
"observed_data": self.observed_data_to_xarray(),
"constant_data": self.constant_data_to_xarray(),
"predictions_constant_data": self.predictions_constant_data_to_xarray(),
}
)

Expand All @@ -182,13 +193,15 @@ def from_dict(
posterior=None,
*,
posterior_predictive=None,
predictions=None,
sample_stats=None,
log_likelihood=None,
prior=None,
prior_predictive=None,
sample_stats_prior=None,
observed_data=None,
constant_data=None,
predictions_constant_data=None,
coords=None,
dims=None
):
Expand All @@ -198,13 +211,15 @@ def from_dict(
----------
posterior : dict
posterior_predictive : dict
predictions: dict
sample_stats : dict
log_likelihood : dict
For stats functions, log likelihood data should be stored here.
prior : dict
prior_predictive : dict
observed_data : dict
constant_data : dict
predictions_constant_data: dict
coords : dict[str, iterable]
A dictionary containing the values that are used as index. The key
is the name of the dimension, the values are the index values.
Expand All @@ -218,13 +233,15 @@ def from_dict(
return DictConverter(
posterior=posterior,
posterior_predictive=posterior_predictive,
predictions=predictions,
sample_stats=sample_stats,
log_likelihood=log_likelihood,
prior=prior,
prior_predictive=prior_predictive,
sample_stats_prior=sample_stats_prior,
observed_data=observed_data,
constant_data=constant_data,
predictions_constant_data=predictions_constant_data,
coords=coords,
dims=dims,
).to_inference_data()
126 changes: 118 additions & 8 deletions arviz/data/io_numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,18 @@ class NumPyroConverter:
ndraws = None # type: int

def __init__(
self, *, posterior=None, prior=None, posterior_predictive=None, coords=None, dims=None
self,
*,
posterior=None,
prior=None,
posterior_predictive=None,
predictions=None,
constant_data=None,
predictions_constant_data=None,
coords=None,
dims=None,
pred_dims=None,
num_chains=1
):
"""Convert NumPyro data into an InferenceData object.

Expand All @@ -32,21 +43,38 @@ def __init__(
Prior samples from a NumPyro model
posterior_predictive : dict
Posterior predictive samples for the posterior
predictions: dict
Out of sample predictions
constant_data: dict
Dictionary containing constant data variables mapped to their values.
predictions_constant_data: dict
Constant data used for out-of-sample predictions.
coords : dict[str] -> list[str]
Map of dimensions to coordinates
dims : dict[str] -> list[str]
Map variable names to their coordinates
pred_dims: dict
Dims for predictions data. Map variable names to their coordinates.
num_chains: int
Number of chains used for sampling. Ignored if posterior is present.
"""
import jax
import numpyro

self.posterior = posterior
self.prior = jax.device_get(prior)
self.posterior_predictive = jax.device_get(posterior_predictive)
self.predictions = predictions
self.constant_data = constant_data
self.predictions_constant_data = predictions_constant_data
self.coords = coords
self.dims = dims
self.pred_dims = pred_dims
self.numpyro = numpyro

def arbitrary_element(dct):
return next(iter(dct.values()))

if posterior is not None:
samples = jax.device_get(self.posterior.get_samples(group_by_chain=True))
if not isinstance(samples, dict):
Expand All @@ -65,7 +93,22 @@ def __init__(
self._args = self.posterior._args # pylint: disable=protected-access
self._kwargs = self.posterior._kwargs # pylint: disable=protected-access
else:
self.nchains = self.ndraws = 0
self.nchains = num_chains
get_from = None
if predictions is not None:
get_from = predictions
elif posterior_predictive is not None:
get_from = posterior_predictive
elif prior is not None:
get_from = prior
if get_from is None and constant_data is None and predictions_constant_data is None:
raise ValueError(
"When constructing InferenceData must have at least"
" one of posterior, prior, posterior_predictive or predictions."
)
if get_from is not None:
aelem = arbitrary_element(get_from)
self.ndraws = aelem.shape[0] // self.nchains

observations = {}
if self.model is not None:
Expand Down Expand Up @@ -120,11 +163,10 @@ def log_likelihood_to_xarray(self):
data[obs_name] = np.reshape(log_like.copy(), shape)
return dict_to_dataset(data, library=self.numpyro, dims=self.dims, coords=self.coords)

@requires("posterior_predictive")
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
def translate_posterior_predictive_dict_to_xarray(self, dct, dims):
"""Convert posterior_predictive or prediction samples to xarray."""
data = {}
for k, ary in self.posterior_predictive.items():
for k, ary in dct.items():
shape = ary.shape
if shape[0] == self.nchains and shape[1] == self.ndraws:
data[k] = ary
Expand All @@ -136,7 +178,19 @@ def posterior_predictive_to_xarray(self):
"posterior predictive shape not compatible with number of chains and draws. "
"This can mean that some draws or even whole chains are not represented."
)
return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims)
return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=dims)

@requires("posterior_predictive")
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
return self.translate_posterior_predictive_dict_to_xarray(
self.posterior_predictive, self.dims
)

@requires("predictions")
def predictions_to_xarray(self):
"""Convert predictions to xarray."""
return self.translate_posterior_predictive_dict_to_xarray(self.predictions, self.pred_dims)

def priors_to_xarray(self):
"""Convert prior samples (and if possible prior predictive too) to xarray."""
Expand Down Expand Up @@ -184,6 +238,32 @@ def observed_data_to_xarray(self):
observed_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=observed_data, attrs=make_attrs(library=self.numpyro))

def convert_constant_data_to_xarray(self, dct, dims):
"""Convert constant_data or predictions_constant_data to xarray."""
if dims is None:
dims = {}
constant_data = {}
for name, vals in dct.items():
vals = utils.one_de(vals)
val_dims = dims.get(name)
val_dims, coords = generate_dims_coords(
vals.shape, name, dims=val_dims, coords=self.coords
)
# filter coords based on the dims
coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims}
constant_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
return xr.Dataset(data_vars=constant_data, attrs=make_attrs(library=self.numpyro))

@requires("constant_data")
def constant_data_to_xarray(self):
"""Convert constant_data to xarray."""
return self.convert_constant_data_to_xarray(self.constant_data, self.dims)

@requires("predictions_constant_data")
def predictions_constant_data_to_xarray(self):
"""Convert predictions_constant_data to xarray."""
return self.convert_constant_data_to_xarray(self.predictions_constant_data, self.pred_dims)

def to_inference_data(self):
"""Convert all available data to an InferenceData object.

Expand All @@ -197,13 +277,28 @@ def to_inference_data(self):
"sample_stats": self.sample_stats_to_xarray(),
"log_likelihood": self.log_likelihood_to_xarray(),
"posterior_predictive": self.posterior_predictive_to_xarray(),
"predictions": self.predictions_to_xarray(),
**self.priors_to_xarray(),
"observed_data": self.observed_data_to_xarray(),
"constant_data": self.constant_data_to_xarray(),
"predictions_constant_data": self.predictions_constant_data_to_xarray(),
}
)


def from_numpyro(posterior=None, *, prior=None, posterior_predictive=None, coords=None, dims=None):
def from_numpyro(
posterior=None,
*,
prior=None,
posterior_predictive=None,
predictions=None,
constant_data=None,
predictions_constant_data=None,
coords=None,
dims=None,
pred_dims=None,
num_chains=1
):
"""Convert NumPyro data into an InferenceData object.

Parameters
Expand All @@ -214,15 +309,30 @@ def from_numpyro(posterior=None, *, prior=None, posterior_predictive=None, coord
Prior samples from a NumPyro model
posterior_predictive : dict
Posterior predictive samples for the posterior
predictions: dict
Out of sample predictions
constant_data: dict
Dictionary containing constant data variables mapped to their values.
predictions_constant_data: dict
Constant data used for out-of-sample predictions.
coords : dict[str] -> list[str]
Map of dimensions to coordinates
dims : dict[str] -> list[str]
Map variable names to their coordinates
pred_dims: dict
Dims for predictions data. Map variable names to their coordinates.
num_chains: int
Number of chains used for sampling. Ignored if posterior is present.
"""
return NumPyroConverter(
posterior=posterior,
prior=prior,
posterior_predictive=posterior_predictive,
predictions=predictions,
constant_data=constant_data,
predictions_constant_data=predictions_constant_data,
coords=coords,
dims=dims,
pred_dims=pred_dims,
num_chains=num_chains,
).to_inference_data()
Loading