Skip to content

Commit

Permalink
dev: migrate typing info to inline
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBernstorff committed Mar 9, 2024
1 parent 76efe89 commit f6a2d80
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 249 deletions.
106 changes: 87 additions & 19 deletions iterpy/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict
from functools import reduce
from itertools import islice
from typing import TYPE_CHECKING, Any, Generator, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generator, Generic, Sequence, TypeVar, overload

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator
Expand All @@ -32,6 +32,11 @@ def __iter__(self) -> Iter[T]:
def __next__(self) -> T:
return next(self._iter)

@overload
def __getitem__(self, index: int) -> T: ...
@overload
def __getitem__(self, index: slice) -> Iter[T]: ...

def __getitem__(self, index: int | slice) -> T | Iter[T]:
if isinstance(index, int) and index >= 0:
try:
Expand Down Expand Up @@ -98,24 +103,6 @@ def groupby(self, func: Callable[[T], str]) -> Iter[tuple[str, list[T]]]:
tuples = list(groups_with_values.items())
return Iter(tuples)

def flatten(self) -> Iter[T]:
depth = 1

def walk(node: Any, level: int) -> Generator[T, None, None]:
if (level > depth) or isinstance(node, str):
yield node # type: ignore
return
try:
tree = iter(node)
except TypeError:
yield node
return
else:
for child in tree:
yield from walk(child, level + 1)

return Iter(walk(self, level=0)) # type: ignore

def take(self, n: int = 1) -> Iter[T]:
return Iter(islice(self._iter, n))

Expand Down Expand Up @@ -154,3 +141,84 @@ def clone(self) -> Iter[T]:

def zip(self, other: Iter[S]) -> Iter[tuple[T, S]]:
return Iter(zip(self, other))

############################################################
# Auto-generated overloads for flatten #
# Code for generating the following is in _generate_pyi.py #
############################################################
# Overloads are technically incompatible, because they use generic S instead of T. However, this is required for the flattening logic to work.

@overload
def flatten(self: Iter[Iterable[S]]) -> Iter[S]: ...
@overload
def flatten(self: Iter[Iterable[S] | S]) -> Iter[S]: ...

# Iterator[S] # noqa: ERA001
@overload
def flatten(self: Iter[Iterator[S]]) -> Iter[S]: ...
@overload
def flatten(self: Iter[Iterator[S] | S]) -> Iter[S]: ...

# tuple[S, ...] # noqa: ERA001
@overload
def flatten(self: Iter[tuple[S, ...]]) -> Iter[S]: ...
@overload
def flatten(self: Iter[tuple[S, ...] | S]) -> Iter[S]: ...

# Sequence[S] # noqa: ERA001
@overload
def flatten(self: Iter[Sequence[S]]) -> Iter[S]: ...
@overload
def flatten(self: Iter[Sequence[S] | S]) -> Iter[S]: ...

# list[S] # noqa: ERA001
@overload
def flatten(self: Iter[list[S]]) -> Iter[S]: ...
@overload
def flatten(self: Iter[list[S] | S]) -> Iter[S]: ...

# set[S] # noqa: ERA001
@overload
def flatten(self: Iter[set[S]]) -> Iter[S]: ...
@overload
def flatten(self: Iter[set[S] | S]) -> Iter[S]: ...

# frozenset[S] # noqa: ERA001
@overload
def flatten(self: Iter[frozenset[S]]) -> Iter[S]: ...
@overload
def flatten(self: Iter[frozenset[S] | S]) -> Iter[S]: ...

# Iter[S] # noqa: ERA001
@overload
def flatten(self: Iter[Iter[S]]) -> Iter[S]: ...
@overload
def flatten(self: Iter[Iter[S] | S]) -> Iter[S]: ...

# str
@overload
def flatten(self: Iter[str]) -> Iter[str]: ...
@overload
def flatten(self: Iter[str | S]) -> Iter[S]: ...

# Generic
@overload
def flatten(self: Iter[S]) -> Iter[S]: ...

def flatten(self) -> Iter[T]: # type: ignore -
depth = 1

def walk(node: Any, level: int) -> Generator[T, None, None]:
if (level > depth) or isinstance(node, str):
yield node # type: ignore
return
try:
tree = iter(node)
except TypeError:
yield node
return
else:
for child in tree:
yield from walk(child, level + 1)

return Iter(walk(self, level=0)) # type: ignore
100 changes: 0 additions & 100 deletions iterpy/iter.pyi

This file was deleted.

3 changes: 2 additions & 1 deletion iterpy/test_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def test_flatten_list(self):

def test_flatten_iterator(self):
test_input: Sequence[Sequence[int]] = [[1, 2], [3, 4]]
result: Iter[int] = Iter(test_input).flatten()
iterator = Iter(test_input)
result: Iter[int] = iterator.flatten()
assert result.to_list() == [1, 2, 3, 4]

def test_flatten_iter_iter(self):
Expand Down
16 changes: 3 additions & 13 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ cookiecutter==2.5.0
coverage==7.4.1
# via pytest-cov
cruft==2.15.0
# via iterpy
diff-cover==8.0.3
# via iterpy
exceptiongroup==1.2.0
# via pytest
execnet==2.0.2
Expand All @@ -53,7 +51,6 @@ jinja2==3.1.3
# via cookiecutter
# via diff-cover
lumberman==0.45.0
# via iterpy
markdown-it-py==3.0.0
# via rich
markupsafe==2.1.5
Expand All @@ -73,22 +70,16 @@ pycparser==2.21
pygments==2.17.2
# via diff-cover
# via rich
pyright==1.1.350
# via iterpy
pytest==8.0.1
# via iterpy
pyright==1.1.352
pytest==8.1.0
# via pytest-codspeed
# via pytest-cov
# via pytest-sugar
# via pytest-xdist
pytest-codspeed==2.2.0
# via iterpy
pytest-cov==4.1.0
# via iterpy
pytest-sugar==1.0.0
# via iterpy
pytest-xdist==3.5.0
# via iterpy
python-dateutil==2.8.2
# via arrow
python-slugify==8.0.4
Expand All @@ -100,8 +91,7 @@ requests==2.31.0
rich==13.7.0
# via cookiecutter
# via lumberman
ruff==0.2.2
# via iterpy
ruff==0.3.0
setuptools==69.1.0
# via nodeenv
six==1.16.0
Expand Down
Loading

0 comments on commit f6a2d80

Please sign in to comment.