Skip to content

Commit

Permalink
feat[lang]!: make @external modifier optional in .vyi files (#4178)
Browse files Browse the repository at this point in the history
make `@external` visibility in `.vyi` files optional.

additionally, fix a panic when functions in an interface have the wrong
visibility

---------

Co-authored-by: Charles Cooper <[email protected]>
  • Loading branch information
cyberthirst and charles-cooper authored Aug 5, 2024
1 parent 85269b0 commit c0cf436
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 13 deletions.
31 changes: 31 additions & 0 deletions tests/functional/codegen/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,3 +695,34 @@ def test_call(a: address, b: {type_str}) -> {type_str}:
make_file("jsonabi.json", json.dumps(convert_v1_abi(abi)))
c3 = get_contract(code, input_bundle=input_bundle)
assert c3.test_call(c1.address, value) == value


def test_interface_function_without_visibility(make_input_bundle, get_contract):
interface_code = """
def foo() -> uint256:
...
@external
def bar() -> uint256:
...
"""

code = """
import a as FooInterface
implements: FooInterface
@external
def foo() -> uint256:
return 1
@external
def bar() -> uint256:
return 1
"""

input_bundle = make_input_bundle({"a.vyi": interface_code})

c = get_contract(code, input_bundle=input_bundle)

assert c.foo() == c.bar() == 1
78 changes: 78 additions & 0 deletions tests/functional/syntax/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,81 @@ def baz():
"""

assert compiler.compile_code(code, input_bundle=input_bundle) is not None


invalid_visibility_code = [
"""
import foo as Foo
implements: Foo
@external
def foobar():
pass
""",
"""
import foo as Foo
implements: Foo
@internal
def foobar():
pass
""",
"""
import foo as Foo
implements: Foo
def foobar():
pass
""",
]


@pytest.mark.parametrize("code", invalid_visibility_code)
def test_internal_visibility_in_interface(make_input_bundle, code):
interface_code = """
@internal
def foobar():
...
"""

input_bundle = make_input_bundle({"foo.vyi": interface_code})

with pytest.raises(FunctionDeclarationException) as e:
compiler.compile_code(code, input_bundle=input_bundle)

assert e.value._message == "Interface functions can only be marked as `@external`"


external_visibility_interface = [
"""
@external
def foobar():
...
def bar():
...
""",
"""
def foobar():
...
@external
def bar():
...
""",
]


@pytest.mark.parametrize("iface", external_visibility_interface)
def test_internal_implemenatation_of_external_interface(make_input_bundle, iface):
input_bundle = make_input_bundle({"foo.vyi": iface})

code = """
import foo as Foo
implements: Foo
@internal
def foobar():
pass
def bar():
pass
"""

with pytest.raises(InterfaceViolation) as e:
compiler.compile_code(code, input_bundle=input_bundle)

assert e.value.message == "Contract does not implement all interface functions: bar(), foobar()"
8 changes: 4 additions & 4 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(
self._imported_modules: dict[PurePath, vy_ast.VyperNode] = {}

# keep track of exported functions to prevent duplicate exports
self._exposed_functions: dict[ContractFunctionT, vy_ast.VyperNode] = {}
self._all_functions: dict[ContractFunctionT, vy_ast.VyperNode] = {}

self._events: list[EventT] = []

Expand Down Expand Up @@ -414,7 +414,7 @@ def visit_ImplementsDecl(self, node):
raise StructureException(msg, node.annotation, hint=hint)

# grab exposed functions
funcs = self._exposed_functions
funcs = {fn_t: node for fn_t, node in self._all_functions.items() if fn_t.is_external}
type_.validate_implements(node, funcs)

node._metadata["interface_type"] = type_
Expand Down Expand Up @@ -608,10 +608,10 @@ def _self_t(self):
def _add_exposed_function(self, func_t, node, relax=True):
# call this before self._self_t.typ.add_member() for exception raising
# priority
if not relax and (prev_decl := self._exposed_functions.get(func_t)) is not None:
if not relax and (prev_decl := self._all_functions.get(func_t)) is not None:
raise StructureException("already exported!", node, prev_decl=prev_decl)

self._exposed_functions[func_t] = node
self._all_functions[func_t] = node

def visit_VariableDecl(self, node):
# postcondition of VariableDecl.validate
Expand Down
38 changes: 29 additions & 9 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,23 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT":
function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef)

if nonreentrant:
raise FunctionDeclarationException("`@nonreentrant` not allowed in interfaces", funcdef)
# TODO: refactor so parse_decorators returns the AST location
decorator = next(d for d in funcdef.decorator_list if d.id == "nonreentrant")
raise FunctionDeclarationException(
"`@nonreentrant` not allowed in interfaces", decorator
)

# it's redundant to specify visibility in vyi - always should be external
if function_visibility is None:
function_visibility = FunctionVisibility.EXTERNAL

if function_visibility != FunctionVisibility.EXTERNAL:
nonexternal = next(
d for d in funcdef.decorator_list if d.id in FunctionVisibility.values()
)
raise FunctionDeclarationException(
"Interface functions can only be marked as `@external`", nonexternal
)

if funcdef.name == "__init__":
raise FunctionDeclarationException("Constructors cannot appear in interfaces", funcdef)
Expand Down Expand Up @@ -381,6 +397,10 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT":
"""
function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef)

# it's redundant to specify internal visibility - it's implied by not being external
if function_visibility is None:
function_visibility = FunctionVisibility.INTERNAL

positional_args, keyword_args = _parse_args(funcdef)

return_type = _parse_return_type(funcdef)
Expand Down Expand Up @@ -419,6 +439,10 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT":
raise FunctionDeclarationException(
"Constructor may not use default arguments", funcdef.args.defaults[0]
)
if nonreentrant:
decorator = next(d for d in funcdef.decorator_list if d.id == "nonreentrant")
msg = "`@nonreentrant` decorator disallowed on `__init__`"
raise FunctionDeclarationException(msg, decorator)

return cls(
funcdef.name,
Expand Down Expand Up @@ -495,6 +519,8 @@ def implements(self, other: "ContractFunctionT") -> bool:
if not self.is_external: # pragma: nocover
raise CompilerPanic("unreachable!")

assert self.visibility == other.visibility

arguments, return_type = self._iface_sig
other_arguments, other_return_type = other._iface_sig

Expand Down Expand Up @@ -700,7 +726,7 @@ def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]:

def _parse_decorators(
funcdef: vy_ast.FunctionDef,
) -> tuple[FunctionVisibility, StateMutability, bool]:
) -> tuple[Optional[FunctionVisibility], StateMutability, bool]:
function_visibility = None
state_mutability = None
nonreentrant_node = None
Expand All @@ -719,10 +745,6 @@ def _parse_decorators(
if nonreentrant_node is not None:
raise StructureException("nonreentrant decorator is already set", nonreentrant_node)

if funcdef.name == "__init__":
msg = "`@nonreentrant` decorator disallowed on `__init__`"
raise FunctionDeclarationException(msg, decorator)

nonreentrant_node = decorator

elif isinstance(decorator, vy_ast.Name):
Expand All @@ -733,6 +755,7 @@ def _parse_decorators(
decorator,
hint="only one visibility decorator is allowed per function",
)

function_visibility = FunctionVisibility(decorator.id)

elif StateMutability.is_valid_value(decorator.id):
Expand All @@ -755,9 +778,6 @@ def _parse_decorators(
else:
raise StructureException("Bad decorator syntax", decorator)

if function_visibility is None:
function_visibility = FunctionVisibility.INTERNAL

if state_mutability is None:
# default to nonpayable
state_mutability = StateMutability.NONPAYABLE
Expand Down
3 changes: 3 additions & 0 deletions vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _ctor_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifia
def validate_implements(
self, node: vy_ast.ImplementsDecl, functions: dict[ContractFunctionT, vy_ast.VyperNode]
) -> None:
# only external functions can implement interfaces
fns_by_name = {fn_t.name: fn_t for fn_t in functions.keys()}

unimplemented = []
Expand All @@ -116,7 +117,9 @@ def _is_function_implemented(fn_name, fn_type):
return False

to_compare = fns_by_name[fn_name]
assert to_compare.is_external
assert isinstance(to_compare, ContractFunctionT)
assert isinstance(fn_type, ContractFunctionT)

return to_compare.implements(fn_type)

Expand Down

0 comments on commit c0cf436

Please sign in to comment.