From 6bb12157a810cc26e034e012bff6a0ec2da191d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Fri, 1 Sep 2023 10:19:01 -0400 Subject: [PATCH] Simplify DataFrame.__getitem__ (#771) * simplify DataFrame.__getitem__ * restrict to Iterable with Hashable/Scalar elements * unrelated: widen type of DataFrame.__iter__, can be any column type * Scalar | Hashable -> Hashable --- pandas-stubs/core/frame.pyi | 22 ++++------------------ tests/test_frame.py | 7 +++++++ 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index eff25f3c..d2394cdf 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -1,6 +1,5 @@ from collections.abc import ( Callable, - Generator, Hashable, Iterable, Iterator, @@ -14,7 +13,6 @@ from typing import ( Any, ClassVar, Literal, - TypeVar, overload, ) @@ -119,8 +117,6 @@ from pandas._typing import ( ValidationOptions, WriteBuffer, XMLParsers, - np_ndarray_bool, - np_ndarray_str, npt, num, ) @@ -130,7 +126,6 @@ from pandas.plotting import PlotAccessor _str = str _bool = bool -_ScalarOrTupleT = TypeVar("_ScalarOrTupleT", bound=Scalar | tuple[Hashable, ...]) class _iLocIndexerFrame(_iLocIndexer): @overload @@ -553,20 +548,11 @@ class DataFrame(NDFrame, OpsMixin): def T(self) -> DataFrame: ... def __getattr__(self, name: str) -> Series: ... @overload - def __getitem__( # type: ignore[misc] - self, - key: Series - | DataFrame - | Index - | np_ndarray_str - | np_ndarray_bool - | list[_ScalarOrTupleT] - | Generator[_ScalarOrTupleT, None, None], - ) -> DataFrame: ... + def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[misc] @overload - def __getitem__(self, key: slice) -> DataFrame: ... + def __getitem__(self, key: Iterable[Hashable] | slice) -> DataFrame: ... @overload - def __getitem__(self, key: Scalar | Hashable) -> Series: ... + def __getitem__(self, key: Hashable) -> Series: ... def isetitem( self, loc: int | Sequence[int], value: Scalar | ArrayLike | list[Any] ) -> None: ... @@ -1477,7 +1463,7 @@ class DataFrame(NDFrame, OpsMixin): Name: _str # # dunder methods - def __iter__(self) -> Iterator[float | _str]: ... + def __iter__(self) -> Iterator[Hashable]: ... # properties @property def at(self): ... # Not sure what to do with this yet; look at source diff --git a/tests/test_frame.py b/tests/test_frame.py index 383f4ecc..e1e1f7e2 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -2811,3 +2811,10 @@ def test_groupby_fillna_inplace() -> None: def test_getitem_generator() -> None: # GH 685 check(assert_type(DF[(f"col{i+1}" for i in range(2))], pd.DataFrame), pd.DataFrame) + + +def test_getitem_dict_keys() -> None: + # GH 770 + some_columns = {"a": [1], "b": [2]} + df = pd.DataFrame.from_dict(some_columns) + check(assert_type(df[some_columns.keys()], pd.DataFrame), pd.DataFrame)