Skip to content

Commit

Permalink
Merge branch 'main' into threadpool
Browse files Browse the repository at this point in the history
* main:
  Optimize bitmask finding for chunk size 1 and single chunk cases (#360)
  Edits to climatology doc (#361)
  • Loading branch information
dcherian committed Apr 27, 2024
2 parents 2823677 + 627bf2b commit 7da4364
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 23 deletions.
13 changes: 13 additions & 0 deletions asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,19 @@ def setup(self, *args, **kwargs):
self.expected = pd.RangeIndex(self.by.max() + 1)


class SingleChunk(Cohorts):
"""Single chunk along reduction axis: always blockwise."""

def setup(self, *args, **kwargs):
index = pd.date_range("1959-01-01", freq="D", end="1962-12-31")
self.time = pd.Series(index)
TIME = len(self.time)
self.axis = (2,)
self.array = dask.array.ones((721, 1440, TIME), chunks=(-1, -1, -1))
self.by = codes_for_resampling(index, freq="5D")
self.expected = pd.RangeIndex(self.by.max() + 1)


class OISST(Cohorts):
def setup(self, *args, **kwargs):
self.array = dask.array.ones((1, 14532), chunks=(1, 10))
Expand Down
52 changes: 41 additions & 11 deletions docs/source/user-stories/climatology.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@
"source": [
"To account for Feb-29 being present in some years, we'll construct a time vector to group by as \"mmm-dd\" string.\n",
"\n",
"For more options, see https://strftime.org/"
"```{seealso}\n",
"For more options, see [this great website](https://strftime.org/).\n",
"```"
]
},
{
Expand All @@ -80,7 +82,7 @@
"id": "6",
"metadata": {},
"source": [
"## map-reduce\n",
"## First, `method=\"map-reduce\"`\n",
"\n",
"The default\n",
"[method=\"map-reduce\"](https://flox.readthedocs.io/en/latest/implementation.html#method-map-reduce)\n",
Expand Down Expand Up @@ -110,7 +112,7 @@
"id": "8",
"metadata": {},
"source": [
"## Rechunking for map-reduce\n",
"### Rechunking for map-reduce\n",
"\n",
"We can split each chunk along the `lat`, `lon` dimensions to make sure the\n",
"output chunk sizes are more reasonable\n"
Expand Down Expand Up @@ -139,7 +141,7 @@
"But what if we didn't want to rechunk the dataset so drastically (note the 10x\n",
"increase in tasks). For that let's try `method=\"cohorts\"`\n",
"\n",
"## method=cohorts\n",
"## `method=\"cohorts\"`\n",
"\n",
"We can take advantage of patterns in the groups here \"day of year\".\n",
"Specifically:\n",
Expand Down Expand Up @@ -271,7 +273,7 @@
"id": "21",
"metadata": {},
"source": [
"And now our cohorts contain more than one group\n"
"And now our cohorts contain more than one group, *and* there is a substantial reduction in number of cohorts **162 -> 12**\n"
]
},
{
Expand All @@ -281,7 +283,7 @@
"metadata": {},
"outputs": [],
"source": [
"preferrd_method, new_cohorts = flox.core.find_group_cohorts(\n",
"preferred_method, new_cohorts = flox.core.find_group_cohorts(\n",
" labels=codes,\n",
" chunks=(rechunked.chunksizes[\"time\"],),\n",
")\n",
Expand All @@ -295,13 +297,23 @@
"id": "23",
"metadata": {},
"outputs": [],
"source": [
"preferred_method"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24",
"metadata": {},
"outputs": [],
"source": [
"new_cohorts.values()"
]
},
{
"cell_type": "markdown",
"id": "24",
"id": "25",
"metadata": {},
"source": [
"Now the groupby reduction **looks OK** in terms of number of tasks but remember\n",
Expand All @@ -311,7 +323,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "25",
"id": "26",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -320,7 +332,25 @@
},
{
"cell_type": "markdown",
"id": "26",
"id": "27",
"metadata": {},
"source": [
"flox's heuristics will choose `\"cohorts\"` automatically!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "28",
"metadata": {},
"outputs": [],
"source": [
"flox.xarray.xarray_reduce(rechunked, day, func=\"mean\")"
]
},
{
"cell_type": "markdown",
"id": "29",
"metadata": {},
"source": [
"## How about other climatologies?\n",
Expand All @@ -331,7 +361,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "27",
"id": "30",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -340,7 +370,7 @@
},
{
"cell_type": "markdown",
"id": "28",
"id": "31",
"metadata": {},
"source": [
"This looks great. Why?\n",
Expand Down
37 changes: 27 additions & 10 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,26 @@ def slices_from_chunks(chunks):


def _compute_label_chunk_bitmask(labels, chunks, nlabels):
def make_bitmask(rows, cols):
data = np.broadcast_to(np.array(1, dtype=np.uint8), rows.shape)
return csc_array((data, (rows, cols)), dtype=bool, shape=(nchunks, nlabels))

assert isinstance(labels, np.ndarray)
shape = tuple(sum(c) for c in chunks)
nchunks = math.prod(len(c) for c in chunks)
approx_chunk_size = math.prod(c[0] for c in chunks)

# Shortcut for 1D with size-1 chunks
if shape == (nchunks,):
rows_array = np.arange(nchunks)
cols_array = labels
mask = labels >= 0
return make_bitmask(rows_array[mask], cols_array[mask])

labels = np.broadcast_to(labels, shape[-labels.ndim :])
cols = []
# Add one to handle the -1 sentinel value
label_is_present = np.zeros((nlabels + 1,), dtype=bool)
ilabels = np.arange(nlabels)

def chunk_unique(labels, slicer, nlabels, label_is_present=None):
Expand Down Expand Up @@ -300,13 +314,10 @@ def chunk_unique(labels, slicer, nlabels, label_is_present=None):
uniques = chunk_unique(labels, region, nlabels, label_is_present)
cols.append(uniques)
label_is_present[:] = False

cols_array = np.concatenate(cols)
rows_array = np.repeat(np.arange(nchunks), tuple(len(col) for col in cols))
data = np.broadcast_to(np.array(1, dtype=np.uint8), rows_array.shape)
bitmask = csc_array((data, (rows_array, cols_array)), dtype=bool, shape=(nchunks, nlabels))
cols_array = np.concatenate(cols)

return bitmask
return make_bitmask(rows_array, cols_array)


# @memoize
Expand Down Expand Up @@ -343,13 +354,18 @@ def find_group_cohorts(
labels = np.asarray(labels)

shape = tuple(sum(c) for c in chunks)
nchunks = math.prod(len(c) for c in chunks)

# assumes that `labels` are factorized
if expected_groups is None:
nlabels = labels.max() + 1
else:
nlabels = expected_groups[-1] + 1

# 1. Single chunk, blockwise always
if nchunks == 1:
return "blockwise", {(0,): list(range(nlabels))}

labels = np.broadcast_to(labels, shape[-labels.ndim :])
bitmask = _compute_label_chunk_bitmask(labels, chunks, nlabels)

Expand Down Expand Up @@ -377,21 +393,21 @@ def invert(x) -> tuple[np.ndarray, ...]:

chunks_cohorts = tlz.groupby(invert, label_chunks.keys())

# 1. Every group is contained to one block, use blockwise here.
# 2. Every group is contained to one block, use blockwise here.
if bitmask.shape[CHUNK_AXIS] == 1 or (chunks_per_label == 1).all():
logger.info("find_group_cohorts: blockwise is preferred.")
return "blockwise", chunks_cohorts

# 2. Perfectly chunked so there is only a single cohort
# 3. Perfectly chunked so there is only a single cohort
if len(chunks_cohorts) == 1:
logger.info("Only found a single cohort. 'map-reduce' is preferred.")
return "map-reduce", chunks_cohorts if merge else {}

# 3. Our dataset has chunksize one along the axis,
# 4. Our dataset has chunksize one along the axis,
single_chunks = all(all(a == 1 for a in ac) for ac in chunks)
# 4. Every chunk only has a single group, but that group might extend across multiple chunks
# 5. Every chunk only has a single group, but that group might extend across multiple chunks
one_group_per_chunk = (bitmask.sum(axis=LABEL_AXIS) == 1).all()
# 5. Existing cohorts don't overlap, great for time grouping with perfect chunking
# 6. Existing cohorts don't overlap, great for time grouping with perfect chunking
no_overlapping_cohorts = (np.bincount(np.concatenate(tuple(chunks_cohorts.keys()))) == 1).all()
if one_group_per_chunk or single_chunks or no_overlapping_cohorts:
logger.info("find_group_cohorts: cohorts is preferred, chunking is perfect.")
Expand Down Expand Up @@ -424,6 +440,7 @@ def invert(x) -> tuple[np.ndarray, ...]:
sparsity, MAX_SPARSITY_FOR_COHORTS
)
)
# 7. Groups seem fairly randomly distributed, use "map-reduce".
if sparsity > MAX_SPARSITY_FOR_COHORTS:
if not merge:
logger.info(
Expand Down
19 changes: 17 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,12 +946,12 @@ def test_verify_complex_cohorts(chunksize: int) -> None:
@pytest.mark.parametrize("chunksize", (12,) + tuple(range(1, 13)) + (-1,))
def test_method_guessing(chunksize):
# just a regression test
labels = np.tile(np.arange(1, 13), 30)
labels = np.tile(np.arange(0, 12), 30)
by = dask.array.from_array(labels, chunks=chunksize) - 1
preferred_method, chunks_cohorts = find_group_cohorts(labels, by.chunks[slice(-1, None)])
if chunksize == -1:
assert preferred_method == "blockwise"
assert chunks_cohorts == {(0,): list(range(1, 13))}
assert chunks_cohorts == {(0,): list(range(12))}
elif chunksize in (1, 2, 3, 4, 6):
assert preferred_method == "cohorts"
assert len(chunks_cohorts) == 12 // chunksize
Expand All @@ -960,6 +960,21 @@ def test_method_guessing(chunksize):
assert chunks_cohorts == {}


@requires_dask
@pytest.mark.parametrize("ndim", [1, 2, 3])
def test_single_chunk_method_is_blockwise(ndim):
for by_ndim in range(1, ndim + 1):
chunks = (5,) * (ndim - by_ndim) + (-1,) * by_ndim
assert len(chunks) == ndim
array = dask.array.ones(shape=(10,) * ndim, chunks=chunks)
by = np.zeros(shape=(10,) * by_ndim, dtype=int)
method, chunks_cohorts = find_group_cohorts(
by, chunks=[array.chunks[ax] for ax in range(-by.ndim, 0)]
)
assert method == "blockwise"
assert chunks_cohorts == {(0,): [0]}


@requires_dask
@pytest.mark.parametrize(
"chunk_at,expected",
Expand Down

0 comments on commit 7da4364

Please sign in to comment.