Skip to content

Commit

Permalink
Add support for grouping by multiple columns when doing a reduct… (#987)
Browse files Browse the repository at this point in the history
* Resolves #75
* Adds support for grouping by multiple columns.
* Does this grouping by broadcasting the columns.
  * A preliminary performance evaluation shows that it is significantly
    faster than before, but still has some room for improvement.
* Minimal code changes to add this new feature.
* We still default to pandas when the user is looping over the dataframe
  * Even though this is common, it is exceptionally hard to optimize,
    and out of scope for this PR.
  • Loading branch information
devin-petersohn authored Jan 13, 2020
1 parent 7874281 commit 08f9af5
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 24 deletions.
24 changes: 19 additions & 5 deletions modin/backends/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,25 +1181,39 @@ def groupby_reduce(
by, type(self)
), "Can only use groupby reduce with another Query Compiler"

other_len = len(by.columns)

def _map(df, other):
other = other.squeeze(axis=axis ^ 1)
if isinstance(other, pandas.DataFrame):
df = pandas.concat(
[df] + [other[[o for o in other if o not in df]]], axis=1
)
other = list(other.columns)
return map_func(
df.groupby(by=other.squeeze(axis=axis ^ 1), axis=axis, **groupby_args),
**map_args
df.groupby(by=other, axis=axis, **groupby_args), **map_args
).reset_index(drop=False)

if reduce_func is not None:

def _reduce(df):
if other_len > 1:
by = list(df.columns[0:other_len])
else:
by = df.columns[0]
return reduce_func(
df.groupby(by=df.columns[0], axis=axis, **groupby_args),
**reduce_args
df.groupby(by=by, axis=axis, **groupby_args), **reduce_args
)

else:

def _reduce(df):
if other_len > 1:
by = list(df.columns[0:other_len])
else:
by = df.columns[0]
return map_func(
df.groupby(by=df.columns[0], axis=axis, **groupby_args), **map_args
df.groupby(by=by, axis=axis, **groupby_args), **map_args
)

if axis == 0:
Expand Down
36 changes: 22 additions & 14 deletions modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,20 +410,28 @@ def groupby(
else:
by = self.__getitem__(by)._query_compiler
elif is_list_like(by):
if isinstance(by, Series):
idx_name = by.name
by = by.values
mismatch = len(by) != len(self.axes[axis])
if mismatch and all(
obj in self
or (hasattr(self.index, "names") and obj in self.index.names)
for obj in by
):
# In the future, we will need to add logic to handle this, but for now
# we default to pandas in this case.
pass
elif mismatch:
raise KeyError(next(x for x in by if x not in self))
# fastpath for multi column groupby
if axis == 0 and all(o in self for o in by):
warnings.warn(
"Multi-column groupby is a new feature. "
"Please report any bugs/issues to [email protected]."
)
by = self.__getitem__(by)._query_compiler
else:
if isinstance(by, Series):
idx_name = by.name
by = by.values
mismatch = len(by) != len(self.axes[axis])
if mismatch and all(
obj in self
or (hasattr(self.index, "names") and obj in self.index.names)
for obj in by
):
# In the future, we will need to add logic to handle this, but for now
# we default to pandas in this case.
pass
elif mismatch:
raise KeyError(next(x for x in by if x not in self))

from .groupby import DataFrameGroupBy

Expand Down
12 changes: 9 additions & 3 deletions modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,23 @@ def __getattr__(self, key):
@property
def _index_grouped(self):
if self._index_grouped_cache is None:
if self._is_multi_by:
if hasattr(self._by, "columns") and len(self._by.columns) > 1:
by = list(self._by.columns)
is_multi_by = True
else:
by = self._by
is_multi_by = self._is_multi_by
if is_multi_by:
# Because we are doing a collect (to_pandas) here and then groupby, we
# end up using pandas implementation. Add the warning so the user is
# aware.
ErrorMessage.catch_bugs_and_request_email(self._axis == 1)
ErrorMessage.default_to_pandas("Groupby with multiple columns")
self._index_grouped_cache = {
k: v.index
for k, v in self._df._query_compiler.getitem_column_array(self._by)
for k, v in self._df._query_compiler.getitem_column_array(by)
.to_pandas()
.groupby(by=self._by)
.groupby(by=by)
}
else:
if isinstance(self._by, type(self._query_compiler)):
Expand Down
3 changes: 1 addition & 2 deletions modin/pandas/test/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,7 @@ def test_multi_column_groupby():
ray_df = from_pandas(pandas_df)
by = ["col1", "col2"]

with pytest.warns(UserWarning):
ray_df.groupby(by).count()
ray_df_equals_pandas(ray_df.groupby(by).count(), pandas_df.groupby(by).count())

with pytest.warns(UserWarning):
for k, _ in ray_df.groupby(by):
Expand Down

0 comments on commit 08f9af5

Please sign in to comment.