From c717956ca5359ba8c8c93c63b6bc0cdf52aeabd8 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Sun, 14 Feb 2021 17:53:58 -0500 Subject: [PATCH 1/9] squeezing scalar vars --- arviz/data/io_cmdstanpy.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/arviz/data/io_cmdstanpy.py b/arviz/data/io_cmdstanpy.py index 638fba4844..67481bb8ee 100644 --- a/arviz/data/io_cmdstanpy.py +++ b/arviz/data/io_cmdstanpy.py @@ -494,10 +494,17 @@ def _unpack_fit(fit, items, save_warmup): else: raise ValueError("fit data, unknown variable: {}".format(item)) if save_warmup: - sample_warmup[item] = draws[:num_warmup, :, col_idxs] - sample[item] = draws[num_warmup:, :, col_idxs] + if len(col_idxs) == 1: + sample_warmup[item] = np.squeeze(draws[:num_warmup, :, col_idxs], axis=2) + sample[item] = np.squeeze(draws[num_warmup:, :, col_idxs], axis=2) + else: + sample_warmup[item] = draws[:num_warmup, :, col_idxs] + sample[item] = draws[num_warmup:, :, col_idxs] else: - sample[item] = draws[:, :, col_idxs] + if len(col_idxs) == 1: + sample[item] = np.squeeze(draws[:, :, col_idxs], axis=2) + else: + sample[item] = draws[:, :, col_idxs] return sample, sample_warmup From 889a402783cf24a9028b713da13ff08973e6a323 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Sun, 14 Feb 2021 19:00:57 -0500 Subject: [PATCH 2/9] added checks on variable shape --- arviz/tests/external_tests/test_data_cmdstanpy.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/arviz/tests/external_tests/test_data_cmdstanpy.py b/arviz/tests/external_tests/test_data_cmdstanpy.py index bf3e12d9e5..472b436735 100644 --- a/arviz/tests/external_tests/test_data_cmdstanpy.py +++ b/arviz/tests/external_tests/test_data_cmdstanpy.py @@ -308,6 +308,7 @@ def test_sampler_stats(self, data, eight_schools_params): test_dict = {"sample_stats": ["lp", "diverging"]} fails = check_multiple_attrs(test_dict, inference_data) assert not fails + assert len(inference_data.sample_stats.lp.shape) == 2 def test_inference_data(self, data, eight_schools_params): inference_data1 = self.get_inference_data(data, eight_schools_params) @@ -354,6 +355,8 @@ def test_inference_data(self, data, eight_schools_params): test_dict = {"posterior": ["theta"], "prior": ["theta"]} fails = check_multiple_attrs(test_dict, inference_data4) assert not fails + assert len(inference_data4.posterior.theta.shape) == 3 + assert len(inference_data4.posterior.mu.shape) == 2 def test_inference_data_warmup(self, data, eight_schools_params): inference_data_true_is_true = self.get_inference_data_warmup_true_is_true( From d1b556f05a5ceb6ec71c271a617516936bd46d44 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Sun, 14 Feb 2021 22:40:34 -0500 Subject: [PATCH 3/9] lint fix --- arviz/tests/external_tests/test_data_cmdstanpy.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/arviz/tests/external_tests/test_data_cmdstanpy.py b/arviz/tests/external_tests/test_data_cmdstanpy.py index 472b436735..4a09057fce 100644 --- a/arviz/tests/external_tests/test_data_cmdstanpy.py +++ b/arviz/tests/external_tests/test_data_cmdstanpy.py @@ -308,7 +308,7 @@ def test_sampler_stats(self, data, eight_schools_params): test_dict = {"sample_stats": ["lp", "diverging"]} fails = check_multiple_attrs(test_dict, inference_data) assert not fails - assert len(inference_data.sample_stats.lp.shape) == 2 + assert len(inference_data.sample_stats.lp.shape) == 2 # pylint: disable=no-member def test_inference_data(self, data, eight_schools_params): inference_data1 = self.get_inference_data(data, eight_schools_params) @@ -355,8 +355,8 @@ def test_inference_data(self, data, eight_schools_params): test_dict = {"posterior": ["theta"], "prior": ["theta"]} fails = check_multiple_attrs(test_dict, inference_data4) assert not fails - assert len(inference_data4.posterior.theta.shape) == 3 - assert len(inference_data4.posterior.mu.shape) == 2 + assert len(inference_data4.posterior.theta.shape) == 3 # pylint: disable=no-member + assert len(inference_data4.posterior.mu.shape) == 2 # pylint: disable=no-member def test_inference_data_warmup(self, data, eight_schools_params): inference_data_true_is_true = self.get_inference_data_warmup_true_is_true( From f6f6c4e3e5e4a149d001abd4df48373397abb853 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 15 Feb 2021 10:42:40 -0500 Subject: [PATCH 4/9] remove checks on variable shape --- arviz/tests/external_tests/test_data_cmdstanpy.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/arviz/tests/external_tests/test_data_cmdstanpy.py b/arviz/tests/external_tests/test_data_cmdstanpy.py index 4a09057fce..bf3e12d9e5 100644 --- a/arviz/tests/external_tests/test_data_cmdstanpy.py +++ b/arviz/tests/external_tests/test_data_cmdstanpy.py @@ -308,7 +308,6 @@ def test_sampler_stats(self, data, eight_schools_params): test_dict = {"sample_stats": ["lp", "diverging"]} fails = check_multiple_attrs(test_dict, inference_data) assert not fails - assert len(inference_data.sample_stats.lp.shape) == 2 # pylint: disable=no-member def test_inference_data(self, data, eight_schools_params): inference_data1 = self.get_inference_data(data, eight_schools_params) @@ -355,8 +354,6 @@ def test_inference_data(self, data, eight_schools_params): test_dict = {"posterior": ["theta"], "prior": ["theta"]} fails = check_multiple_attrs(test_dict, inference_data4) assert not fails - assert len(inference_data4.posterior.theta.shape) == 3 # pylint: disable=no-member - assert len(inference_data4.posterior.mu.shape) == 2 # pylint: disable=no-member def test_inference_data_warmup(self, data, eight_schools_params): inference_data_true_is_true = self.get_inference_data_warmup_true_is_true( From 2e2fdea2795b4d419ec3c71e32ab5710f069b102 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 15 Feb 2021 13:28:47 -0500 Subject: [PATCH 5/9] adding back shape checks --- arviz/tests/external_tests/test_data_cmdstanpy.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/arviz/tests/external_tests/test_data_cmdstanpy.py b/arviz/tests/external_tests/test_data_cmdstanpy.py index bf3e12d9e5..400b9d1c9f 100644 --- a/arviz/tests/external_tests/test_data_cmdstanpy.py +++ b/arviz/tests/external_tests/test_data_cmdstanpy.py @@ -308,6 +308,7 @@ def test_sampler_stats(self, data, eight_schools_params): test_dict = {"sample_stats": ["lp", "diverging"]} fails = check_multiple_attrs(test_dict, inference_data) assert not fails + assert len(inference_data.sample_stats.lp.shape) == 2 # pylint: disable=no-member def test_inference_data(self, data, eight_schools_params): inference_data1 = self.get_inference_data(data, eight_schools_params) @@ -354,6 +355,8 @@ def test_inference_data(self, data, eight_schools_params): test_dict = {"posterior": ["theta"], "prior": ["theta"]} fails = check_multiple_attrs(test_dict, inference_data4) assert not fails + assert len(inference_data4.posterior.theta.shape) == 3 # pylint: disable=no-member + assert len(inference_data4.posterior.mu.shape) == 2 # pylint: disable=no-member def test_inference_data_warmup(self, data, eight_schools_params): inference_data_true_is_true = self.get_inference_data_warmup_true_is_true( @@ -429,4 +432,4 @@ def test_inference_data_warmup(self, data, eight_schools_params): assert "warmup_posterior" not in inference_data_false_is_false assert "warmup_predictions" not in inference_data_false_is_false assert "warmup_log_likelihood" not in inference_data_false_is_false - assert "warmup_prior" not in inference_data_false_is_false + assert "warmup_prior" not in inference_data_false_is_false# pylint: disable=redefined-outer-name From 9058187e49cf0700971a90b9dcc85ea6f814e02c Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 15 Feb 2021 19:56:27 -0500 Subject: [PATCH 6/9] pylint - diable no-member, not-an-iterable --- .pylintrc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index 54e92986bb..b4b70a2a75 100644 --- a/.pylintrc +++ b/.pylintrc @@ -70,7 +70,9 @@ disable=missing-docstring, unsubscriptable-object, cyclic-import, ungrouped-imports, - + not-an-iterable + no-member + #TODO: Remove this once todos are done fixme From 55778d245c7e27ae31694423ff39e9a65b988462 Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Tue, 16 Feb 2021 00:24:46 +0200 Subject: [PATCH 7/9] black --- arviz/tests/external_tests/test_data_cmdstanpy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/arviz/tests/external_tests/test_data_cmdstanpy.py b/arviz/tests/external_tests/test_data_cmdstanpy.py index 400b9d1c9f..6d64ff480b 100644 --- a/arviz/tests/external_tests/test_data_cmdstanpy.py +++ b/arviz/tests/external_tests/test_data_cmdstanpy.py @@ -432,4 +432,6 @@ def test_inference_data_warmup(self, data, eight_schools_params): assert "warmup_posterior" not in inference_data_false_is_false assert "warmup_predictions" not in inference_data_false_is_false assert "warmup_log_likelihood" not in inference_data_false_is_false - assert "warmup_prior" not in inference_data_false_is_false# pylint: disable=redefined-outer-name + assert ( + "warmup_prior" not in inference_data_false_is_false + ) # pylint: disable=redefined-outer-name From b1986ae85bb1a7627e4dfa4d3a00f3b82cd6c583 Mon Sep 17 00:00:00 2001 From: Oriol Abril-Pla Date: Tue, 16 Feb 2021 03:38:01 +0200 Subject: [PATCH 8/9] Update .pylintrc indentation --- .pylintrc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.pylintrc b/.pylintrc index b4b70a2a75..d82d70fbee 100644 --- a/.pylintrc +++ b/.pylintrc @@ -70,9 +70,8 @@ disable=missing-docstring, unsubscriptable-object, cyclic-import, ungrouped-imports, - not-an-iterable - no-member - + not-an-iterable + no-member #TODO: Remove this once todos are done fixme From f69a024833b0ef4fda176f2b6269798c70884399 Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Tue, 16 Feb 2021 05:44:28 +0200 Subject: [PATCH 9/9] add commas --- .pylintrc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pylintrc b/.pylintrc index d82d70fbee..f629f9ed8d 100644 --- a/.pylintrc +++ b/.pylintrc @@ -70,8 +70,8 @@ disable=missing-docstring, unsubscriptable-object, cyclic-import, ungrouped-imports, - not-an-iterable - no-member + not-an-iterable, + no-member, #TODO: Remove this once todos are done fixme