From b9244e89e8755de00789eedced130748a09f1b06 Mon Sep 17 00:00:00 2001 From: cyberthirst Date: Sat, 25 May 2024 13:52:17 +0200 Subject: [PATCH] fix[codegen]: fix double eval of start in range expr (#4033) cache the start argument to the stack, add double eval tests --------- Co-authored-by: Charles Cooper --- .../features/iteration/test_for_range.py | 34 ++++++++++ vyper/codegen/stmt.py | 63 +++++++++---------- 2 files changed, 65 insertions(+), 32 deletions(-) diff --git a/tests/functional/codegen/features/iteration/test_for_range.py b/tests/functional/codegen/features/iteration/test_for_range.py index 61c6f453d6..b8cf8c2592 100644 --- a/tests/functional/codegen/features/iteration/test_for_range.py +++ b/tests/functional/codegen/features/iteration/test_for_range.py @@ -473,3 +473,37 @@ def foo() -> DynArray[int256, 10]: return with pytest.raises(StaticAssertionException): get_contract(code) + + +def test_for_range_start_double_eval(get_contract, tx_failed): + code = """ +@external +def foo() -> (uint256, DynArray[uint256, 3]): + x:DynArray[uint256, 3] = [3, 1] + res: DynArray[uint256, 3] = empty(DynArray[uint256, 3]) + for i:uint256 in range(x.pop(),x.pop(), bound = 3): + res.append(i) + + return len(x), res + """ + c = get_contract(code) + length, res = c.foo() + + assert (length, res) == (0, [1, 2]) + + +def test_for_range_stop_double_eval(get_contract, tx_failed): + code = """ +@external +def foo() -> (uint256, DynArray[uint256, 3]): + x:DynArray[uint256, 3] = [3, 3] + res: DynArray[uint256, 3] = empty(DynArray[uint256, 3]) + for i:uint256 in range(x.pop(), bound = 3): + res.append(i) + + return len(x), res + """ + c = get_contract(code) + length, res = c.foo() + + assert (length, res) == (1, [0, 1, 2]) diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 562a9d85d7..f29c0ea42d 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -200,44 +200,43 @@ def _parse_For_range(self): # sanity check that the following `end - start` is a valid operation assert start.typ == end.typ == target_type - if "bound" in kwargs: - with end.cache_when_complex("end") as (b1, end): - # note: the check for rounds<=rounds_bound happens in asm - # generation for `repeat`. - clamped_start = clamp_le(start, end, target_type.is_signed) - rounds = b1.resolve(IRnode.from_list(["sub", end, clamped_start])) - rounds_bound = kwargs.pop("bound").int_value() - else: - rounds = end.int_value() - start.int_value() - rounds_bound = rounds + with start.cache_when_complex("start") as (b1, start): + if "bound" in kwargs: + with end.cache_when_complex("end") as (b2, end): + # note: the check for rounds<=rounds_bound happens in asm + # generation for `repeat`. + clamped_start = clamp_le(start, end, target_type.is_signed) + rounds = b2.resolve(IRnode.from_list(["sub", end, clamped_start])) + rounds_bound = kwargs.pop("bound").int_value() + else: + rounds = end.int_value() - start.int_value() + rounds_bound = rounds - assert len(kwargs) == 0 # sanity check stray keywords + assert len(kwargs) == 0 # sanity check stray keywords - if rounds_bound < 1: # pragma: nocover - raise TypeCheckFailure("unreachable: unchecked 0 bound") + if rounds_bound < 1: # pragma: nocover + raise TypeCheckFailure("unreachable: unchecked 0 bound") - varname = self.stmt.target.target.id - i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=target_type) - iptr = self.context.new_variable(varname, target_type) + varname = self.stmt.target.target.id + i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=target_type) + iptr = self.context.new_variable(varname, target_type) - self.context.forvars[varname] = True + self.context.forvars[varname] = True - loop_body = ["seq"] - # store the current value of i so it is accessible to userland - loop_body.append(["mstore", iptr, i]) - loop_body.append(parse_body(self.stmt.body, self.context)) - - # NOTE: codegen for `repeat` inserts an assertion that - # (gt rounds_bound rounds). note this also covers the case where - # rounds < 0. - # if we ever want to remove that, we need to manually add the assertion - # where it makes sense. - ir_node = IRnode.from_list( - ["repeat", i, start, rounds, rounds_bound, loop_body], error_msg="range() bounds check" - ) - del self.context.forvars[varname] + loop_body = ["seq"] + # store the current value of i so it is accessible to userland + loop_body.append(["mstore", iptr, i]) + loop_body.append(parse_body(self.stmt.body, self.context)) + + del self.context.forvars[varname] - return ir_node + # NOTE: codegen for `repeat` inserts an assertion that + # (gt rounds_bound rounds). note this also covers the case where + # rounds < 0. + # if we ever want to remove that, we need to manually add the assertion + # where it makes sense. + loop = ["repeat", i, start, rounds, rounds_bound, loop_body] + return b1.resolve(IRnode.from_list(loop, error_msg="range() bounds check")) def _parse_For_list(self): with self.context.range_scope():