Skip to content

Commit

Permalink
[CHORE] Improve error messages when calling aggregation methods on da…
Browse files Browse the repository at this point in the history
…taframe without input columns (#1587)

Fixes #1583.

When a user does not specify columns in df aggregation methods, e.g.
`df.count()`:
- Default to running aggregation on all columns
- Log warning messages with an example to pass in columns.
  • Loading branch information
colin-ho authored Nov 13, 2023
1 parent 0052c17 commit b4467ad
Showing 1 changed file with 36 additions and 7 deletions.
43 changes: 36 additions & 7 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# For technical details, see https://github.com/Eventual-Inc/Daft/pull/630

import pathlib
import warnings
from dataclasses import dataclass
from functools import reduce
from typing import (
Expand Down Expand Up @@ -858,7 +859,11 @@ def sum(self, *cols: ColumnInputType) -> "DataFrame":
Returns:
DataFrame: Globally aggregated sums. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
warnings.warn(
"No columns specified; performing sum on all columns. Specify columns using df.sum('col1', 'col2', ...)."
)
cols = tuple(self.columns)
return self._agg([(c, "sum") for c in cols])

@DataframePublicAPI
Expand All @@ -870,7 +875,11 @@ def mean(self, *cols: ColumnInputType) -> "DataFrame":
Returns:
DataFrame: Globally aggregated mean. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
warnings.warn(
"No columns specified; performing mean on all columns. Specify columns using df.mean('col1', 'col2', ...)."
)
cols = tuple(self.columns)
return self._agg([(c, "mean") for c in cols])

@DataframePublicAPI
Expand All @@ -882,7 +891,11 @@ def min(self, *cols: ColumnInputType) -> "DataFrame":
Returns:
DataFrame: Globally aggregated min. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
warnings.warn(
"No columns specified; performing min on all columns. Specify columns using df.min('col1', 'col2', ...)."
)
cols = tuple(self.columns)
return self._agg([(c, "min") for c in cols])

@DataframePublicAPI
Expand All @@ -894,7 +907,11 @@ def max(self, *cols: ColumnInputType) -> "DataFrame":
Returns:
DataFrame: Globally aggregated max. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
warnings.warn(
"No columns specified; performing max on all columns. Specify columns using df.max('col1', 'col2', ...)."
)
cols = tuple(self.columns)
return self._agg([(c, "max") for c in cols])

@DataframePublicAPI
Expand All @@ -906,7 +923,11 @@ def count(self, *cols: ColumnInputType) -> "DataFrame":
Returns:
DataFrame: Globally aggregated count. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
warnings.warn(
"No columns specified; performing count on all columns. Specify columns using df.count('col1', 'col2', ...) or use df.count_rows() for row counts."
)
cols = tuple(self.columns)
return self._agg([(c, "count") for c in cols])

@DataframePublicAPI
Expand All @@ -918,7 +939,11 @@ def agg_list(self, *cols: ColumnInputType) -> "DataFrame":
Returns:
DataFrame: Globally aggregated list. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
warnings.warn(
"No columns specified; performing agg_list on all columns. Specify columns using df.agg_list('col1', 'col2', ...)."
)
cols = tuple(self.columns)
return self._agg([(c, "list") for c in cols])

@DataframePublicAPI
Expand All @@ -930,7 +955,11 @@ def agg_concat(self, *cols: ColumnInputType) -> "DataFrame":
Returns:
DataFrame: Globally aggregated list. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
warnings.warn(
"No columns specified; performing agg_concat on all columns. Specify columns using df.agg_concat('col1', 'col2', ...)."
)
cols = tuple(self.columns)
return self._agg([(c, "concat") for c in cols])

@DataframePublicAPI
Expand Down

0 comments on commit b4467ad

Please sign in to comment.