Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise IterationError on StopIteration #473

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion toolz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# Aliases
comp = compose

from . import curried, sandbox
from . import curried, exceptions, sandbox

functoolz._sigs.create_signature_registry()

Expand Down
6 changes: 6 additions & 0 deletions toolz/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

__all__ = ('IterationError',)


class IterationError(RuntimeError):
pass
36 changes: 27 additions & 9 deletions toolz/itertoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from random import Random
from collections.abc import Sequence
from toolz.utils import no_default
from toolz.exceptions import IterationError


__all__ = ('remove', 'accumulate', 'groupby', 'merge_sorted', 'interleave',
Expand Down Expand Up @@ -373,7 +374,9 @@ def first(seq):
>>> first('ABC')
'A'
"""
return next(iter(seq))
for rv in seq:
return rv
raise IterationError("Received empty sequence")


def second(seq):
Expand All @@ -383,8 +386,14 @@ def second(seq):
'B'
"""
seq = iter(seq)
next(seq)
return next(seq)
for item in seq:
break
else:
raise IterationError("Received empty sequence")
for item in seq:
return item
else:
raise IterationError("Length of sequence is < 2")


def nth(n, seq):
Expand All @@ -396,7 +405,9 @@ def nth(n, seq):
if isinstance(seq, (tuple, list, Sequence)):
return seq[n]
else:
return next(itertools.islice(seq, n, None))
for rv in itertools.islice(seq, n, None):
return rv
raise IterationError("Length of seq is < %d" % n)


def last(seq):
Expand Down Expand Up @@ -531,8 +542,9 @@ def interpose(el, seq):
[1, 'a', 2, 'a', 3]
"""
inposed = concat(zip(itertools.repeat(el), seq))
next(inposed)
return inposed
for _ in inposed:
return inposed
raise IterationError("Received empty sequence")


def frequencies(seq):
Expand Down Expand Up @@ -722,13 +734,16 @@ def partition_all(n, seq):
"""
args = [iter(seq)] * n
it = zip_longest(*args, fillvalue=no_pad)

try:
prev = next(it)
except StopIteration:
return

for item in it:
yield prev
prev = item

if prev[-1] is no_pad:
try:
# If seq defines __len__, then
Expand Down Expand Up @@ -997,8 +1012,11 @@ def peek(seq):
[0, 1, 2, 3, 4]
"""
iterator = iter(seq)
item = next(iterator)
return item, itertools.chain((item,), iterator)
for peeked in iterator:
break
else:
raise IterationError("Received empty sequence")
return peeked, itertools.chain((peeked,), iterator)


def peekn(n, seq):
Expand All @@ -1016,7 +1034,7 @@ def peekn(n, seq):
"""
iterator = iter(seq)
peeked = tuple(take(n, iterator))
return peeked, itertools.chain(iter(peeked), iterator)
return peeked, itertools.chain(peeked, iterator)


def random_sample(prob, seq, random_state=None):
Expand Down
13 changes: 9 additions & 4 deletions toolz/tests/test_itertoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
sliding_window, count, partition,
partition_all, take_nth, pluck, join,
diff, topk, peek, peekn, random_sample)
from operator import add, mul

from toolz.exceptions import IterationError

from operator import add, mul

# is comparison will fail between this and no_default
no_default2 = loads(dumps('__no__default__'))
Expand Down Expand Up @@ -127,7 +129,7 @@ def test_nth():
assert nth(2, iter('ABCDE')) == 'C'
assert nth(1, (3, 2, 1)) == 2
assert nth(0, {'foo': 'bar'}) == 'foo'
assert raises(StopIteration, lambda: nth(10, {10: 'foo'}))
assert raises(IterationError, lambda: nth(10, {10: 'foo'}))
assert nth(-2, 'ABCDE') == 'D'
assert raises(ValueError, lambda: nth(-2, iter('ABCDE')))

Expand All @@ -136,12 +138,15 @@ def test_first():
assert first('ABCDE') == 'A'
assert first((3, 2, 1)) == 3
assert isinstance(first({0: 'zero', 1: 'one'}), int)
assert raises(IterationError, lambda: first([]))


def test_second():
assert second('ABCDE') == 'B'
assert second((3, 2, 1)) == 2
assert isinstance(second({0: 'zero', 1: 'one'}), int)
assert raises(IterationError, lambda: second([1]))
assert raises(IterationError, lambda: second([]))


def test_last():
Expand Down Expand Up @@ -228,6 +233,7 @@ def test_interpose():
assert "tXaXrXzXaXn" == "".join(interpose("X", "tarzan"))
assert list(interpose(0, itertools.repeat(1, 4))) == [1, 0, 1, 0, 1, 0, 1]
assert list(interpose('.', ['a', 'b', 'c'])) == ['a', '.', 'b', '.', 'c']
assert raises(IterationError, lambda: interpose('a', []))


def test_frequencies():
Expand Down Expand Up @@ -510,8 +516,7 @@ def test_peek():
element, blist = peek(alist)
assert element == alist[0]
assert list(blist) == alist

assert raises(StopIteration, lambda: peek([]))
assert raises(IterationError, lambda: peek([]))


def test_peekn():
Expand Down