Skip to content

Commit

Permalink
Fix groupby.get_group with length-1 tuple with list-like grouper (#17216
Browse files Browse the repository at this point in the history
)

closes #17187

Adds similar logic as implemented in pandas: https://github.com/pandas-dev/pandas/blob/main/pandas/core/groupby/groupby.py#L751-L758

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #17216
  • Loading branch information
mroeschke authored Oct 31, 2024
1 parent 02a50e8 commit a83debb
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
5 changes: 5 additions & 0 deletions python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,11 @@ def get_group(self, name, obj=None):
"instead of ``gb.get_group(name, obj=df)``.",
FutureWarning,
)
if is_list_like(self._by):
if isinstance(name, tuple) and len(name) == 1:
name = name[0]
else:
raise KeyError(name)
return obj.iloc[self.indices[name]]

@_performance_tracking
Expand Down
16 changes: 16 additions & 0 deletions python/cudf/cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -4059,3 +4059,19 @@ def test_ndim():
pgb = pser.groupby([0, 0, 1])
ggb = gser.groupby(cudf.Series([0, 0, 1]))
assert pgb.ndim == ggb.ndim


@pytest.mark.skipif(
not PANDAS_GE_220, reason="pandas behavior applicable in >=2.2"
)
def test_get_group_list_like():
df = cudf.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
result = df.groupby(["a"]).get_group((1,))
expected = df.to_pandas().groupby(["a"]).get_group((1,))
assert_eq(result, expected)

with pytest.raises(KeyError):
df.groupby(["a"]).get_group((1, 2))

with pytest.raises(KeyError):
df.groupby(["a"]).get_group([1])

0 comments on commit a83debb

Please sign in to comment.