diff --git a/docs/statements.rst b/docs/statements.rst index 88b8d3a712..f443b2a38c 100644 --- a/docs/statements.rst +++ b/docs/statements.rst @@ -93,7 +93,7 @@ The ``raise`` statement triggers an exception and reverts the current call. raise "something went wrong" -The error string is not required. +The error string is not required. If it is provided, it is limited to 1024 bytes. assert ------ @@ -104,7 +104,7 @@ The ``assert`` statement makes an assertion about a given condition. If the cond assert x > 5, "value too low" -The error string is not required. +The error string is not required. If it is provided, it is limited to 1024 bytes. This method's behavior is equivalent to: diff --git a/tests/grammar/vyper.lark b/tests/grammar/vyper.lark index 5a354ade10..2b9783ac9e 100644 --- a/tests/grammar/vyper.lark +++ b/tests/grammar/vyper.lark @@ -139,10 +139,10 @@ log_stmt: _LOG NAME "(" [arguments] ")" return_stmt: _RETURN [_expr ("," _expr)*] _UNREACHABLE: "UNREACHABLE" raise_stmt: _RAISE -> raise - | _RAISE STRING -> raise_with_reason + | _RAISE _expr -> raise_with_reason | _RAISE _UNREACHABLE -> raise_unreachable assert_stmt: _ASSERT _expr -> assert - | _ASSERT _expr "," STRING -> assert_with_reason + | _ASSERT _expr "," _expr -> assert_with_reason | _ASSERT _expr "," _UNREACHABLE -> assert_unreachable body: _NEWLINE _INDENT ([COMMENT] _NEWLINE | _stmt)+ _DEDENT diff --git a/tests/parser/exceptions/test_invalid_type_exception.py b/tests/parser/exceptions/test_invalid_type_exception.py index d5641b78ee..f2f9a45720 100644 --- a/tests/parser/exceptions/test_invalid_type_exception.py +++ b/tests/parser/exceptions/test_invalid_type_exception.py @@ -42,7 +42,7 @@ def foo(): """ @external def mint(_to: address, _value: uint256): - assert msg.sender == self,minter + assert msg.sender == self,msg.sender """, # Raise reason must be string """ diff --git a/tests/parser/features/test_assert.py b/tests/parser/features/test_assert.py index adad5e57af..d31edc74a4 100644 --- a/tests/parser/features/test_assert.py +++ b/tests/parser/features/test_assert.py @@ -29,15 +29,15 @@ def test(a: int128) -> int128: return 1 + a @external -def test2(a: int128, b: int128) -> int128: +def test2(a: int128, b: int128, extra_reason: String[32]) -> int128: c: int128 = 11 assert a > 1, "a is not large enough" - assert b == 1, "b may only be 1" + assert b == 1, concat("b may only be 1", extra_reason) return a + b + c @external -def test3() : - raise "An exception" +def test3(reason_str: String[32]): + raise reason_str """ c = get_contract_with_gas_estimation(code) @@ -48,17 +48,17 @@ def test3() : assert e_info.value.args[0] == "larger than one please" # a = 0, b = 1 with pytest.raises(TransactionFailed) as e_info: - c.test2(0, 1) + c.test2(0, 1, "") assert e_info.value.args[0] == "a is not large enough" # a = 1, b = 0 with pytest.raises(TransactionFailed) as e_info: - c.test2(2, 2) - assert e_info.value.args[0] == "b may only be 1" + c.test2(2, 2, " because I said so") + assert e_info.value.args[0] == "b may only be 1" + " because I said so" # return correct value - assert c.test2(5, 1) == 17 + assert c.test2(5, 1, "") == 17 with pytest.raises(TransactionFailed) as e_info: - c.test3() + c.test3("An exception") assert e_info.value.args[0] == "An exception" diff --git a/tests/parser/functions/test_slice.py b/tests/parser/functions/test_slice.py index 725428e804..bfce521505 100644 --- a/tests/parser/functions/test_slice.py +++ b/tests/parser/functions/test_slice.py @@ -100,6 +100,7 @@ def ret10_slice() -> Bytes[10]: c = get_contract(code) assert c.ret10_slice() == b"A" + def test_slice_expr(get_contract): # test slice of a complex expression code = """ @@ -112,7 +113,6 @@ def ret10_slice() -> Bytes[10]: assert c.ret10_slice() == b"A" - code_bytes32 = [ """ foo: bytes32 diff --git a/vyper/old_codegen/stmt.py b/vyper/old_codegen/stmt.py index b2e9979966..ff2fca79ce 100644 --- a/vyper/old_codegen/stmt.py +++ b/vyper/old_codegen/stmt.py @@ -2,13 +2,20 @@ import vyper.utils as util from vyper import ast as vy_ast from vyper.builtin_functions import STMT_DISPATCH_TABLE -from vyper.exceptions import StructureException, TypeCheckFailure +from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure from vyper.old_codegen import external_call, self_call -from vyper.old_codegen.context import Context +from vyper.old_codegen.context import Constancy, Context from vyper.old_codegen.expr import Expr -from vyper.old_codegen.parser_utils import LLLnode, getpos, make_setter, unwrap_location +from vyper.old_codegen.parser_utils import ( + LLLnode, + getpos, + make_byte_array_copier, + make_setter, + unwrap_location, + zero_pad, +) from vyper.old_codegen.return_ import make_return_stmt -from vyper.old_codegen.types import BaseType, ByteArrayType, ListType, get_size_of_type, parse_type +from vyper.old_codegen.types import BaseType, ByteArrayType, ListType, parse_type class Stmt: @@ -142,25 +149,47 @@ def _assert_reason(self, test_expr, msg): if isinstance(msg, vy_ast.Name) and msg.id == "UNREACHABLE": return LLLnode.from_list(["assert_unreachable", test_expr], typ=None, pos=getpos(msg)) - reason_str_type = ByteArrayType(len(msg.value.strip())) + # set constant so that revert reason str is well behaved + try: + tmp = self.context.constancy + self.context.constancy = Constancy.Constant + msg_lll = Expr(msg, self.context).lll_node + finally: + self.context.constancy = tmp + + # TODO this is probably useful in parser_utils + def _get_last(lll): + if len(lll.args) == 0: + return lll.value + return _get_last(lll.args[-1]) + + if msg_lll.location != "memory": + buf = self.context.new_internal_variable(msg_lll.typ) + instantiate_msg = make_byte_array_copier(buf, msg_lll) + else: + buf = _get_last(msg_lll) + if not isinstance(buf, int): + raise CompilerPanic(f"invalid bytestring {buf}\n{self}") + instantiate_msg = msg_lll - # abi encode the reason string - sig_placeholder = self.context.new_internal_variable(BaseType(32)) # offset of bytes in (bytes,) - arg_placeholder = self.context.new_internal_variable(BaseType(32)) - placeholder_bytes = Expr(msg, self.context).lll_node - method_id = util.abi_method_id("Error(string)") # abi encode method_id + bytestring + assert buf >= 36, "invalid buffer" + # we don't mind overwriting other memory because we are + # getting out of here anyway. + _runtime_length = ["mload", buf] revert_seq = [ "seq", - ["mstore", sig_placeholder, method_id], - ["mstore", arg_placeholder, 32], - placeholder_bytes, - ["revert", sig_placeholder + 28, int(32 + 4 + get_size_of_type(reason_str_type) * 32)], + instantiate_msg, + zero_pad(buf), + ["mstore", buf - 64, method_id], + ["mstore", buf - 32, 0x20], + ["revert", buf - 36, ["add", 4 + 32 + 32, ["ceil32", _runtime_length]]], ] - if test_expr: + + if test_expr is not None: lll_node = ["if", ["iszero", test_expr], revert_seq] else: lll_node = revert_seq @@ -183,7 +212,7 @@ def parse_Assert(self): def parse_Raise(self): if self.stmt.exc: - return self._assert_reason(0, self.stmt.exc) + return self._assert_reason(None, self.stmt.exc) else: return LLLnode.from_list(["revert", 0, 0], typ=None, pos=getpos(self.stmt)) @@ -394,19 +423,21 @@ def parse_Return(self): return make_return_stmt(lll_val, self.stmt, self.context) def _get_target(self, target): + _dbg_expr = target + if isinstance(target, vy_ast.Name) and target.id in self.context.forvars: - raise TypeCheckFailure("Failed for-loop constancy check") + raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}") if isinstance(target, vy_ast.Tuple): target = Expr(target, self.context).lll_node for node in target.args: if (node.location == "storage" and self.context.is_constant()) or not node.mutable: - raise TypeCheckFailure("Failed for-loop constancy check") + raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}") return target target = Expr.parse_variable_location(target, self.context) if (target.location == "storage" and self.context.is_constant()) or not target.mutable: - raise TypeCheckFailure("Failed for-loop constancy check") + raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}") return target diff --git a/vyper/semantics/validation/local.py b/vyper/semantics/validation/local.py index 383dbb0070..7dbe88a718 100644 --- a/vyper/semantics/validation/local.py +++ b/vyper/semantics/validation/local.py @@ -30,6 +30,7 @@ from vyper.semantics.types.user.event import Event from vyper.semantics.types.user.struct import StructDefinition from vyper.semantics.types.utils import get_type_from_annotation +from vyper.semantics.types.value.array_value import StringDefinition from vyper.semantics.types.value.boolean import BoolDefinition from vyper.semantics.types.value.numeric import Uint256Definition from vyper.semantics.validation.annotation import StatementAnnotationVisitor @@ -104,7 +105,10 @@ def _validate_revert_reason(msg_node: vy_ast.VyperNode) -> None: if not msg_node.value.strip(): raise StructureException("Reason string cannot be empty", msg_node) elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"): - raise InvalidType("Reason must UNREACHABLE or a string literal", msg_node) + try: + validate_expected_type(msg_node, StringDefinition(1024)) + except TypeMismatch as e: + raise InvalidType("revert reason must fit within String[1024]") from e class FunctionNodeVisitor(VyperNodeVisitorBase):