Skip to content

Commit

Permalink
validate filter_vars (#1772)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Aug 30, 2021
1 parent bb2af3b commit 956a8cc
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
* Fixed xarray related tests. ([1726](https://github.com/arviz-devs/arviz/pull/1726))
* Fix Bokeh deprecation warnings ([1657](https://github.com/arviz-devs/arviz/pull/1657))
* Fix credible inteval percentage in legend in `plot_loo_pit` ([1745](https://github.com/arviz-devs/arviz/pull/1745))
* Arguments `filter_vars` and `filter_groups` now raise `ValueError` if illegal arguments are passed ([1772](https://github.com/arviz-devs/arviz/pull/1772))

### Deprecation
* Deprecated `index_origin` and `order` arguments in `az.summary` ([1201](https://github.com/arviz-devs/arviz/pull/1201))
Expand Down
5 changes: 5 additions & 0 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,11 @@ def _group_names(
-------
groups: list
"""
if filter_groups not in {None, "like", "regex"}:
raise ValueError(
f"'filter_groups' can only be None, 'like', or 'regex', got: '{filter_groups}'"
)

all_groups = self._groups_all
if groups is None:
return all_groups
Expand Down
9 changes: 9 additions & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,15 @@ def test_group_names(self, args_res):
group_names = idata._group_names(*args) # pylint: disable=protected-access
assert np.all([name in result for name in group_names])

def test_group_names_invalid_args(self):
ds = dict_to_dataset({"a": np.random.normal(size=(3, 10))})
idata = InferenceData(posterior=(ds, ds))
msg = r"^\'filter_groups\' can only be None, \'like\', or \'regex\', got: 'foo'$"
with pytest.raises(ValueError, match=msg):
idata._group_names( # pylint: disable=protected-access
("posterior",), filter_groups="foo"
)

@pytest.mark.parametrize("inplace", [False, True])
def test_isel(self, data_random, inplace):
idata = data_random
Expand Down
9 changes: 9 additions & 0 deletions arviz/tests/base_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ def test_var_names_filter(var_args):
assert _var_names(var_names, data, filter_vars) == expected


def test_var_names_filter_invalid_argument():
"""Check invalid argument raises."""
samples = np.random.randn(10)
data = dict_to_dataset({"alpha": samples})
msg = r"^\'filter_vars\' can only be None, \'like\', or \'regex\', got: 'foo'$"
with pytest.raises(ValueError, match=msg):
assert _var_names(["alpha"], data, filter_vars="foo")


def test_subset_list_negation_not_found():
"""Check there is a warning if negation pattern is ignored"""
names = ["mu", "theta"]
Expand Down
5 changes: 5 additions & 0 deletions arviz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def _var_names(var_names, data, filter_vars=None):
-------
var_name: list or None
"""
if filter_vars not in {None, "like", "regex"}:
raise ValueError(
f"'filter_vars' can only be None, 'like', or 'regex', got: '{filter_vars}'"
)

if var_names is not None:
if isinstance(data, (list, tuple)):
all_vars = []
Expand Down

0 comments on commit 956a8cc

Please sign in to comment.