Skip to content

Commit

Permalink
Try return_array from _finalize_results
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Oct 26, 2022
1 parent ee6be26 commit cb25e38
Showing 1 changed file with 43 additions and 16 deletions.
59 changes: 43 additions & 16 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,10 +803,15 @@ def _aggregate(
keepdims,
fill_value: Any,
reindex: bool,
return_array: bool,
) -> FinalResultsDict:
"""Final aggregation step of tree reduction"""
results = combine(x_chunk, agg, axis, keepdims, is_aggregate=True)
return _finalize_results(results, agg, axis, expected_groups, fill_value, reindex)
finalized = _finalize_results(results, agg, axis, expected_groups, fill_value, reindex)
if return_array:
return finalized[agg.name]
else:
return finalized


def _expand_dims(results: IntermediateDict) -> IntermediateDict:
Expand Down Expand Up @@ -1287,6 +1292,7 @@ def dask_groupby_agg(
group_chunks: tuple[tuple[Union[int, float], ...]] = (
(len(expected_groups),) if expected_groups is not None else (np.nan,),
)
groups_are_unknown = is_duck_dask_array(by_input) and expected_groups is None

if method in ["map-reduce", "cohorts"]:
combine: Callable[..., IntermediateDict]
Expand Down Expand Up @@ -1316,16 +1322,32 @@ def dask_groupby_agg(
reduced = tree_reduce(
intermediate,
combine=partial(combine, agg=agg),
aggregate=partial(aggregate, expected_groups=expected_groups, reindex=reindex),
aggregate=partial(
aggregate,
expected_groups=expected_groups,
reindex=reindex,
return_array=not groups_are_unknown,
),
)
if is_duck_dask_array(by_input) and expected_groups is None:
if groups_are_unknown:
groups = _extract_unknown_groups(reduced, group_chunks=group_chunks, dtype=by.dtype)
result = dask.array.map_blocks(
_extract_result,
reduced,
chunks=reduced.chunks[: -len(axis)] + group_chunks,
drop_axis=axis[:-1],
dtype=agg.dtype[agg.name],
key=agg.name,
name=f"{name}-{token}",
)

else:
if expected_groups is None:
expected_groups_ = _get_expected_groups(by_input, sort=sort)
else:
expected_groups_ = expected_groups
groups = (expected_groups_.to_numpy(),)
result = reduced

elif method == "cohorts":
chunks_cohorts = find_group_cohorts(
Expand All @@ -1344,12 +1366,14 @@ def dask_groupby_agg(
tree_reduce(
reindexed,
combine=partial(combine, agg=agg, reindex=True),
aggregate=partial(aggregate, expected_groups=index, reindex=True),
aggregate=partial(
aggregate, expected_groups=index, reindex=True, return_array=True
),
)
)
groups_.append(cohort)

reduced = dask.array.concatenate(reduced_, axis=-1)
result = dask.array.concatenate(reduced_, axis=-1)
groups = (np.concatenate(groups_),)
group_chunks = (tuple(len(cohort) for cohort in groups_),)

Expand All @@ -1375,21 +1399,24 @@ def dask_groupby_agg(
for ax, chunks in zip(axis, group_chunks):
adjust_chunks[ax] = chunks

result = dask.array.blockwise(
_extract_result,
inds[: -len(axis)] + (inds[-1],),
reduced,
inds,
adjust_chunks=adjust_chunks,
dtype=agg.dtype[agg.name],
key=agg.name,
name=f"{name}-{token}",
)

# result = dask.array.blockwise(
# _extract_result,
# inds[: -len(axis)] + (inds[-1],),
# reduced,
# inds,
# adjust_chunks=adjust_chunks,
# dtype=agg.dtype[agg.name],
# key=agg.name,
# name=f"{name}-{token}",
# )
return (result, groups)


def _extract_result(result_dict: FinalResultsDict, key) -> np.ndarray:
from dask.array.core import deepfirst

if not isinstance(result_dict, dict):
result_dict = deepfirst(result_dict)
return result_dict[key]


Expand Down

0 comments on commit cb25e38

Please sign in to comment.