Skip to content

Commit

Permalink
Use dot product for containment (#306)
Browse files Browse the repository at this point in the history
* Use dot product

* avoid advanced indexing

* small edits

* Cache label_chunks as earlier

* WIP

* Readd chunks_cohorts groupby

* Fix

* comments

* Remove cache test

since find_group_cohorts is a lot faster now
  • Loading branch information
dcherian authored Jan 2, 2024
1 parent 80ae6a4 commit 15abf49
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 62 deletions.
87 changes: 48 additions & 39 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _compute_label_chunk_bitmask(labels, chunks, nlabels):
return bitmask


@memoize
# @memoize
def find_group_cohorts(
labels, chunks, merge: bool = True, expected_groups: None | pd.RangeIndex = None
) -> dict:
Expand Down Expand Up @@ -286,7 +286,6 @@ def find_group_cohorts(
nlabels = expected_groups[-1] + 1

labels = np.broadcast_to(labels, shape[-labels.ndim :])
ilabels = np.arange(nlabels)
bitmask = _compute_label_chunk_bitmask(labels, chunks, nlabels)

CHUNK_AXIS, LABEL_AXIS = 0, 1
Expand All @@ -303,54 +302,64 @@ def find_group_cohorts(
for lab in range(bitmask.shape[-1])
}

# These invert the label_chunks mapping so we know which labels occur together.
# Invert the label_chunks mapping so we know which labels occur together.
def invert(x) -> tuple[np.ndarray, ...]:
arr = label_chunks.get(x)
return tuple(arr) # type: ignore [arg-type] # pandas issue?
arr = label_chunks[x]
return tuple(arr)

chunks_cohorts = tlz.groupby(invert, label_chunks.keys())

# If our dataset has chunksize one along the axis,
# then no merging is possible.
# No merging is possible when
# 1. Our dataset has chunksize one along the axis,
single_chunks = all(all(a == 1 for a in ac) for ac in chunks)
# 2. Every chunk only has a single group, but that group might extend across multiple chunks
one_group_per_chunk = (bitmask.sum(axis=LABEL_AXIS) == 1).all()
# every group is contained to one block, we should be using blockwise here.
# 3. Every group is contained to one block, we should be using blockwise here.
every_group_one_block = (chunks_per_label == 1).all()
if every_group_one_block or one_group_per_chunk or single_chunks or not merge:
return chunks_cohorts

# First sort by number of chunks occupied by cohort
sorted_chunks_cohorts = dict(
sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0]), reverse=True)
)
# 4. Existing cohorts don't overlap, great for time grouping with perfect chunking
no_overlapping_cohorts = (np.bincount(np.concatenate(tuple(chunks_cohorts.keys()))) == 1).all()

# precompute needed metrics for the quadratic loop below.
items = tuple((k, len(k), set(k), v) for k, v in sorted_chunks_cohorts.items() if k)
if (
every_group_one_block
or one_group_per_chunk
or single_chunks
or no_overlapping_cohorts
or not merge
):
return chunks_cohorts

# Containment = |Q & S| / |Q|
# - |X| is the cardinality of set X
# - Q is the query set being tested
# - S is the existing set
MIN_CONTAINMENT = 0.75 # arbitrary
asfloat = bitmask.astype(float)
containment = ((asfloat.T @ asfloat) / chunks_per_label[present_labels]).tocsr()
mask = containment.data < MIN_CONTAINMENT
containment.data[mask] = 0
containment.eliminate_zeros()

# Iterate over labels, beginning with those with most chunks
order = np.argsort(containment.sum(axis=LABEL_AXIS))[::-1]
merged_cohorts = {}
merged_keys: set[tuple] = set()

# Now we iterate starting with the longest number of chunks,
# and then merge in cohorts that are present in a subset of those chunks
# I think this is suboptimal and must fail at some point.
# But it might work for most cases. There must be a better way...
for idx, (k1, len_k1, set_k1, v1) in enumerate(items):
if k1 in merged_keys:
merged_keys = set()
# TODO: we can optimize this to loop over chunk_cohorts instead
# by zeroing out rows that are already in a cohort
for rowidx in order:
cohort_ = containment.indices[
slice(containment.indptr[rowidx], containment.indptr[rowidx + 1])
]
cohort = [elem for elem in cohort_ if elem not in merged_keys]
if not cohort:
continue
new_key = set_k1
new_value = v1
# iterate in reverse since we expect small cohorts
# to be most likely merged in to larger ones
for k2, len_k2, set_k2, v2 in reversed(items[idx + 1 :]):
if k2 not in merged_keys:
if (len(set_k2 & new_key) / len_k2) > 0.75:
new_key |= set_k2
new_value += v2
merged_keys.update((k2,))
sorted_ = sorted(new_value)
merged_cohorts[tuple(sorted(new_key))] = sorted_
if idx == 0 and (len(sorted_) == nlabels) and (np.array(sorted_) == ilabels).all():
break
merged_keys.update(cohort)
allchunks = (label_chunks[member] for member in cohort)
chunk = tuple(set(itertools.chain(*allchunks)))
merged_cohorts[chunk] = cohort

actual_ngroups = np.concatenate(tuple(merged_cohorts.values())).size
expected_ngroups = bitmask.shape[LABEL_AXIS]
assert expected_ngroups == actual_ngroups, (expected_ngroups, actual_ngroups)

# sort by first label in cohort
# This will help when sort=True (default)
Expand Down
3 changes: 0 additions & 3 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,11 +847,8 @@ def test_rechunk_for_blockwise(inchunks, expected):
"expected, labels, chunks, merge",
[
[[[0, 1, 2, 3]], [0, 1, 2, 0, 1, 2, 3], (3, 4), True],
[[[0, 1, 2], [3]], [0, 1, 2, 0, 1, 2, 3], (3, 4), False],
[[[0], [1], [2], [3]], [0, 1, 2, 0, 1, 2, 3], (2, 2, 2, 1), False],
[[[0], [1], [2], [3]], [0, 1, 2, 0, 1, 2, 3], (2, 2, 2, 1), True],
[[[0, 1, 2], [3]], [0, 1, 2, 0, 1, 2, 3], (3, 3, 1), True],
[[[0, 1, 2], [3]], [0, 1, 2, 0, 1, 2, 3], (3, 3, 1), False],
[
[[0], [1, 2, 3, 4], [5]],
np.repeat(np.arange(6), [4, 4, 12, 2, 3, 4]),
Expand Down
40 changes: 20 additions & 20 deletions tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,26 +367,26 @@ def test_func_is_aggregation():
xarray_reduce(ds.Tair, ds.time.dt.month, func=mean, skipna=False)


@requires_dask
def test_cache():
pytest.importorskip("cachey")

from flox.cache import cache

ds = xr.Dataset(
{
"foo": (("x", "y"), dask.array.ones((10, 20), chunks=2)),
"bar": (("x", "y"), dask.array.ones((10, 20), chunks=2)),
},
coords={"labels": ("y", np.repeat([1, 2], 10))},
)

cache.clear()
xarray_reduce(ds, "labels", func="mean", method="cohorts")
assert len(cache.data) == 1

xarray_reduce(ds, "labels", func="mean", method="blockwise")
assert len(cache.data) == 2
# @requires_dask
# def test_cache():
# pytest.importorskip("cachey")

# from flox.cache import cache

# ds = xr.Dataset(
# {
# "foo": (("x", "y"), dask.array.ones((10, 20), chunks=2)),
# "bar": (("x", "y"), dask.array.ones((10, 20), chunks=2)),
# },
# coords={"labels": ("y", np.repeat([1, 2], 10))},
# )

# cache.clear()
# xarray_reduce(ds, "labels", func="mean", method="cohorts")
# assert len(cache.data) == 1

# xarray_reduce(ds, "labels", func="mean", method="blockwise")
# assert len(cache.data) == 2


@requires_dask
Expand Down

0 comments on commit 15abf49

Please sign in to comment.