diff --git a/modin/backends/pandas/query_compiler.py b/modin/backends/pandas/query_compiler.py index 767704f1a13..72cdbc6d640 100644 --- a/modin/backends/pandas/query_compiler.py +++ b/modin/backends/pandas/query_compiler.py @@ -1181,25 +1181,39 @@ def groupby_reduce( by, type(self) ), "Can only use groupby reduce with another Query Compiler" + other_len = len(by.columns) + def _map(df, other): + other = other.squeeze(axis=axis ^ 1) + if isinstance(other, pandas.DataFrame): + df = pandas.concat( + [df] + [other[[o for o in other if o not in df]]], axis=1 + ) + other = list(other.columns) return map_func( - df.groupby(by=other.squeeze(axis=axis ^ 1), axis=axis, **groupby_args), - **map_args + df.groupby(by=other, axis=axis, **groupby_args), **map_args ).reset_index(drop=False) if reduce_func is not None: def _reduce(df): + if other_len > 1: + by = list(df.columns[0:other_len]) + else: + by = df.columns[0] return reduce_func( - df.groupby(by=df.columns[0], axis=axis, **groupby_args), - **reduce_args + df.groupby(by=by, axis=axis, **groupby_args), **reduce_args ) else: def _reduce(df): + if other_len > 1: + by = list(df.columns[0:other_len]) + else: + by = df.columns[0] return map_func( - df.groupby(by=df.columns[0], axis=axis, **groupby_args), **map_args + df.groupby(by=by, axis=axis, **groupby_args), **map_args ) if axis == 0: diff --git a/modin/pandas/dataframe.py b/modin/pandas/dataframe.py index 53e977a873f..b440e739c0a 100644 --- a/modin/pandas/dataframe.py +++ b/modin/pandas/dataframe.py @@ -410,20 +410,28 @@ def groupby( else: by = self.__getitem__(by)._query_compiler elif is_list_like(by): - if isinstance(by, Series): - idx_name = by.name - by = by.values - mismatch = len(by) != len(self.axes[axis]) - if mismatch and all( - obj in self - or (hasattr(self.index, "names") and obj in self.index.names) - for obj in by - ): - # In the future, we will need to add logic to handle this, but for now - # we default to pandas in this case. - pass - elif mismatch: - raise KeyError(next(x for x in by if x not in self)) + # fastpath for multi column groupby + if axis == 0 and all(o in self for o in by): + warnings.warn( + "Multi-column groupby is a new feature. " + "Please report any bugs/issues to bug_reports@modin.org." + ) + by = self.__getitem__(by)._query_compiler + else: + if isinstance(by, Series): + idx_name = by.name + by = by.values + mismatch = len(by) != len(self.axes[axis]) + if mismatch and all( + obj in self + or (hasattr(self.index, "names") and obj in self.index.names) + for obj in by + ): + # In the future, we will need to add logic to handle this, but for now + # we default to pandas in this case. + pass + elif mismatch: + raise KeyError(next(x for x in by if x not in self)) from .groupby import DataFrameGroupBy diff --git a/modin/pandas/groupby.py b/modin/pandas/groupby.py index 7f016108ce5..b66baa2cd05 100644 --- a/modin/pandas/groupby.py +++ b/modin/pandas/groupby.py @@ -84,7 +84,13 @@ def __getattr__(self, key): @property def _index_grouped(self): if self._index_grouped_cache is None: - if self._is_multi_by: + if hasattr(self._by, "columns") and len(self._by.columns) > 1: + by = list(self._by.columns) + is_multi_by = True + else: + by = self._by + is_multi_by = self._is_multi_by + if is_multi_by: # Because we are doing a collect (to_pandas) here and then groupby, we # end up using pandas implementation. Add the warning so the user is # aware. @@ -92,9 +98,9 @@ def _index_grouped(self): ErrorMessage.default_to_pandas("Groupby with multiple columns") self._index_grouped_cache = { k: v.index - for k, v in self._df._query_compiler.getitem_column_array(self._by) + for k, v in self._df._query_compiler.getitem_column_array(by) .to_pandas() - .groupby(by=self._by) + .groupby(by=by) } else: if isinstance(self._by, type(self._query_compiler)): diff --git a/modin/pandas/test/test_groupby.py b/modin/pandas/test/test_groupby.py index 1cc8bb28d8d..45f8e6c7eea 100644 --- a/modin/pandas/test/test_groupby.py +++ b/modin/pandas/test/test_groupby.py @@ -465,8 +465,7 @@ def test_multi_column_groupby(): ray_df = from_pandas(pandas_df) by = ["col1", "col2"] - with pytest.warns(UserWarning): - ray_df.groupby(by).count() + ray_df_equals_pandas(ray_df.groupby(by).count(), pandas_df.groupby(by).count()) with pytest.warns(UserWarning): for k, _ in ray_df.groupby(by):