Skip to content

Commit

Permalink
cohorts: Delete the merge kwarg (#313)
Browse files Browse the repository at this point in the history
So fast now, we do it always!
  • Loading branch information
dcherian authored Jan 5, 2024
1 parent 58bc9be commit 0c4a7f9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 29 deletions.
22 changes: 5 additions & 17 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,7 @@ def _compute_label_chunk_bitmask(labels, chunks, nlabels):


# @memoize
def find_group_cohorts(
labels, chunks, merge: bool = True, expected_groups: None | pd.RangeIndex = None
) -> dict:
def find_group_cohorts(labels, chunks, expected_groups: None | pd.RangeIndex = None) -> dict:
"""
Finds groups labels that occur together aka "cohorts"
Expand All @@ -265,9 +263,8 @@ def find_group_cohorts(
represents NaNs.
chunks : tuple
chunks of the array being reduced
merge : bool, optional
Attempt to merge cohorts when one cohort's chunks are a subset
of another cohort's chunks.
expected_groups: pd.RangeIndex (optional)
Used to extract the largest label expected
Returns
-------
Expand Down Expand Up @@ -322,13 +319,7 @@ def invert(x) -> tuple[np.ndarray, ...]:
# 4. Existing cohorts don't overlap, great for time grouping with perfect chunking
no_overlapping_cohorts = (np.bincount(np.concatenate(tuple(chunks_cohorts.keys()))) == 1).all()

if (
every_group_one_block
or one_group_per_chunk
or single_chunks
or no_overlapping_cohorts
or not merge
):
if every_group_one_block or one_group_per_chunk or single_chunks or no_overlapping_cohorts:
return chunks_cohorts

# Containment = |Q & S| / |Q|
Expand Down Expand Up @@ -1569,10 +1560,7 @@ def dask_groupby_agg(

elif method == "cohorts":
chunks_cohorts = find_group_cohorts(
by_input,
[array.chunks[ax] for ax in axis],
merge=True,
expected_groups=expected_groups,
by_input, [array.chunks[ax] for ax in axis], expected_groups=expected_groups
)
reduced_ = []
groups_ = []
Expand Down
19 changes: 7 additions & 12 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,21 +844,16 @@ def test_rechunk_for_blockwise(inchunks, expected):

@requires_dask
@pytest.mark.parametrize(
"expected, labels, chunks, merge",
"expected, labels, chunks",
[
[[[0, 1, 2, 3]], [0, 1, 2, 0, 1, 2, 3], (3, 4), True],
[[[0], [1], [2], [3]], [0, 1, 2, 0, 1, 2, 3], (2, 2, 2, 1), True],
[[[0, 1, 2], [3]], [0, 1, 2, 0, 1, 2, 3], (3, 3, 1), True],
[
[[0], [1, 2, 3, 4], [5]],
np.repeat(np.arange(6), [4, 4, 12, 2, 3, 4]),
(4, 8, 4, 9, 4),
True,
],
[[[0, 1, 2, 3]], [0, 1, 2, 0, 1, 2, 3], (3, 4)],
[[[0], [1], [2], [3]], [0, 1, 2, 0, 1, 2, 3], (2, 2, 2, 1)],
[[[0, 1, 2], [3]], [0, 1, 2, 0, 1, 2, 3], (3, 3, 1)],
[[[0], [1, 2, 3, 4], [5]], np.repeat(np.arange(6), [4, 4, 12, 2, 3, 4]), (4, 8, 4, 9, 4)],
],
)
def test_find_group_cohorts(expected, labels, chunks: tuple[int], merge: bool) -> None:
actual = list(find_group_cohorts(labels, (chunks,), merge).values())
def test_find_group_cohorts(expected, labels, chunks: tuple[int]) -> None:
actual = list(find_group_cohorts(labels, (chunks,)).values())
assert actual == expected, (actual, expected)


Expand Down

0 comments on commit 0c4a7f9

Please sign in to comment.