Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

from_cmdstanpy - fix bug introduced by refactor (PR 1558) #1564

Merged
merged 13 commits into from
Feb 16, 2021
Merged
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ disable=missing-docstring,
unsubscriptable-object,
cyclic-import,
ungrouped-imports,

not-an-iterable,
no-member,
#TODO: Remove this once todos are done
fixme

Expand Down
13 changes: 10 additions & 3 deletions arviz/data/io_cmdstanpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,10 +494,17 @@ def _unpack_fit(fit, items, save_warmup):
else:
raise ValueError("fit data, unknown variable: {}".format(item))
if save_warmup:
sample_warmup[item] = draws[:num_warmup, :, col_idxs]
sample[item] = draws[num_warmup:, :, col_idxs]
if len(col_idxs) == 1:
sample_warmup[item] = np.squeeze(draws[:num_warmup, :, col_idxs], axis=2)
sample[item] = np.squeeze(draws[num_warmup:, :, col_idxs], axis=2)
else:
sample_warmup[item] = draws[:num_warmup, :, col_idxs]
sample[item] = draws[num_warmup:, :, col_idxs]
else:
sample[item] = draws[:, :, col_idxs]
if len(col_idxs) == 1:
sample[item] = np.squeeze(draws[:, :, col_idxs], axis=2)
else:
sample[item] = draws[:, :, col_idxs]

return sample, sample_warmup

Expand Down
7 changes: 6 additions & 1 deletion arviz/tests/external_tests/test_data_cmdstanpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def test_sampler_stats(self, data, eight_schools_params):
test_dict = {"sample_stats": ["lp", "diverging"]}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails
assert len(inference_data.sample_stats.lp.shape) == 2 # pylint: disable=no-member

def test_inference_data(self, data, eight_schools_params):
inference_data1 = self.get_inference_data(data, eight_schools_params)
Expand Down Expand Up @@ -354,6 +355,8 @@ def test_inference_data(self, data, eight_schools_params):
test_dict = {"posterior": ["theta"], "prior": ["theta"]}
fails = check_multiple_attrs(test_dict, inference_data4)
assert not fails
assert len(inference_data4.posterior.theta.shape) == 3 # pylint: disable=no-member
assert len(inference_data4.posterior.mu.shape) == 2 # pylint: disable=no-member

def test_inference_data_warmup(self, data, eight_schools_params):
inference_data_true_is_true = self.get_inference_data_warmup_true_is_true(
Expand Down Expand Up @@ -429,4 +432,6 @@ def test_inference_data_warmup(self, data, eight_schools_params):
assert "warmup_posterior" not in inference_data_false_is_false
assert "warmup_predictions" not in inference_data_false_is_false
assert "warmup_log_likelihood" not in inference_data_false_is_false
assert "warmup_prior" not in inference_data_false_is_false
assert (
"warmup_prior" not in inference_data_false_is_false
) # pylint: disable=redefined-outer-name