From 8ec06172bca687111851e5b7b5d169ea991f50bd Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 28 Nov 2022 13:51:18 -0700 Subject: [PATCH] Consolidate validation of expected_groups (#193) * Consolidate validation of expected_groups * Add tests * Switch by to nby Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * Type expected_groups properly Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- flox/core.py | 49 +++++++++++++++++++++++++++++++++----------- flox/xarray.py | 11 +++------- tests/test_core.py | 20 +++++++++++++++++- tests/test_xarray.py | 14 +++++++++++-- 4 files changed, 71 insertions(+), 23 deletions(-) diff --git a/flox/core.py b/flox/core.py index 728d8b186..2463b7dc6 100644 --- a/flox/core.py +++ b/flox/core.py @@ -28,7 +28,9 @@ if TYPE_CHECKING: import dask.array.Array as DaskArray - T_ExpectedGroups = Union[Sequence, np.ndarray, pd.Index] + T_Expect = Union[Sequence, np.ndarray, pd.Index, None] + T_ExpectTuple = tuple[T_Expect, ...] + T_ExpectedGroups = Union[T_Expect, T_ExpectTuple] T_ExpectedGroupsOpt = Union[T_ExpectedGroups, None] T_Func = Union[str, Callable] T_Funcs = Union[T_Func, Sequence[T_Func]] @@ -1476,7 +1478,7 @@ def _assert_by_is_aligned(shape, by): def _convert_expected_groups_to_index( - expected_groups: T_ExpectedGroups, isbin: Sequence[bool], sort: bool + expected_groups: T_ExpectTuple, isbin: Sequence[bool], sort: bool ) -> tuple[pd.Index | None, ...]: out: list[pd.Index | None] = [] for ex, isbin_ in zip(expected_groups, isbin): @@ -1543,6 +1545,36 @@ def _factorize_multiple(by, expected_groups, any_by_dask, reindex): return (group_idx,), final_groups, grp_shape +def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) -> T_ExpectTuple: + + if expected_groups is None: + return (None,) * nby + + if nby == 1 and not isinstance(expected_groups, tuple): + return (np.asarray(expected_groups),) + + if nby > 1 and not isinstance(expected_groups, tuple): # TODO: test for list + raise ValueError( + "When grouping by multiple variables, expected_groups must be a tuple " + "of either arrays or objects convertible to an array (like lists). " + "For example `expected_groups=(np.array([1, 2, 3]), ['a', 'b', 'c'])`." + f"Received a {type(expected_groups).__name__} instead. " + "When grouping by a single variable, you can pass an array or something " + "convertible to an array for convenience: `expected_groups=['a', 'b', 'c']`." + ) + + if TYPE_CHECKING: + assert isinstance(expected_groups, tuple) + + if len(expected_groups) != nby: + raise ValueError( + f"Must have same number of `expected_groups` (received {len(expected_groups)}) " + f" and variables to group by (received {nby})." + ) + + return expected_groups + + def groupby_reduce( array: np.ndarray | DaskArray, *by: np.ndarray | DaskArray, @@ -1679,24 +1711,17 @@ def groupby_reduce( isbins = isbin else: isbins = (isbin,) * nby - if expected_groups is None: - expected_groups = (None,) * nby _assert_by_is_aligned(array.shape, bys) + + expected_groups = _validate_expected_groups(nby, expected_groups) + for idx, (expect, is_dask) in enumerate(zip(expected_groups, by_is_dask)): if is_dask and (reindex or nby > 1) and expect is None: raise ValueError( f"`expected_groups` for array {idx} in `by` cannot be None since it is a dask.array." ) - if nby == 1 and not isinstance(expected_groups, tuple): - expected_groups = (np.asarray(expected_groups),) - elif len(expected_groups) != nby: - raise ValueError( - f"Must have same number of `expected_groups` (received {len(expected_groups)}) " - f" and variables to group by (received {nby})." - ) - # We convert to pd.Index since that lets us know if we are binning or not # (pd.IntervalIndex or not) expected_groups = _convert_expected_groups_to_index(expected_groups, isbins, sort) diff --git a/flox/xarray.py b/flox/xarray.py index 2a2b27bdc..314cf39a4 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -13,6 +13,7 @@ from .core import ( _convert_expected_groups_to_index, _get_expected_groups, + _validate_expected_groups, groupby_reduce, rechunk_for_blockwise as rechunk_array_for_blockwise, rechunk_for_cohorts as rechunk_array_for_cohorts, @@ -216,16 +217,10 @@ def xarray_reduce( else: isbins = (isbin,) * nby - if expected_groups is None: - expected_groups = (None,) * nby - if isinstance(expected_groups, (np.ndarray, list)): # TODO: test for list - if nby == 1: - expected_groups = (expected_groups,) - else: - raise ValueError("Needs better message.") + expected_groups = _validate_expected_groups(nby, expected_groups) if not sort: - raise NotImplementedError + raise NotImplementedError("sort must be True for xarray_reduce") # eventually drop the variables we are grouping by maybe_drop = [b for b in by if isinstance(b, Hashable)] diff --git a/tests/test_core.py b/tests/test_core.py index 3270e6151..8249927cc 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1019,8 +1019,26 @@ def test_multiple_groupers(chunk, by1, by2, expected_groups) -> None: assert_equal(expected, actual) +@pytest.mark.parametrize( + "expected_groups", + ( + [None, None, None], + (None,), + ), +) +def test_validate_expected_groups(expected_groups): + with pytest.raises(ValueError): + groupby_reduce( + np.ones((10,)), + np.ones((10,)), + np.ones((10,)), + expected_groups=expected_groups, + func="mean", + ) + + @requires_dask -def test_multiple_groupers_errors() -> None: +def test_validate_expected_groups_not_none_dask() -> None: with pytest.raises(ValueError): groupby_reduce( dask.array.ones((5, 2)), diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 42e6ee17f..226867b1a 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -168,12 +168,22 @@ def test_xarray_reduce_multiple_groupers_2(pass_expected_groups, chunk, engine): @requires_dask -def test_dask_groupers_error(): +@pytest.mark.parametrize( + "expected_groups", + (None, (None, None), [[1, 2], [1, 2]]), +) +def test_validate_expected_groups(expected_groups): da = xr.DataArray( [1.0, 2.0], dims="x", coords={"labels": ("x", [1, 2]), "labels2": ("x", [1, 2])} ) with pytest.raises(ValueError): - xarray_reduce(da.chunk({"x": 2, "z": 1}), "labels", "labels2", func="count") + xarray_reduce( + da.chunk({"x": 1}), + "labels", + "labels2", + func="count", + expected_groups=expected_groups, + ) @requires_dask