diff --git a/tests/functional/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py index 9442362696..9ea0b58d89 100644 --- a/tests/functional/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -695,3 +695,34 @@ def test_call(a: address, b: {type_str}) -> {type_str}: make_file("jsonabi.json", json.dumps(convert_v1_abi(abi))) c3 = get_contract(code, input_bundle=input_bundle) assert c3.test_call(c1.address, value) == value + + +def test_interface_function_without_visibility(make_input_bundle, get_contract): + interface_code = """ +def foo() -> uint256: + ... + +@external +def bar() -> uint256: + ... + """ + + code = """ +import a as FooInterface + +implements: FooInterface + +@external +def foo() -> uint256: + return 1 + +@external +def bar() -> uint256: + return 1 + """ + + input_bundle = make_input_bundle({"a.vyi": interface_code}) + + c = get_contract(code, input_bundle=input_bundle) + + assert c.foo() == c.bar() == 1 diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index d2e0077cd6..ea06e0ab2f 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -484,3 +484,81 @@ def baz(): """ assert compiler.compile_code(code, input_bundle=input_bundle) is not None + + +invalid_visibility_code = [ + """ +import foo as Foo +implements: Foo +@external +def foobar(): + pass + """, + """ +import foo as Foo +implements: Foo +@internal +def foobar(): + pass + """, + """ +import foo as Foo +implements: Foo +def foobar(): + pass + """, +] + + +@pytest.mark.parametrize("code", invalid_visibility_code) +def test_internal_visibility_in_interface(make_input_bundle, code): + interface_code = """ +@internal +def foobar(): + ... +""" + + input_bundle = make_input_bundle({"foo.vyi": interface_code}) + + with pytest.raises(FunctionDeclarationException) as e: + compiler.compile_code(code, input_bundle=input_bundle) + + assert e.value._message == "Interface functions can only be marked as `@external`" + + +external_visibility_interface = [ + """ +@external +def foobar(): + ... +def bar(): + ... + """, + """ +def foobar(): + ... +@external +def bar(): + ... + """, +] + + +@pytest.mark.parametrize("iface", external_visibility_interface) +def test_internal_implemenatation_of_external_interface(make_input_bundle, iface): + input_bundle = make_input_bundle({"foo.vyi": iface}) + + code = """ +import foo as Foo +implements: Foo +@internal +def foobar(): + pass +def bar(): + pass + """ + + with pytest.raises(InterfaceViolation) as e: + compiler.compile_code(code, input_bundle=input_bundle) + + assert e.value.message == "Contract does not implement all interface functions: bar(), foobar()" diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index d6bbea1b48..d3de219c03 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -193,7 +193,7 @@ def __init__( self._imported_modules: dict[PurePath, vy_ast.VyperNode] = {} # keep track of exported functions to prevent duplicate exports - self._exposed_functions: dict[ContractFunctionT, vy_ast.VyperNode] = {} + self._all_functions: dict[ContractFunctionT, vy_ast.VyperNode] = {} self._events: list[EventT] = [] @@ -414,7 +414,7 @@ def visit_ImplementsDecl(self, node): raise StructureException(msg, node.annotation, hint=hint) # grab exposed functions - funcs = self._exposed_functions + funcs = {fn_t: node for fn_t, node in self._all_functions.items() if fn_t.is_external} type_.validate_implements(node, funcs) node._metadata["interface_type"] = type_ @@ -608,10 +608,10 @@ def _self_t(self): def _add_exposed_function(self, func_t, node, relax=True): # call this before self._self_t.typ.add_member() for exception raising # priority - if not relax and (prev_decl := self._exposed_functions.get(func_t)) is not None: + if not relax and (prev_decl := self._all_functions.get(func_t)) is not None: raise StructureException("already exported!", node, prev_decl=prev_decl) - self._exposed_functions[func_t] = node + self._all_functions[func_t] = node def visit_VariableDecl(self, node): # postcondition of VariableDecl.validate diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 72682c881d..7a56b01281 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -330,7 +330,23 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef) if nonreentrant: - raise FunctionDeclarationException("`@nonreentrant` not allowed in interfaces", funcdef) + # TODO: refactor so parse_decorators returns the AST location + decorator = next(d for d in funcdef.decorator_list if d.id == "nonreentrant") + raise FunctionDeclarationException( + "`@nonreentrant` not allowed in interfaces", decorator + ) + + # it's redundant to specify visibility in vyi - always should be external + if function_visibility is None: + function_visibility = FunctionVisibility.EXTERNAL + + if function_visibility != FunctionVisibility.EXTERNAL: + nonexternal = next( + d for d in funcdef.decorator_list if d.id in FunctionVisibility.values() + ) + raise FunctionDeclarationException( + "Interface functions can only be marked as `@external`", nonexternal + ) if funcdef.name == "__init__": raise FunctionDeclarationException("Constructors cannot appear in interfaces", funcdef) @@ -381,6 +397,10 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": """ function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef) + # it's redundant to specify internal visibility - it's implied by not being external + if function_visibility is None: + function_visibility = FunctionVisibility.INTERNAL + positional_args, keyword_args = _parse_args(funcdef) return_type = _parse_return_type(funcdef) @@ -419,6 +439,10 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": raise FunctionDeclarationException( "Constructor may not use default arguments", funcdef.args.defaults[0] ) + if nonreentrant: + decorator = next(d for d in funcdef.decorator_list if d.id == "nonreentrant") + msg = "`@nonreentrant` decorator disallowed on `__init__`" + raise FunctionDeclarationException(msg, decorator) return cls( funcdef.name, @@ -495,6 +519,8 @@ def implements(self, other: "ContractFunctionT") -> bool: if not self.is_external: # pragma: nocover raise CompilerPanic("unreachable!") + assert self.visibility == other.visibility + arguments, return_type = self._iface_sig other_arguments, other_return_type = other._iface_sig @@ -700,7 +726,7 @@ def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]: def _parse_decorators( funcdef: vy_ast.FunctionDef, -) -> tuple[FunctionVisibility, StateMutability, bool]: +) -> tuple[Optional[FunctionVisibility], StateMutability, bool]: function_visibility = None state_mutability = None nonreentrant_node = None @@ -719,10 +745,6 @@ def _parse_decorators( if nonreentrant_node is not None: raise StructureException("nonreentrant decorator is already set", nonreentrant_node) - if funcdef.name == "__init__": - msg = "`@nonreentrant` decorator disallowed on `__init__`" - raise FunctionDeclarationException(msg, decorator) - nonreentrant_node = decorator elif isinstance(decorator, vy_ast.Name): @@ -733,6 +755,7 @@ def _parse_decorators( decorator, hint="only one visibility decorator is allowed per function", ) + function_visibility = FunctionVisibility(decorator.id) elif StateMutability.is_valid_value(decorator.id): @@ -755,9 +778,6 @@ def _parse_decorators( else: raise StructureException("Bad decorator syntax", decorator) - if function_visibility is None: - function_visibility = FunctionVisibility.INTERNAL - if state_mutability is None: # default to nonpayable state_mutability = StateMutability.NONPAYABLE diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index e55c4d145f..60eb93bcac 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -107,6 +107,7 @@ def _ctor_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifia def validate_implements( self, node: vy_ast.ImplementsDecl, functions: dict[ContractFunctionT, vy_ast.VyperNode] ) -> None: + # only external functions can implement interfaces fns_by_name = {fn_t.name: fn_t for fn_t in functions.keys()} unimplemented = [] @@ -116,7 +117,9 @@ def _is_function_implemented(fn_name, fn_type): return False to_compare = fns_by_name[fn_name] + assert to_compare.is_external assert isinstance(to_compare, ContractFunctionT) + assert isinstance(fn_type, ContractFunctionT) return to_compare.implements(fn_type)