Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[codegen]: fix double evals in sqrt, slice, blueprint #3976

Merged
merged 16 commits into from
May 16, 2024
Merged
39 changes: 39 additions & 0 deletions tests/functional/builtins/codegen/test_create_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,45 @@ def test(target: address):
assert test.foo() == 12


def test_blueprint_evals_once_side_effects(get_contract, deploy_blueprint_for, env):
# test msize allocator does not get trampled by salt= kwarg
code = """
foo: public(uint256)
"""

deployer_code = """
created_address: public(address)
deployed: public(uint256)

@external
def get() -> Bytes[32]:
self.deployed += 1
return b''

@external
def create_(target: address):
self.created_address = create_from_blueprint(
target,
raw_call(self, method_id("get()"), max_outsize=32),
raw_args=True, code_offset=3
)
"""

foo_contract = get_contract(code)
expected_runtime_code = env.get_code(foo_contract.address)

f, FooContract = deploy_blueprint_for(code)

d = get_contract(deployer_code)

d.create_(f.address)

test = FooContract(d.created_address())
assert env.get_code(test.address) == expected_runtime_code
assert test.foo() == 0
assert d.deployed() == 1


def test_create_copy_of_complex_kwargs(get_contract, env):
# test msize allocator does not get trampled by salt= kwarg
complex_salt = """
Expand Down
19 changes: 19 additions & 0 deletions tests/functional/builtins/codegen/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,3 +536,22 @@ def test_slice_buffer_oob_reverts(bad_code, get_contract, tx_failed):
c = get_contract(bad_code)
with tx_failed():
c.do_slice()

cyberthirst marked this conversation as resolved.
Show resolved Hide resolved

def test_slice_length_eval_once(get_contract):
code = """
l: DynArray[uint256, 5]

@external
def foo(cs: String[64]) -> uint256:
self.l = [1, 1, 1, 1, 1]
assert len(self.l) == 5

s: Bytes[64] = b""
s = slice(msg.data, self.l.pop(), 3)

return len(self.l)
"""
arg = "a" * 64
c = get_contract(code)
assert c.foo(arg) == 4
19 changes: 19 additions & 0 deletions tests/functional/codegen/types/numbers/test_sqrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,22 @@ def test_sqrt_valid_range(sqrt_contract, value):
def test_sqrt_invalid_range(tx_failed, sqrt_contract, value):
with tx_failed():
sqrt_contract.test(decimal_to_int(value))


def test_sqrt_eval_once(get_contract):
code = """
c: uint256

@internal
def some_decimal() -> decimal:
self.c += 1
return 1.0

@external
def foo() -> uint256:
k: decimal = sqrt(self.some_decimal())
return self.c
"""

c = get_contract(code)
assert c.foo() == 1
142 changes: 78 additions & 64 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,44 +252,46 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context:
# allocate a buffer for the return value
buf = context.new_internal_variable(dst_typ)

# `msg.data` by `calldatacopy`
if sub.value == "~calldata":
node = [
"seq",
_make_slice_bounds_check(start, length, "calldatasize"),
["mstore", buf, length],
["calldatacopy", add_ofst(buf, 32), start, length],
buf,
]

# `self.code` by `codecopy`
elif sub.value == "~selfcode":
node = [
"seq",
_make_slice_bounds_check(start, length, "codesize"),
["mstore", buf, length],
["codecopy", add_ofst(buf, 32), start, length],
buf,
]
with scope_multi((start, length), ("start", "length")) as (b1, (start, length)):
# `msg.data` by `calldatacopy`
if sub.value == "~calldata":
node = [
"seq",
_make_slice_bounds_check(start, length, "calldatasize"),
["mstore", buf, length],
["calldatacopy", add_ofst(buf, 32), start, length],
buf,
]

# `<address>.code` by `extcodecopy`
else:
assert sub.value == "~extcode" and len(sub.args) == 1
node = [
"with",
"_extcode_address",
sub.args[0],
[
# `self.code` by `codecopy`
elif sub.value == "~selfcode":
node = [
"seq",
_make_slice_bounds_check(start, length, ["extcodesize", "_extcode_address"]),
_make_slice_bounds_check(start, length, "codesize"),
["mstore", buf, length],
["extcodecopy", "_extcode_address", add_ofst(buf, 32), start, length],
["codecopy", add_ofst(buf, 32), start, length],
buf,
],
]
]

assert isinstance(length.value, int) # mypy hint
return IRnode.from_list(node, typ=BytesT(length.value), location=MEMORY)
# `<address>.code` by `extcodecopy`
else:
assert sub.value == "~extcode" and len(sub.args) == 1
node = [
"with",
"_extcode_address",
sub.args[0],
[
"seq",
_make_slice_bounds_check(start, length, ["extcodesize", "_extcode_address"]),
["mstore", buf, length],
["extcodecopy", "_extcode_address", add_ofst(buf, 32), start, length],
buf,
],
]

assert isinstance(length.value, int) # mypy hint
ret = IRnode.from_list(node, typ=BytesT(length.value), location=MEMORY)
return b1.resolve(ret)


# note: this and a lot of other builtins could be refactored to accept any uint type
Expand Down Expand Up @@ -1816,9 +1818,15 @@ def _build_create_IR(
if len(ctor_args) != 1 or not isinstance(ctor_args[0].typ, BytesT):
raise StructureException("raw_args must be used with exactly 1 bytes argument")

argbuf = bytes_data_ptr(ctor_args[0])
argslen = get_bytearray_length(ctor_args[0])
bufsz = ctor_args[0].typ.maxlen
with ctor_args[0].cache_when_complex("arg") as (b1, arg):
argbuf = bytes_data_ptr(arg)
argslen = get_bytearray_length(arg)
bufsz = arg.typ.maxlen
return b1.resolve(
self._helper(
argbuf, bufsz, target, value, salt, argslen, code_offset, revert_on_failure
)
)
else:
# encode the varargs
to_encode = ir_tuple_from_args(ctor_args)
Expand All @@ -1831,7 +1839,11 @@ def _build_create_IR(
# return a complex expression which writes to memory and returns
# the length of the encoded data
argslen = abi_encode(argbuf, to_encode, context, bufsz=bufsz, returns_len=True)
return self._helper(
argbuf, bufsz, target, value, salt, argslen, code_offset, revert_on_failure
)

def _helper(self, argbuf, bufsz, target, value, salt, argslen, code_offset, revert_on_failure):
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
# NOTE: we need to invoke the abi encoder before evaluating MSIZE,
# then copy the abi encoded buffer to past-the-end of the initcode
# (since the abi encoder could write to fresh memory).
Expand Down Expand Up @@ -2118,7 +2130,8 @@ def build_IR(self, expr, args, kwargs, context):

arg = args[0]
# TODO: reify decimal and integer sqrt paths (see isqrt)
sqrt_code = """
with arg.cache_when_complex("x") as (b1, arg):
sqrt_code = """
assert x >= 0.0
z: decimal = 0.0

Expand All @@ -2133,33 +2146,34 @@ def build_IR(self, expr, args, kwargs, context):
break
y = z
z = (x / z + z) / 2.0
"""

x_type = DecimalT()
placeholder_copy = ["pass"]
# Steal current position if variable is already allocated.
if arg.value == "mload":
new_var_pos = arg.args[0]
# Other locations need to be copied.
else:
new_var_pos = context.new_internal_variable(x_type)
placeholder_copy = ["mstore", new_var_pos, arg]
# Create input variables.
variables = {"x": VariableRecord(name="x", pos=new_var_pos, typ=x_type, mutable=False)}
# Dictionary to update new (i.e. typecheck) namespace
variables_2 = {"x": VarInfo(DecimalT())}
# Generate inline IR.
new_ctx, sqrt_ir = generate_inline_function(
code=sqrt_code,
variables=variables,
variables_2=variables_2,
memory_allocator=context.memory_allocator,
)
return IRnode.from_list(
["seq", placeholder_copy, sqrt_ir, new_ctx.vars["z"].pos], # load x variable
typ=DecimalT(),
location=MEMORY,
)
"""

x_type = DecimalT()
placeholder_copy = ["pass"]
# Steal current position if variable is already allocated.
if arg.value == "mload":
new_var_pos = arg.args[0]
# Other locations need to be copied.
else:
new_var_pos = context.new_internal_variable(x_type)
placeholder_copy = ["mstore", new_var_pos, arg]
# Create input variables.
variables = {"x": VariableRecord(name="x", pos=new_var_pos, typ=x_type, mutable=False)}
# Dictionary to update new (i.e. typecheck) namespace
variables_2 = {"x": VarInfo(DecimalT())}
# Generate inline IR.
new_ctx, sqrt_ir = generate_inline_function(
code=sqrt_code,
variables=variables,
variables_2=variables_2,
memory_allocator=context.memory_allocator,
)
ret = IRnode.from_list(
["seq", placeholder_copy, sqrt_ir, new_ctx.vars["z"].pos], # load x variable
typ=DecimalT(),
location=MEMORY,
)
return b1.resolve(ret)


class ISqrt(BuiltinFunctionT):
Expand Down
Loading