Skip to content

Commit

Permalink
fix[codegen]: fix double eval of start in range expr (#4033)
Browse files Browse the repository at this point in the history
cache the start argument to the stack, add double eval tests

---------

Co-authored-by: Charles Cooper <[email protected]>
  • Loading branch information
cyberthirst and charles-cooper authored May 25, 2024
1 parent f0afaf0 commit b9244e8
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 32 deletions.
34 changes: 34 additions & 0 deletions tests/functional/codegen/features/iteration/test_for_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,3 +473,37 @@ def foo() -> DynArray[int256, 10]:
return
with pytest.raises(StaticAssertionException):
get_contract(code)


def test_for_range_start_double_eval(get_contract, tx_failed):
code = """
@external
def foo() -> (uint256, DynArray[uint256, 3]):
x:DynArray[uint256, 3] = [3, 1]
res: DynArray[uint256, 3] = empty(DynArray[uint256, 3])
for i:uint256 in range(x.pop(),x.pop(), bound = 3):
res.append(i)
return len(x), res
"""
c = get_contract(code)
length, res = c.foo()

assert (length, res) == (0, [1, 2])


def test_for_range_stop_double_eval(get_contract, tx_failed):
code = """
@external
def foo() -> (uint256, DynArray[uint256, 3]):
x:DynArray[uint256, 3] = [3, 3]
res: DynArray[uint256, 3] = empty(DynArray[uint256, 3])
for i:uint256 in range(x.pop(), bound = 3):
res.append(i)
return len(x), res
"""
c = get_contract(code)
length, res = c.foo()

assert (length, res) == (1, [0, 1, 2])
63 changes: 31 additions & 32 deletions vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,44 +200,43 @@ def _parse_For_range(self):
# sanity check that the following `end - start` is a valid operation
assert start.typ == end.typ == target_type

if "bound" in kwargs:
with end.cache_when_complex("end") as (b1, end):
# note: the check for rounds<=rounds_bound happens in asm
# generation for `repeat`.
clamped_start = clamp_le(start, end, target_type.is_signed)
rounds = b1.resolve(IRnode.from_list(["sub", end, clamped_start]))
rounds_bound = kwargs.pop("bound").int_value()
else:
rounds = end.int_value() - start.int_value()
rounds_bound = rounds
with start.cache_when_complex("start") as (b1, start):
if "bound" in kwargs:
with end.cache_when_complex("end") as (b2, end):
# note: the check for rounds<=rounds_bound happens in asm
# generation for `repeat`.
clamped_start = clamp_le(start, end, target_type.is_signed)
rounds = b2.resolve(IRnode.from_list(["sub", end, clamped_start]))
rounds_bound = kwargs.pop("bound").int_value()
else:
rounds = end.int_value() - start.int_value()
rounds_bound = rounds

assert len(kwargs) == 0 # sanity check stray keywords
assert len(kwargs) == 0 # sanity check stray keywords

if rounds_bound < 1: # pragma: nocover
raise TypeCheckFailure("unreachable: unchecked 0 bound")
if rounds_bound < 1: # pragma: nocover
raise TypeCheckFailure("unreachable: unchecked 0 bound")

varname = self.stmt.target.target.id
i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=target_type)
iptr = self.context.new_variable(varname, target_type)
varname = self.stmt.target.target.id
i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=target_type)
iptr = self.context.new_variable(varname, target_type)

self.context.forvars[varname] = True
self.context.forvars[varname] = True

loop_body = ["seq"]
# store the current value of i so it is accessible to userland
loop_body.append(["mstore", iptr, i])
loop_body.append(parse_body(self.stmt.body, self.context))

# NOTE: codegen for `repeat` inserts an assertion that
# (gt rounds_bound rounds). note this also covers the case where
# rounds < 0.
# if we ever want to remove that, we need to manually add the assertion
# where it makes sense.
ir_node = IRnode.from_list(
["repeat", i, start, rounds, rounds_bound, loop_body], error_msg="range() bounds check"
)
del self.context.forvars[varname]
loop_body = ["seq"]
# store the current value of i so it is accessible to userland
loop_body.append(["mstore", iptr, i])
loop_body.append(parse_body(self.stmt.body, self.context))

del self.context.forvars[varname]

return ir_node
# NOTE: codegen for `repeat` inserts an assertion that
# (gt rounds_bound rounds). note this also covers the case where
# rounds < 0.
# if we ever want to remove that, we need to manually add the assertion
# where it makes sense.
loop = ["repeat", i, start, rounds, rounds_bound, loop_body]
return b1.resolve(IRnode.from_list(loop, error_msg="range() bounds check"))

def _parse_For_list(self):
with self.context.range_scope():
Expand Down

0 comments on commit b9244e8

Please sign in to comment.