Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checks for HLG layers in dask-cudf groupby tests #10853

Merged
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
6571499
Add checks for dask-cudf groupby HLG layers
charlesbluca May 13, 2022
9bc6645
Remove some tests that are now superfluous
charlesbluca May 13, 2022
1115f1c
Merge remote-tracking branch 'origin/branch-22.10' into checkout-grou…
charlesbluca Sep 13, 2022
1c700cb
Update check to use new layer names
charlesbluca Sep 13, 2022
2bec53a
Merge remote-tracking branch 'origin/branch-22.12' into checkout-grou…
charlesbluca Nov 1, 2022
543f8db
Update python/dask_cudf/dask_cudf/tests/test_groupby.py
charlesbluca Nov 2, 2022
b77a14a
Rename layer assertion function, improve error message
charlesbluca Nov 2, 2022
9d458f8
Parametrize nulls in pdf fixtures, use in test_groupby_agg
charlesbluca Nov 2, 2022
c7eaad1
Merge cumulative agg tests into main groupby tests
charlesbluca Nov 2, 2022
069fc65
Add back in test_groupby_first_last with some changes
charlesbluca Nov 2, 2022
711470f
Revert "Merge cumulative agg tests into main groupby tests"
charlesbluca Nov 2, 2022
2112fea
xfail cumulative test for null dataframes
charlesbluca Nov 3, 2022
60649ba
Remove cumulative aggs from SUPPORTED_AGGS
charlesbluca Nov 3, 2022
5b2fe76
Merge remote-tracking branch 'origin/branch-22.12' into checkout-grou…
charlesbluca Nov 3, 2022
2001ffe
Rename util groupby functions from supported -> optimized
charlesbluca Nov 3, 2022
6eddbcf
Merge remote-tracking branch 'origin/branch-22.12' into checkout-grou…
charlesbluca Nov 3, 2022
cc64ebc
Wrap groupby decorator with functools.wraps
charlesbluca Nov 3, 2022
78be977
Merge remote-tracking branch 'origin/branch-22.12' into checkout-grou…
charlesbluca Nov 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 34 additions & 40 deletions python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,8 @@
import cudf
from cudf.utils.utils import _dask_cudf_nvtx_annotate

CUMULATIVE_AGGS = (
"cumsum",
"cumcount",
)

AGGS = (
# aggregations that are dask-cudf optimized
OPTIMIZED_AGGS = (
"count",
"mean",
"std",
Expand All @@ -34,19 +30,17 @@
"last",
)

SUPPORTED_AGGS = (*AGGS, *CUMULATIVE_AGGS)


def _check_groupby_supported(func):
def _check_groupby_optimized(func):
"""
Decorator for dask-cudf's groupby methods that returns the dask-cudf
method if the groupby object is supported, otherwise reverting to the
upstream Dask method
optimized method if the groupby object is supported, otherwise
reverting to the upstream Dask method
"""

def wrapper(*args, **kwargs):
charlesbluca marked this conversation as resolved.
Show resolved Hide resolved
gb = args[0]
if _groupby_supported(gb):
if _groupby_optimized(gb):
return func(*args, **kwargs)
# note that we use upstream Dask's default kwargs for this call if
# none are specified; this shouldn't be an issue as those defaults are
Expand Down Expand Up @@ -94,7 +88,7 @@ def _make_groupby_method_aggs(self, agg_name):
return {c: agg_name for c in self.obj.columns if c != self.by}

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def count(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -109,7 +103,7 @@ def count(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def mean(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -124,7 +118,7 @@ def mean(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def std(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -139,7 +133,7 @@ def std(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def var(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -154,7 +148,7 @@ def var(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def sum(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -169,7 +163,7 @@ def sum(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def min(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -184,7 +178,7 @@ def min(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def max(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -199,7 +193,7 @@ def max(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def collect(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -214,7 +208,7 @@ def collect(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def first(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -229,7 +223,7 @@ def first(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def last(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -250,7 +244,7 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):

arg = _redirect_aggs(arg)

if _groupby_supported(self) and _aggs_supported(arg, SUPPORTED_AGGS):
if _groupby_optimized(self) and _aggs_optimized(arg, OPTIMIZED_AGGS):
if isinstance(self._meta.grouping.keys, cudf.MultiIndex):
keys = self._meta.grouping.keys.names
else:
Expand Down Expand Up @@ -287,7 +281,7 @@ def __init__(self, *args, sort=None, **kwargs):
super().__init__(*args, sort=sort, **kwargs)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def count(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -302,7 +296,7 @@ def count(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def mean(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -317,7 +311,7 @@ def mean(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def std(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -332,7 +326,7 @@ def std(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def var(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -347,7 +341,7 @@ def var(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def sum(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -362,7 +356,7 @@ def sum(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def min(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -377,7 +371,7 @@ def min(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def max(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -392,7 +386,7 @@ def max(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def collect(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -407,7 +401,7 @@ def collect(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def first(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -422,7 +416,7 @@ def first(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def last(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -446,7 +440,7 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
if not isinstance(arg, dict):
arg = {self._slice: arg}

if _groupby_supported(self) and _aggs_supported(arg, SUPPORTED_AGGS):
if _groupby_optimized(self) and _aggs_optimized(arg, OPTIMIZED_AGGS):
return groupby_agg(
self.obj,
self.by,
Expand Down Expand Up @@ -569,9 +563,9 @@ def groupby_agg(
"""
# Assert that aggregations are supported
aggs = _redirect_aggs(aggs_in)
if not _aggs_supported(aggs, SUPPORTED_AGGS):
if not _aggs_optimized(aggs, OPTIMIZED_AGGS):
raise ValueError(
f"Supported aggs include {SUPPORTED_AGGS} for groupby_agg API. "
f"Supported aggs include {OPTIMIZED_AGGS} for groupby_agg API. "
f"Aggregations must be specified with dict or list syntax."
)

Expand Down Expand Up @@ -735,7 +729,7 @@ def _redirect_aggs(arg):


@_dask_cudf_nvtx_annotate
def _aggs_supported(arg, supported: set):
def _aggs_optimized(arg, supported: set):
"""Check that aggregations in `arg` are a subset of `supported`"""
if isinstance(arg, (list, dict)):
if isinstance(arg, dict):
Expand All @@ -757,8 +751,8 @@ def _aggs_supported(arg, supported: set):


@_dask_cudf_nvtx_annotate
def _groupby_supported(gb):
"""Check that groupby input is supported by dask-cudf"""
def _groupby_optimized(gb):
"""Check that groupby input can use dask-cudf optimized codepath"""
return isinstance(gb.obj, DaskDataFrame) and (
isinstance(gb.by, str)
or (isinstance(gb.by, list) and all(isinstance(x, str) for x in gb.by))
Expand Down Expand Up @@ -830,7 +824,7 @@ def _tree_node_agg(df, gb_cols, dropna, sort, sep):
agg = col.split(sep)[-1]
if agg in ("count", "sum"):
agg_dict[col] = ["sum"]
elif agg in SUPPORTED_AGGS:
elif agg in OPTIMIZED_AGGS:
agg_dict[col] = [agg]
else:
raise ValueError(f"Unexpected aggregation: {agg}")
Expand Down
Loading