Skip to content

Commit

Permalink
FIX-#2482: improved handling non-str 'by' (#2548)
Browse files Browse the repository at this point in the history
Signed-off-by: Dmitry Chigarev <[email protected]>
  • Loading branch information
dchigarev authored Dec 18, 2020
1 parent cee481b commit 9d9dc29
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 14 deletions.
4 changes: 2 additions & 2 deletions modin/backends/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from modin.backends.base.query_compiler import BaseQueryCompiler
from modin.error_message import ErrorMessage
from modin.utils import try_cast_to_pandas, wrap_udf_function
from modin.utils import try_cast_to_pandas, wrap_udf_function, hashable
from modin.data_management.functions import (
Function,
FoldFunction,
Expand Down Expand Up @@ -2555,7 +2555,7 @@ def is_reduce_fn(fn, deep_level=0):
else:
if not isinstance(by, list):
by = [by]
internal_by = [o for o in by if isinstance(o, str) and o in self.columns]
internal_by = [o for o in by if hashable(o) and o in self.columns]
internal_qc = (
[self.getitem_column_array(internal_by)] if len(internal_by) else []
)
Expand Down
6 changes: 4 additions & 2 deletions modin/data_management/functions/groupby_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pandas

from .mapreducefunction import MapReduceFunction
from modin.utils import try_cast_to_pandas
from modin.utils import try_cast_to_pandas, hashable


class GroupbyReduceFunction(MapReduceFunction):
Expand Down Expand Up @@ -113,7 +113,9 @@ def caller(
numeric_only=True,
**kwargs,
):
if not isinstance(by, (type(query_compiler), str)):
if not (isinstance(by, (type(query_compiler)) or hashable(by))) or isinstance(
by, pandas.Grouper
):
by = try_cast_to_pandas(by, squeeze=True)
default_func = (
(lambda grp: grp.agg(map_func))
Expand Down
12 changes: 6 additions & 6 deletions modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def groupby(

if callable(by):
by = self.index.map(by)
elif isinstance(by, str):
elif hashable(by) and not isinstance(by, pandas.Grouper):
drop = by in self.columns
idx_name = by
if self._query_compiler.has_multiindex(
Expand All @@ -374,7 +374,7 @@ def groupby(
# In this case we pass the string value of the name through to the
# partitions. This is more efficient than broadcasting the values.
pass
else:
elif level is None:
by = self.__getitem__(by)._query_compiler
elif isinstance(by, Series):
drop = by._parent is self
Expand All @@ -384,7 +384,7 @@ def groupby(
# fastpath for multi column groupby
if axis == 0 and all(
(
(isinstance(o, str) and (o in self))
(hashable(o) and (o in self))
or isinstance(o, Series)
or (is_list_like(o) and len(o) == len(self.axes[axis]))
)
Expand All @@ -395,7 +395,7 @@ def groupby(
internal_by, external_by = [], []

for current_by in by:
if isinstance(current_by, str):
if hashable(current_by):
internal_by.append(current_by)
elif isinstance(current_by, Series):
if current_by._parent is self:
Expand All @@ -414,7 +414,7 @@ def groupby(
else:
mismatch = len(by) != len(self.axes[axis])
if mismatch and all(
isinstance(obj, str)
hashable(obj)
and (
obj in self or obj in self._query_compiler.get_index_names(axis)
)
Expand All @@ -424,7 +424,7 @@ def groupby(
# we default to pandas in this case.
pass
elif mismatch and any(
isinstance(obj, str) and obj not in self.columns for obj in by
hashable(obj) and obj not in self.columns for obj in by
):
names = [o.name if isinstance(o, Series) else o for o in by]
raise KeyError(next(x for x in names if x not in self))
Expand Down
20 changes: 16 additions & 4 deletions modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
from collections.abc import Iterable

from modin.error_message import ErrorMessage
from modin.utils import _inherit_docstrings, try_cast_to_pandas, wrap_udf_function
from modin.utils import (
_inherit_docstrings,
try_cast_to_pandas,
wrap_udf_function,
hashable,
)
from modin.backends.base.query_compiler import BaseQueryCompiler
from modin.config import IsExperimental
from .series import Series
Expand Down Expand Up @@ -79,7 +84,7 @@ def __init__(
not isinstance(by, type(self._query_compiler))
and axis == 0
and all(
(isinstance(obj, str) and obj in self._query_compiler.columns)
(hashable(obj) and obj in self._query_compiler.columns)
or isinstance(obj, type(self._query_compiler))
or is_list_like(obj)
for obj in self._by
Expand Down Expand Up @@ -324,7 +329,7 @@ def __getitem__(self, key):
if (
self._is_multi_by
and isinstance(self._by, list)
and not all(isinstance(o, str) for o in self._by)
and not all(hashable(o) and o in self._df for o in self._by)
):
raise NotImplementedError(
"Column lookups on GroupBy with arbitrary Series in by"
Expand Down Expand Up @@ -809,7 +814,14 @@ def _index_grouped(self):
# aware.
ErrorMessage.catch_bugs_and_request_email(self._axis == 1)
ErrorMessage.default_to_pandas("Groupby with multiple columns")
if isinstance(by, list) and all(isinstance(o, str) for o in by):
if isinstance(by, list) and all(
hashable(o)
and (
o in self._df
or o in self._df._query_compiler.get_index_names(self._axis)
)
for o in by
):
pandas_df = self._df._query_compiler.getitem_column_array(
by
).to_pandas()
Expand Down
31 changes: 31 additions & 0 deletions modin/pandas/test/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,3 +1494,34 @@ def test_multi_column_groupby_different_partitions(
by, as_index=as_index
)
eval_general(md_grp, pd_grp, func_to_apply)


@pytest.mark.parametrize(
"by",
[
0,
1.5,
"str",
pandas.Timestamp("2020-02-02"),
[None],
[0, "str"],
[None, 0],
[pandas.Timestamp("2020-02-02"), 1.5],
],
)
@pytest.mark.parametrize("as_index", [True, False])
def test_not_str_by(by, as_index):
data = {f"col{i}": np.arange(5) for i in range(5)}
columns = pandas.Index([0, 1.5, "str", pandas.Timestamp("2020-02-02"), None])

md_df, pd_df = create_test_dfs(data, columns=columns)
md_grp, pd_grp = md_df.groupby(by, as_index=as_index), pd_df.groupby(
by, as_index=as_index
)

modin_groupby_equals_pandas(md_grp, pd_grp)
df_equals(md_grp.sum(), pd_grp.sum())
df_equals(md_grp.size(), pd_grp.size())
df_equals(md_grp.agg(lambda df: df.mean()), pd_grp.agg(lambda df: df.mean()))
df_equals(md_grp.dtypes, pd_grp.dtypes)
df_equals(md_grp.first(), pd_grp.first())

0 comments on commit 9d9dc29

Please sign in to comment.