Skip to content

Commit

Permalink
Merge pull request #2180 from iamdefinitelyahuman/fix-for-iterator-type
Browse files Browse the repository at this point in the history
Allow uint256 as iterator type in range-based for loop
  • Loading branch information
fubuloubu authored Oct 9, 2020
2 parents 293f83d + 7f3815b commit ce85d7d
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 43 deletions.
13 changes: 13 additions & 0 deletions tests/parser/features/iteration/test_for_in_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,19 @@ def test_for() -> int128:
""",
TypeMismatch,
),
(
"""
@external
def test_for() -> int128:
a: int128 = 0
b: uint256 = 0
for i in range(5):
a = i
b = i
return a
""",
TypeMismatch,
),
]


Expand Down
82 changes: 45 additions & 37 deletions tests/parser/features/iteration/test_repeater.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pytest


def test_basic_repeater(get_contract_with_gas_estimation):
basic_repeater = """
@external
Expand Down Expand Up @@ -47,30 +50,32 @@ def repeat() -> int128:
assert c.repeat() == 666666


def test_offset_repeater(get_contract_with_gas_estimation):
offset_repeater = """
@pytest.mark.parametrize("typ", ["int128", "uint256"])
def test_offset_repeater(get_contract_with_gas_estimation, typ):
offset_repeater = f"""
@external
def sum() -> int128:
out: int128 = 0
def sum() -> {typ}:
out: {typ} = 0
for i in range(80, 121):
out = out + i
return(out)
return out
"""

c = get_contract_with_gas_estimation(offset_repeater)
assert c.sum() == 4100


def test_offset_repeater_2(get_contract_with_gas_estimation):
offset_repeater_2 = """
@pytest.mark.parametrize("typ", ["int128", "uint256"])
def test_offset_repeater_2(get_contract_with_gas_estimation, typ):
offset_repeater_2 = f"""
@external
def sum(frm: int128, to: int128) -> int128:
out: int128 = 0
def sum(frm: {typ}, to: {typ}) -> {typ}:
out: {typ} = 0
for i in range(frm, frm + 101):
if i == to:
break
out = out + i
return(out)
return out
"""

c = get_contract_with_gas_estimation(offset_repeater_2)
Expand All @@ -95,61 +100,64 @@ def foo() -> bool:
assert c.foo() is True


def test_return_inside_repeater(get_contract):
code = """
@pytest.mark.parametrize("typ", ["int128", "uint256"])
def test_return_inside_repeater(get_contract, typ):
code = f"""
@internal
def _final(a: int128) -> int128:
def _final(a: {typ}) -> {typ}:
for i in range(10):
if i > a:
return i
return -42
return 31337
@internal
def _middle(a: int128) -> int128:
b: int128 = self._final(a)
def _middle(a: {typ}) -> {typ}:
b: {typ} = self._final(a)
return b
@external
def foo(a: int128) -> int128:
b: int128 = self._middle(a)
def foo(a: {typ}) -> {typ}:
b: {typ} = self._middle(a)
return b
"""

c = get_contract(code)
assert c.foo(6) == 7
assert c.foo(100) == -42
assert c.foo(100) == 31337


def test_return_inside_nested_repeater(get_contract):
code = """
@pytest.mark.parametrize("typ", ["int128", "uint256"])
def test_return_inside_nested_repeater(get_contract, typ):
code = f"""
@internal
def _final(a: int128) -> int128:
def _final(a: {typ}) -> {typ}:
for i in range(10):
for x in range(10):
if i + x > a:
return i + x
return -42
return 31337
@internal
def _middle(a: int128) -> int128:
b: int128 = self._final(a)
def _middle(a: {typ}) -> {typ}:
b: {typ} = self._final(a)
return b
@external
def foo(a: int128) -> int128:
b: int128 = self._middle(a)
def foo(a: {typ}) -> {typ}:
b: {typ} = self._middle(a)
return b
"""

c = get_contract(code)
assert c.foo(14) == 15
assert c.foo(100) == -42
assert c.foo(100) == 31337


def test_breaks_and_returns_inside_nested_repeater(get_contract):
code = """
@pytest.mark.parametrize("typ", ["int128", "uint256"])
def test_breaks_and_returns_inside_nested_repeater(get_contract, typ):
code = f"""
@internal
def _final(a: int128) -> int128:
def _final(a: {typ}) -> {typ}:
for i in range(10):
for x in range(10):
if a < 2:
Expand All @@ -159,20 +167,20 @@ def _final(a: int128) -> int128:
break
return 31337
return -42
return 666
@internal
def _middle(a: int128) -> int128:
b: int128 = self._final(a)
def _middle(a: {typ}) -> {typ}:
b: {typ} = self._final(a)
return b
@external
def foo(a: int128) -> int128:
b: int128 = self._middle(a)
def foo(a: {typ}) -> {typ}:
b: {typ} = self._middle(a)
return b
"""

c = get_contract(code)
assert c.foo(100) == 6
assert c.foo(1) == -42
assert c.foo(1) == 666
assert c.foo(0) == 31337
5 changes: 4 additions & 1 deletion vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,10 @@ class NameConstant(Constant):


class Name(VyperNode):
__slots__ = ("id",)
__slots__ = (
"id",
"_type",
)


class Expr(VyperNode):
Expand Down
1 change: 1 addition & 0 deletions vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class NameConstant(Constant): ...

class Name(VyperNode):
id: str = ...
_type: str = ...

class Expr(VyperNode):
value: VyperNode = ...
Expand Down
6 changes: 5 additions & 1 deletion vyper/context/validation/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
FunctionDeclarationException,
ImmutableViolation,
InvalidLiteral,
InvalidOperation,
InvalidType,
IteratorException,
NonPayableViolation,
Expand Down Expand Up @@ -369,8 +370,11 @@ def visit_For(self, node):
try:
for n in node.body:
self.visit(n)
# attach type information to allow non `int128` types in `vyper.parser.stmt`
# this is a temporary solution until `vyper.parser` has been refactored
node.target._type = type_._id
return
except TypeMismatch as exc:
except (TypeMismatch, InvalidOperation) as exc:
for_loop_exceptions.append(exc)

if len(set(str(i) for i in for_loop_exceptions)) == 1:
Expand Down
12 changes: 8 additions & 4 deletions vyper/parser/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def parse_For(self):
if not 0 < len(self.stmt.iter.args) < 3:
return

# attempt to use the type specified by type checking, fall back to `int128`
# this is a stopgap solution to allow uint256 - it will be properly solved
# once we refactor `vyper.parser`
iter_typ = self.stmt.target.get("_type") or "int128"
block_scope_id = id(self.stmt)
with self.context.make_blockscope(block_scope_id):
# Get arg0
Expand All @@ -240,15 +244,15 @@ def parse_For(self):
# Type 1 for, e.g. for i in range(10): ...
if num_of_args == 1:
arg0_val = self._get_range_const_value(arg0)
start = LLLnode.from_list(0, typ="int128", pos=getpos(self.stmt))
start = LLLnode.from_list(0, typ=iter_typ, pos=getpos(self.stmt))
rounds = arg0_val

# Type 2 for, e.g. for i in range(100, 110): ...
elif self._check_valid_range_constant(self.stmt.iter.args[1], raise_exception=False)[0]:
arg0_val = self._get_range_const_value(arg0)
arg1_val = self._get_range_const_value(self.stmt.iter.args[1])
start = LLLnode.from_list(arg0_val, typ="int128", pos=getpos(self.stmt))
rounds = LLLnode.from_list(arg1_val - arg0_val, typ="int128", pos=getpos(self.stmt))
start = LLLnode.from_list(arg0_val, typ=iter_typ, pos=getpos(self.stmt))
rounds = LLLnode.from_list(arg1_val - arg0_val, typ=iter_typ, pos=getpos(self.stmt))

# Type 3 for, e.g. for i in range(x, x + 10): ...
else:
Expand All @@ -261,7 +265,7 @@ def parse_For(self):
return

varname = self.stmt.target.id
pos = self.context.new_variable(varname, BaseType("int128"), pos=getpos(self.stmt))
pos = self.context.new_variable(varname, BaseType(iter_typ), pos=getpos(self.stmt))
self.context.forvars[varname] = True
lll_node = LLLnode.from_list(
["repeat", pos, start, rounds, parse_body(self.stmt.body, self.context)],
Expand Down

0 comments on commit ce85d7d

Please sign in to comment.