Skip to content

Commit

Permalink
Disable dims, default_dims, and index_origin options until arviz > v0…
Browse files Browse the repository at this point in the history
….11.2
  • Loading branch information
brandonwillard committed Mar 26, 2021
1 parent 78ff887 commit e01a473
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions pymc3/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,15 @@ def posterior_to_xarray(self):
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
index_origin=self.index_origin,
# index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=pymc3,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
index_origin=self.index_origin,
# index_origin=self.index_origin,
),
)

Expand Down Expand Up @@ -344,15 +344,15 @@ def sample_stats_to_xarray(self):
dims=None,
coords=self.coords,
attrs=self.attrs,
index_origin=self.index_origin,
# index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=pymc3,
dims=None,
coords=self.coords,
attrs=self.attrs,
index_origin=self.index_origin,
# index_origin=self.index_origin,
),
)

Expand Down Expand Up @@ -385,15 +385,15 @@ def log_likelihood_to_xarray(self):
dims=self.dims,
coords=self.coords,
skip_event_dims=True,
index_origin=self.index_origin,
# index_origin=self.index_origin,
),
dict_to_dataset(
data_warmup,
library=pymc3,
dims=self.dims,
coords=self.coords,
skip_event_dims=True,
index_origin=self.index_origin,
# index_origin=self.index_origin,
),
)

Expand All @@ -415,7 +415,11 @@ def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset:
k,
)
return dict_to_dataset(
data, library=pymc3, coords=self.coords, dims=self.dims, index_origin=self.index_origin
data,
library=pymc3,
coords=self.coords,
# dims=self.dims,
# index_origin=self.index_origin
)

@requires(["posterior_predictive"])
Expand Down Expand Up @@ -450,8 +454,8 @@ def priors_to_xarray(self):
{k: np.expand_dims(self.prior[k], 0) for k in var_names},
library=pymc3,
coords=self.coords,
dims=self.dims,
index_origin=self.index_origin,
# dims=self.dims,
# index_origin=self.index_origin,
)
)
return priors_dict
Expand All @@ -466,9 +470,9 @@ def observed_data_to_xarray(self):
{**self.observations, **self.multi_observations},
library=pymc3,
coords=self.coords,
dims=self.dims,
default_dims=[],
index_origin=self.index_origin,
# dims=self.dims,
# default_dims=[],
# index_origin=self.index_origin,
)

@requires(["trace", "predictions"])
Expand Down Expand Up @@ -513,9 +517,9 @@ def is_data(name, var) -> bool:
constant_data,
library=pymc3,
coords=self.coords,
dims=self.dims,
default_dims=[],
index_origin=self.index_origin,
# dims=self.dims,
# default_dims=[],
# index_origin=self.index_origin,
)

def to_inference_data(self):
Expand Down

0 comments on commit e01a473

Please sign in to comment.