Skip to content

Commit

Permalink
Simplify DataFrame.__getitem__ (#771)
Browse files Browse the repository at this point in the history
* simplify DataFrame.__getitem__

* restrict to Iterable with Hashable/Scalar elements

* unrelated: widen type of DataFrame.__iter__, can be any column type

* Scalar | Hashable -> Hashable
  • Loading branch information
twoertwein authored Sep 1, 2023
1 parent ab5c643 commit 6bb1215
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 18 deletions.
22 changes: 4 additions & 18 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections.abc import (
Callable,
Generator,
Hashable,
Iterable,
Iterator,
Expand All @@ -14,7 +13,6 @@ from typing import (
Any,
ClassVar,
Literal,
TypeVar,
overload,
)

Expand Down Expand Up @@ -119,8 +117,6 @@ from pandas._typing import (
ValidationOptions,
WriteBuffer,
XMLParsers,
np_ndarray_bool,
np_ndarray_str,
npt,
num,
)
Expand All @@ -130,7 +126,6 @@ from pandas.plotting import PlotAccessor

_str = str
_bool = bool
_ScalarOrTupleT = TypeVar("_ScalarOrTupleT", bound=Scalar | tuple[Hashable, ...])

class _iLocIndexerFrame(_iLocIndexer):
@overload
Expand Down Expand Up @@ -553,20 +548,11 @@ class DataFrame(NDFrame, OpsMixin):
def T(self) -> DataFrame: ...
def __getattr__(self, name: str) -> Series: ...
@overload
def __getitem__( # type: ignore[misc]
self,
key: Series
| DataFrame
| Index
| np_ndarray_str
| np_ndarray_bool
| list[_ScalarOrTupleT]
| Generator[_ScalarOrTupleT, None, None],
) -> DataFrame: ...
def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[misc]
@overload
def __getitem__(self, key: slice) -> DataFrame: ...
def __getitem__(self, key: Iterable[Hashable] | slice) -> DataFrame: ...
@overload
def __getitem__(self, key: Scalar | Hashable) -> Series: ...
def __getitem__(self, key: Hashable) -> Series: ...
def isetitem(
self, loc: int | Sequence[int], value: Scalar | ArrayLike | list[Any]
) -> None: ...
Expand Down Expand Up @@ -1477,7 +1463,7 @@ class DataFrame(NDFrame, OpsMixin):
Name: _str
#
# dunder methods
def __iter__(self) -> Iterator[float | _str]: ...
def __iter__(self) -> Iterator[Hashable]: ...
# properties
@property
def at(self): ... # Not sure what to do with this yet; look at source
Expand Down
7 changes: 7 additions & 0 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2811,3 +2811,10 @@ def test_groupby_fillna_inplace() -> None:
def test_getitem_generator() -> None:
# GH 685
check(assert_type(DF[(f"col{i+1}" for i in range(2))], pd.DataFrame), pd.DataFrame)


def test_getitem_dict_keys() -> None:
# GH 770
some_columns = {"a": [1], "b": [2]}
df = pd.DataFrame.from_dict(some_columns)
check(assert_type(df[some_columns.keys()], pd.DataFrame), pd.DataFrame)

0 comments on commit 6bb1215

Please sign in to comment.