From 029ec0b433722b7141fff6b0d1abf847086d622e Mon Sep 17 00:00:00 2001 From: Ben Hauser Date: Mon, 5 Oct 2020 14:52:33 +0300 Subject: [PATCH 1/3] feat: add type information to target value in for loop --- vyper/ast/nodes.py | 5 ++++- vyper/ast/nodes.pyi | 1 + vyper/context/validation/local.py | 6 +++++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 69546da514..a29dfea374 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -786,7 +786,10 @@ class NameConstant(Constant): class Name(VyperNode): - __slots__ = ("id",) + __slots__ = ( + "id", + "_type", + ) class Expr(VyperNode): diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 9a23fb7122..f53a02edaa 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -121,6 +121,7 @@ class NameConstant(Constant): ... class Name(VyperNode): id: str = ... + _type: str = ... class Expr(VyperNode): value: VyperNode = ... diff --git a/vyper/context/validation/local.py b/vyper/context/validation/local.py index e7859a2f31..b49df21f76 100644 --- a/vyper/context/validation/local.py +++ b/vyper/context/validation/local.py @@ -35,6 +35,7 @@ FunctionDeclarationException, ImmutableViolation, InvalidLiteral, + InvalidOperation, InvalidType, IteratorException, NonPayableViolation, @@ -369,8 +370,11 @@ def visit_For(self, node): try: for n in node.body: self.visit(n) + # attach type information to allow non `int128` types in `vyper.parser.stmt` + # this is a temporary solution until `vyper.parser` has been refactored + node.target._type = type_._id return - except TypeMismatch as exc: + except (TypeMismatch, InvalidOperation) as exc: for_loop_exceptions.append(exc) if len(set(str(i) for i in for_loop_exceptions)) == 1: From ca1f810971c8a2a382a081327c0c5ea2b8359613 Mon Sep 17 00:00:00 2001 From: Ben Hauser Date: Mon, 5 Oct 2020 14:53:39 +0300 Subject: [PATCH 2/3] fix: use iter type when generating LLL from for loop --- vyper/parser/stmt.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vyper/parser/stmt.py b/vyper/parser/stmt.py index e9f5ddc786..ed45a6e560 100644 --- a/vyper/parser/stmt.py +++ b/vyper/parser/stmt.py @@ -231,6 +231,10 @@ def parse_For(self): if not 0 < len(self.stmt.iter.args) < 3: return + # attempt to use the type specified by type checking, fall back to `int128` + # this is a stopgap solution to allow uint256 - it will be properly solved + # once we refactor `vyper.parser` + iter_typ = self.stmt.target.get("_type") or "int128" block_scope_id = id(self.stmt) with self.context.make_blockscope(block_scope_id): # Get arg0 @@ -240,15 +244,15 @@ def parse_For(self): # Type 1 for, e.g. for i in range(10): ... if num_of_args == 1: arg0_val = self._get_range_const_value(arg0) - start = LLLnode.from_list(0, typ="int128", pos=getpos(self.stmt)) + start = LLLnode.from_list(0, typ=iter_typ, pos=getpos(self.stmt)) rounds = arg0_val # Type 2 for, e.g. for i in range(100, 110): ... elif self._check_valid_range_constant(self.stmt.iter.args[1], raise_exception=False)[0]: arg0_val = self._get_range_const_value(arg0) arg1_val = self._get_range_const_value(self.stmt.iter.args[1]) - start = LLLnode.from_list(arg0_val, typ="int128", pos=getpos(self.stmt)) - rounds = LLLnode.from_list(arg1_val - arg0_val, typ="int128", pos=getpos(self.stmt)) + start = LLLnode.from_list(arg0_val, typ=iter_typ, pos=getpos(self.stmt)) + rounds = LLLnode.from_list(arg1_val - arg0_val, typ=iter_typ, pos=getpos(self.stmt)) # Type 3 for, e.g. for i in range(x, x + 10): ... else: @@ -261,7 +265,7 @@ def parse_For(self): return varname = self.stmt.target.id - pos = self.context.new_variable(varname, BaseType("int128"), pos=getpos(self.stmt)) + pos = self.context.new_variable(varname, BaseType(iter_typ), pos=getpos(self.stmt)) self.context.forvars[varname] = True lll_node = LLLnode.from_list( ["repeat", pos, start, rounds, parse_body(self.stmt.body, self.context)], From 7f3815b4bad271a6581a6a2b6021720120e50c91 Mon Sep 17 00:00:00 2001 From: Ben Hauser Date: Fri, 9 Oct 2020 20:12:10 +0300 Subject: [PATCH 3/3] test: add test cases for uint256 range iteration --- .../features/iteration/test_for_in_list.py | 13 +++ .../features/iteration/test_repeater.py | 82 ++++++++++--------- 2 files changed, 58 insertions(+), 37 deletions(-) diff --git a/tests/parser/features/iteration/test_for_in_list.py b/tests/parser/features/iteration/test_for_in_list.py index 066c188cf4..4bd11d02a2 100644 --- a/tests/parser/features/iteration/test_for_in_list.py +++ b/tests/parser/features/iteration/test_for_in_list.py @@ -506,6 +506,19 @@ def test_for() -> int128: """, TypeMismatch, ), + ( + """ +@external +def test_for() -> int128: + a: int128 = 0 + b: uint256 = 0 + for i in range(5): + a = i + b = i + return a + """, + TypeMismatch, + ), ] diff --git a/tests/parser/features/iteration/test_repeater.py b/tests/parser/features/iteration/test_repeater.py index 9cc4bf7697..192633af9c 100644 --- a/tests/parser/features/iteration/test_repeater.py +++ b/tests/parser/features/iteration/test_repeater.py @@ -1,3 +1,6 @@ +import pytest + + def test_basic_repeater(get_contract_with_gas_estimation): basic_repeater = """ @external @@ -47,30 +50,32 @@ def repeat() -> int128: assert c.repeat() == 666666 -def test_offset_repeater(get_contract_with_gas_estimation): - offset_repeater = """ +@pytest.mark.parametrize("typ", ["int128", "uint256"]) +def test_offset_repeater(get_contract_with_gas_estimation, typ): + offset_repeater = f""" @external -def sum() -> int128: - out: int128 = 0 +def sum() -> {typ}: + out: {typ} = 0 for i in range(80, 121): out = out + i - return(out) + return out """ c = get_contract_with_gas_estimation(offset_repeater) assert c.sum() == 4100 -def test_offset_repeater_2(get_contract_with_gas_estimation): - offset_repeater_2 = """ +@pytest.mark.parametrize("typ", ["int128", "uint256"]) +def test_offset_repeater_2(get_contract_with_gas_estimation, typ): + offset_repeater_2 = f""" @external -def sum(frm: int128, to: int128) -> int128: - out: int128 = 0 +def sum(frm: {typ}, to: {typ}) -> {typ}: + out: {typ} = 0 for i in range(frm, frm + 101): if i == to: break out = out + i - return(out) + return out """ c = get_contract_with_gas_estimation(offset_repeater_2) @@ -95,61 +100,64 @@ def foo() -> bool: assert c.foo() is True -def test_return_inside_repeater(get_contract): - code = """ +@pytest.mark.parametrize("typ", ["int128", "uint256"]) +def test_return_inside_repeater(get_contract, typ): + code = f""" @internal -def _final(a: int128) -> int128: +def _final(a: {typ}) -> {typ}: for i in range(10): if i > a: return i - return -42 + return 31337 @internal -def _middle(a: int128) -> int128: - b: int128 = self._final(a) +def _middle(a: {typ}) -> {typ}: + b: {typ} = self._final(a) return b @external -def foo(a: int128) -> int128: - b: int128 = self._middle(a) +def foo(a: {typ}) -> {typ}: + b: {typ} = self._middle(a) return b """ c = get_contract(code) assert c.foo(6) == 7 - assert c.foo(100) == -42 + assert c.foo(100) == 31337 -def test_return_inside_nested_repeater(get_contract): - code = """ +@pytest.mark.parametrize("typ", ["int128", "uint256"]) +def test_return_inside_nested_repeater(get_contract, typ): + code = f""" @internal -def _final(a: int128) -> int128: +def _final(a: {typ}) -> {typ}: for i in range(10): for x in range(10): if i + x > a: return i + x - return -42 + return 31337 @internal -def _middle(a: int128) -> int128: - b: int128 = self._final(a) +def _middle(a: {typ}) -> {typ}: + b: {typ} = self._final(a) return b @external -def foo(a: int128) -> int128: - b: int128 = self._middle(a) +def foo(a: {typ}) -> {typ}: + b: {typ} = self._middle(a) return b """ c = get_contract(code) assert c.foo(14) == 15 - assert c.foo(100) == -42 + assert c.foo(100) == 31337 -def test_breaks_and_returns_inside_nested_repeater(get_contract): - code = """ +@pytest.mark.parametrize("typ", ["int128", "uint256"]) +def test_breaks_and_returns_inside_nested_repeater(get_contract, typ): + code = f""" @internal -def _final(a: int128) -> int128: +def _final(a: {typ}) -> {typ}: for i in range(10): for x in range(10): if a < 2: @@ -159,20 +167,20 @@ def _final(a: int128) -> int128: break return 31337 - return -42 + return 666 @internal -def _middle(a: int128) -> int128: - b: int128 = self._final(a) +def _middle(a: {typ}) -> {typ}: + b: {typ} = self._final(a) return b @external -def foo(a: int128) -> int128: - b: int128 = self._middle(a) +def foo(a: {typ}) -> {typ}: + b: {typ} = self._middle(a) return b """ c = get_contract(code) assert c.foo(100) == 6 - assert c.foo(1) == -42 + assert c.foo(1) == 666 assert c.foo(0) == 31337