Skip to content

Commit

Permalink
Consolidate validation of expected_groups (#193)
Browse files Browse the repository at this point in the history
* Consolidate validation of expected_groups

* Add tests

* Switch by to nby

Co-authored-by: Illviljan <[email protected]>

* Type expected_groups properly

Co-authored-by: Illviljan <[email protected]>
  • Loading branch information
dcherian and Illviljan authored Nov 28, 2022
1 parent b58aa5f commit 8ec0617
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 23 deletions.
49 changes: 37 additions & 12 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
if TYPE_CHECKING:
import dask.array.Array as DaskArray

T_ExpectedGroups = Union[Sequence, np.ndarray, pd.Index]
T_Expect = Union[Sequence, np.ndarray, pd.Index, None]
T_ExpectTuple = tuple[T_Expect, ...]
T_ExpectedGroups = Union[T_Expect, T_ExpectTuple]
T_ExpectedGroupsOpt = Union[T_ExpectedGroups, None]
T_Func = Union[str, Callable]
T_Funcs = Union[T_Func, Sequence[T_Func]]
Expand Down Expand Up @@ -1476,7 +1478,7 @@ def _assert_by_is_aligned(shape, by):


def _convert_expected_groups_to_index(
expected_groups: T_ExpectedGroups, isbin: Sequence[bool], sort: bool
expected_groups: T_ExpectTuple, isbin: Sequence[bool], sort: bool
) -> tuple[pd.Index | None, ...]:
out: list[pd.Index | None] = []
for ex, isbin_ in zip(expected_groups, isbin):
Expand Down Expand Up @@ -1543,6 +1545,36 @@ def _factorize_multiple(by, expected_groups, any_by_dask, reindex):
return (group_idx,), final_groups, grp_shape


def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) -> T_ExpectTuple:

if expected_groups is None:
return (None,) * nby

if nby == 1 and not isinstance(expected_groups, tuple):
return (np.asarray(expected_groups),)

if nby > 1 and not isinstance(expected_groups, tuple): # TODO: test for list
raise ValueError(
"When grouping by multiple variables, expected_groups must be a tuple "
"of either arrays or objects convertible to an array (like lists). "
"For example `expected_groups=(np.array([1, 2, 3]), ['a', 'b', 'c'])`."
f"Received a {type(expected_groups).__name__} instead. "
"When grouping by a single variable, you can pass an array or something "
"convertible to an array for convenience: `expected_groups=['a', 'b', 'c']`."
)

if TYPE_CHECKING:
assert isinstance(expected_groups, tuple)

if len(expected_groups) != nby:
raise ValueError(
f"Must have same number of `expected_groups` (received {len(expected_groups)}) "
f" and variables to group by (received {nby})."
)

return expected_groups


def groupby_reduce(
array: np.ndarray | DaskArray,
*by: np.ndarray | DaskArray,
Expand Down Expand Up @@ -1679,24 +1711,17 @@ def groupby_reduce(
isbins = isbin
else:
isbins = (isbin,) * nby
if expected_groups is None:
expected_groups = (None,) * nby

_assert_by_is_aligned(array.shape, bys)

expected_groups = _validate_expected_groups(nby, expected_groups)

for idx, (expect, is_dask) in enumerate(zip(expected_groups, by_is_dask)):
if is_dask and (reindex or nby > 1) and expect is None:
raise ValueError(
f"`expected_groups` for array {idx} in `by` cannot be None since it is a dask.array."
)

if nby == 1 and not isinstance(expected_groups, tuple):
expected_groups = (np.asarray(expected_groups),)
elif len(expected_groups) != nby:
raise ValueError(
f"Must have same number of `expected_groups` (received {len(expected_groups)}) "
f" and variables to group by (received {nby})."
)

# We convert to pd.Index since that lets us know if we are binning or not
# (pd.IntervalIndex or not)
expected_groups = _convert_expected_groups_to_index(expected_groups, isbins, sort)
Expand Down
11 changes: 3 additions & 8 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .core import (
_convert_expected_groups_to_index,
_get_expected_groups,
_validate_expected_groups,
groupby_reduce,
rechunk_for_blockwise as rechunk_array_for_blockwise,
rechunk_for_cohorts as rechunk_array_for_cohorts,
Expand Down Expand Up @@ -216,16 +217,10 @@ def xarray_reduce(
else:
isbins = (isbin,) * nby

if expected_groups is None:
expected_groups = (None,) * nby
if isinstance(expected_groups, (np.ndarray, list)): # TODO: test for list
if nby == 1:
expected_groups = (expected_groups,)
else:
raise ValueError("Needs better message.")
expected_groups = _validate_expected_groups(nby, expected_groups)

if not sort:
raise NotImplementedError
raise NotImplementedError("sort must be True for xarray_reduce")

# eventually drop the variables we are grouping by
maybe_drop = [b for b in by if isinstance(b, Hashable)]
Expand Down
20 changes: 19 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,8 +1019,26 @@ def test_multiple_groupers(chunk, by1, by2, expected_groups) -> None:
assert_equal(expected, actual)


@pytest.mark.parametrize(
"expected_groups",
(
[None, None, None],
(None,),
),
)
def test_validate_expected_groups(expected_groups):
with pytest.raises(ValueError):
groupby_reduce(
np.ones((10,)),
np.ones((10,)),
np.ones((10,)),
expected_groups=expected_groups,
func="mean",
)


@requires_dask
def test_multiple_groupers_errors() -> None:
def test_validate_expected_groups_not_none_dask() -> None:
with pytest.raises(ValueError):
groupby_reduce(
dask.array.ones((5, 2)),
Expand Down
14 changes: 12 additions & 2 deletions tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,22 @@ def test_xarray_reduce_multiple_groupers_2(pass_expected_groups, chunk, engine):


@requires_dask
def test_dask_groupers_error():
@pytest.mark.parametrize(
"expected_groups",
(None, (None, None), [[1, 2], [1, 2]]),
)
def test_validate_expected_groups(expected_groups):
da = xr.DataArray(
[1.0, 2.0], dims="x", coords={"labels": ("x", [1, 2]), "labels2": ("x", [1, 2])}
)
with pytest.raises(ValueError):
xarray_reduce(da.chunk({"x": 2, "z": 1}), "labels", "labels2", func="count")
xarray_reduce(
da.chunk({"x": 1}),
"labels",
"labels2",
func="count",
expected_groups=expected_groups,
)


@requires_dask
Expand Down

0 comments on commit 8ec0617

Please sign in to comment.