Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pop for loop values from stack prior to returning #2110

Merged
merged 3 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 83 additions & 8 deletions tests/parser/features/iteration/test_repeater.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ def repeat(z: int128) -> int128:
"""
c = get_contract_with_gas_estimation(basic_repeater)
assert c.repeat(9) == 54
print("Passed basic repeater test")


def test_digit_reverser(get_contract_with_gas_estimation):
Expand All @@ -30,7 +29,6 @@ def reverse_digits(x: int128) -> int128:

c = get_contract_with_gas_estimation(digit_reverser)
assert c.reverse_digits(123456) == 654321
print("Passed digit reverser test")


def test_more_complex_repeater(get_contract_with_gas_estimation):
Expand All @@ -48,8 +46,6 @@ def repeat() -> int128:
c = get_contract_with_gas_estimation(more_complex_repeater)
assert c.repeat() == 666666

print("Passed complex repeater test")


def test_offset_repeater(get_contract_with_gas_estimation):
offset_repeater = """
Expand All @@ -64,8 +60,6 @@ def sum() -> int128:
c = get_contract_with_gas_estimation(offset_repeater)
assert c.sum() == 4100

print("Passed repeater with offset test")


def test_offset_repeater_2(get_contract_with_gas_estimation):
offset_repeater_2 = """
Expand All @@ -83,8 +77,6 @@ def sum(frm: int128, to: int128) -> int128:
assert c.sum(100, 99999) == 15150
assert c.sum(70, 131) == 6100

print("Passed more complex repeater with offset test")


def test_loop_call_priv(get_contract_with_gas_estimation):
code = """
Expand All @@ -101,3 +93,86 @@ def foo() -> bool:

c = get_contract_with_gas_estimation(code)
assert c.foo() is True


def test_return_inside_repeater(get_contract):
code = """
@internal
def _final(a: int128) -> int128:
for i in range(10):
if i > a:
return i
return -42

@internal
def _middle(a: int128) -> int128:
b: int128 = self._final(a)
return b

@external
def foo(a: int128) -> int128:
b: int128 = self._middle(a)
return b
"""

c = get_contract(code)
assert c.foo(6) == 7
assert c.foo(100) == -42


def test_return_inside_nested_repeater(get_contract):
code = """
@internal
def _final(a: int128) -> int128:
for i in range(10):
for x in range(10):
if i + x > a:
return i + x
return -42

@internal
def _middle(a: int128) -> int128:
b: int128 = self._final(a)
return b

@external
def foo(a: int128) -> int128:
b: int128 = self._middle(a)
return b
"""

c = get_contract(code)
assert c.foo(14) == 15
assert c.foo(100) == -42


def test_breaks_and_returns_inside_nested_repeater(get_contract):
code = """
@internal
def _final(a: int128) -> int128:
for i in range(10):
for x in range(10):
if a < 2:
break
return 6
if a == 1:
break
return 31337

return -42

@internal
def _middle(a: int128) -> int128:
b: int128 = self._final(a)
return b

@external
def foo(a: int128) -> int128:
b: int128 = self._middle(a)
return b
"""

c = get_contract(code)
assert c.foo(100) == 6
assert c.foo(1) == -42
assert c.foo(0) == 31337
23 changes: 11 additions & 12 deletions vyper/codegen/return_.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@ def make_return_stmt(stmt, context, begin_pos, _size, loop_memory_position=None)
if isinstance(begin_pos, int) and isinstance(_size, int):
# static values, unroll the mloads instead.
mloads = [["mload", pos] for pos in range(begin_pos, _size, 32)]
return (
["seq_unchecked"]
+ mloads
+ nonreentrant_post
+ [["jump", ["mload", context.callback_ptr]]]
)
else:
mloads = [
"seq_unchecked",
Expand All @@ -54,12 +48,17 @@ def make_return_stmt(stmt, context, begin_pos, _size, loop_memory_position=None)
["goto", start_label],
["label", exit_label],
]
return (
["seq_unchecked"]
+ [mloads]
+ nonreentrant_post
+ [["jump", ["mload", context.callback_ptr]]]
)

# if we are in a for loop, we have to exit prior to returning
exit_repeater = ["exit_repeater"] if context.forvars else []

return (
["seq_unchecked"]
+ exit_repeater
+ mloads
+ nonreentrant_post
+ [["jump", ["mload", context.callback_ptr]]]
)
else:
return ["seq_unchecked"] + nonreentrant_post + [["return", begin_pos, _size]]

Expand Down
10 changes: 8 additions & 2 deletions vyper/compile_lll.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,21 @@ def compile_to_assembly(code, withargs=None, existing_labels=None, break_dest=No
# Continue to the next iteration of the for loop
elif code.value == "continue":
if not break_dest:
raise Exception("Invalid break")
raise CompilerPanic("Invalid break")
dest, continue_dest, break_height = break_dest
return [continue_dest, "JUMP"]
# Break from inside a for loop
elif code.value == "break":
if not break_dest:
raise Exception("Invalid break")
raise CompilerPanic("Invalid break")
dest, continue_dest, break_height = break_dest
return ["POP"] * (height - break_height) + [dest, "JUMP"]
# Break from inside one or more for loops prior to a return statement inside the loop
elif code.value == "exit_repeater":
if not break_dest:
raise CompilerPanic("Invalid break")
_, _, break_height = break_dest
return ["POP"] * break_height
# With statements
elif code.value == "with":
o = []
Expand Down