diff --git a/examples/tokens/ERC20.vy b/examples/tokens/ERC20.vy index 3044c92c89..2d70fd670a 100644 --- a/examples/tokens/ERC20.vy +++ b/examples/tokens/ERC20.vy @@ -14,16 +14,6 @@ from ethereum.ercs import IERC20Detailed implements: IERC20 implements: IERC20Detailed -event Transfer: - sender: indexed(address) - receiver: indexed(address) - value: uint256 - -event Approval: - owner: indexed(address) - spender: indexed(address) - value: uint256 - name: public(String[32]) symbol: public(String[32]) decimals: public(uint8) @@ -49,7 +39,7 @@ def __init__(_name: String[32], _symbol: String[32], _decimals: uint8, _supply: self.balanceOf[msg.sender] = init_supply self.totalSupply = init_supply self.minter = msg.sender - log Transfer(empty(address), msg.sender, init_supply) + log IERC20.Transfer(empty(address), msg.sender, init_supply) @@ -64,7 +54,7 @@ def transfer(_to : address, _value : uint256) -> bool: # so the following subtraction would revert on insufficient balance self.balanceOf[msg.sender] -= _value self.balanceOf[_to] += _value - log Transfer(msg.sender, _to, _value) + log IERC20.Transfer(msg.sender, _to, _value) return True @@ -83,7 +73,7 @@ def transferFrom(_from : address, _to : address, _value : uint256) -> bool: # NOTE: vyper does not allow underflows # so the following subtraction would revert on insufficient allowance self.allowance[_from][msg.sender] -= _value - log Transfer(_from, _to, _value) + log IERC20.Transfer(_from, _to, _value) return True @@ -99,7 +89,7 @@ def approve(_spender : address, _value : uint256) -> bool: @param _value The amount of tokens to be spent. """ self.allowance[msg.sender][_spender] = _value - log Approval(msg.sender, _spender, _value) + log IERC20.Approval(msg.sender, _spender, _value) return True @@ -116,7 +106,7 @@ def mint(_to: address, _value: uint256): assert _to != empty(address) self.totalSupply += _value self.balanceOf[_to] += _value - log Transfer(empty(address), _to, _value) + log IERC20.Transfer(empty(address), _to, _value) @internal @@ -130,7 +120,7 @@ def _burn(_to: address, _value: uint256): assert _to != empty(address) self.totalSupply -= _value self.balanceOf[_to] -= _value - log Transfer(_to, empty(address), _value) + log IERC20.Transfer(_to, empty(address), _value) @external diff --git a/examples/tokens/ERC4626.vy b/examples/tokens/ERC4626.vy index aadf1f4f13..a175fd3aa7 100644 --- a/examples/tokens/ERC4626.vy +++ b/examples/tokens/ERC4626.vy @@ -24,33 +24,10 @@ NAME: constant(String[10]) = "Test Vault" SYMBOL: constant(String[5]) = "vTEST" DECIMALS: constant(uint8) = 18 -event Transfer: - sender: indexed(address) - receiver: indexed(address) - amount: uint256 - -event Approval: - owner: indexed(address) - spender: indexed(address) - allowance: uint256 - ##### ERC4626 ##### asset: public(IERC20) -event Deposit: - depositor: indexed(address) - receiver: indexed(address) - assets: uint256 - shares: uint256 - -event Withdraw: - withdrawer: indexed(address) - receiver: indexed(address) - owner: indexed(address) - assets: uint256 - shares: uint256 - @deploy def __init__(asset: IERC20): @@ -79,14 +56,14 @@ def decimals() -> uint8: def transfer(receiver: address, amount: uint256) -> bool: self.balanceOf[msg.sender] -= amount self.balanceOf[receiver] += amount - log Transfer(msg.sender, receiver, amount) + log IERC20.Transfer(msg.sender, receiver, amount) return True @external def approve(spender: address, amount: uint256) -> bool: self.allowance[msg.sender][spender] = amount - log Approval(msg.sender, spender, amount) + log IERC20.Approval(msg.sender, spender, amount) return True @@ -95,7 +72,7 @@ def transferFrom(sender: address, receiver: address, amount: uint256) -> bool: self.allowance[sender][msg.sender] -= amount self.balanceOf[sender] -= amount self.balanceOf[receiver] += amount - log Transfer(sender, receiver, amount) + log IERC20.Transfer(sender, receiver, amount) return True @@ -160,7 +137,7 @@ def deposit(assets: uint256, receiver: address=msg.sender) -> uint256: self.totalSupply += shares self.balanceOf[receiver] += shares - log Deposit(msg.sender, receiver, assets, shares) + log IERC4626.Deposit(msg.sender, receiver, assets, shares) return shares @@ -193,7 +170,7 @@ def mint(shares: uint256, receiver: address=msg.sender) -> uint256: self.totalSupply += shares self.balanceOf[receiver] += shares - log Deposit(msg.sender, receiver, assets, shares) + log IERC4626.Deposit(msg.sender, receiver, assets, shares) return assets @@ -230,7 +207,7 @@ def withdraw(assets: uint256, receiver: address=msg.sender, owner: address=msg.s self.balanceOf[owner] -= shares self.asset.transfer(receiver, assets) - log Withdraw(msg.sender, receiver, owner, assets, shares) + log IERC4626.Withdraw(msg.sender, receiver, owner, assets, shares) return shares @@ -256,7 +233,7 @@ def redeem(shares: uint256, receiver: address=msg.sender, owner: address=msg.sen self.balanceOf[owner] -= shares self.asset.transfer(receiver, assets) - log Withdraw(msg.sender, receiver, owner, assets, shares) + log IERC4626.Withdraw(msg.sender, receiver, owner, assets, shares) return assets diff --git a/examples/tokens/ERC721.vy b/examples/tokens/ERC721.vy index 2399f31947..5ae9365200 100644 --- a/examples/tokens/ERC721.vy +++ b/examples/tokens/ERC721.vy @@ -24,41 +24,6 @@ interface ERC721Receiver: ) -> bytes4: nonpayable -# @dev Emits when ownership of any NFT changes by any mechanism. This event emits when NFTs are -# created (`from` == 0) and destroyed (`to` == 0). Exception: during contract creation, any -# number of NFTs may be created and assigned without emitting Transfer. At the time of any -# transfer, the approved address for that NFT (if any) is reset to none. -# @param _from Sender of NFT (if address is zero address it indicates token creation). -# @param _to Receiver of NFT (if address is zero address it indicates token destruction). -# @param _tokenId The NFT that got transferred. -event Transfer: - sender: indexed(address) - receiver: indexed(address) - tokenId: indexed(uint256) - -# @dev This emits when the approved address for an NFT is changed or reaffirmed. The zero -# address indicates there is no approved address. When a Transfer event emits, this also -# indicates that the approved address for that NFT (if any) is reset to none. -# @param _owner Owner of NFT. -# @param _approved Address that we are approving. -# @param _tokenId NFT which we are approving. -event Approval: - owner: indexed(address) - approved: indexed(address) - tokenId: indexed(uint256) - -# @dev This emits when an operator is enabled or disabled for an owner. The operator can manage -# all NFTs of the owner. -# @param _owner Owner of NFT. -# @param _operator Address to which we are setting operator rights. -# @param _approved Status of operator rights(true if operator rights are given and false if -# revoked). -event ApprovalForAll: - owner: indexed(address) - operator: indexed(address) - approved: bool - - # @dev Mapping from NFT ID to the address that owns it. idToOwner: HashMap[uint256, address] @@ -236,7 +201,7 @@ def _transferFrom(_from: address, _to: address, _tokenId: uint256, _sender: addr # Add NFT self._addTokenTo(_to, _tokenId) # Log the transfer - log Transfer(_from, _to, _tokenId) + log IERC721.Transfer(_from, _to, _tokenId) ### TRANSFER FUNCTIONS ### @@ -310,7 +275,7 @@ def approve(_approved: address, _tokenId: uint256): assert (senderIsOwner or senderIsApprovedForAll) # Set the approval self.idToApprovals[_tokenId] = _approved - log Approval(owner, _approved, _tokenId) + log IERC721.Approval(owner, _approved, _tokenId) @external @@ -326,7 +291,7 @@ def setApprovalForAll(_operator: address, _approved: bool): # Throws if `_operator` is the `msg.sender` assert _operator != msg.sender self.ownerToOperators[msg.sender][_operator] = _approved - log ApprovalForAll(msg.sender, _operator, _approved) + log IERC721.ApprovalForAll(msg.sender, _operator, _approved) ### MINT & BURN FUNCTIONS ### @@ -348,7 +313,7 @@ def mint(_to: address, _tokenId: uint256) -> bool: assert _to != empty(address) # Add NFT. Throws if `_tokenId` is owned by someone self._addTokenTo(_to, _tokenId) - log Transfer(empty(address), _to, _tokenId) + log IERC721.Transfer(empty(address), _to, _tokenId) return True @@ -368,7 +333,7 @@ def burn(_tokenId: uint256): assert owner != empty(address) self._clearApproval(owner, _tokenId) self._removeTokenFrom(owner, _tokenId) - log Transfer(owner, empty(address), _tokenId) + log IERC721.Transfer(owner, empty(address), _tokenId) @view diff --git a/tests/functional/builtins/codegen/test_abi.py b/tests/functional/builtins/codegen/test_abi.py deleted file mode 100644 index 6318ffd883..0000000000 --- a/tests/functional/builtins/codegen/test_abi.py +++ /dev/null @@ -1,294 +0,0 @@ -import pytest - -from vyper.compiler import compile_code -from vyper.compiler.output import build_abi_output -from vyper.compiler.phases import CompilerData - -source_codes = [ - """ -x: int128 - -@deploy -def __init__(): - self.x = 1 - """, - """ -x: int128 - -@deploy -def __init__(): - pass - """, -] - - -@pytest.mark.parametrize("source_code", source_codes) -def test_only_init_function(source_code): - empty_sig = [ - {"outputs": [], "inputs": [], "stateMutability": "nonpayable", "type": "constructor"} - ] - - data = CompilerData(source_code) - assert build_abi_output(data) == empty_sig - - -def test_default_abi(): - default_code = """ -@payable -@external -def __default__(): - pass - """ - - data = CompilerData(default_code) - assert build_abi_output(data) == [{"stateMutability": "payable", "type": "fallback"}] - - -def test_method_identifiers(): - code = """ -x: public(int128) - -@external -def foo(y: uint256) -> Bytes[100]: - return b"hello" - """ - - out = compile_code(code, output_formats=["method_identifiers"]) - - assert out["method_identifiers"] == {"foo(uint256)": "0x2fbebd38", "x()": "0xc55699c"} - - -def test_struct_abi(): - code = """ -struct MyStruct: - a: address - b: uint256 - -@external -@view -def foo(s: MyStruct) -> MyStruct: - return s - """ - - data = CompilerData(code) - abi = build_abi_output(data) - func_abi = abi[0] - - assert func_abi["name"] == "foo" - - expected_output = [ - { - "type": "tuple", - "name": "", - "components": [{"type": "address", "name": "a"}, {"type": "uint256", "name": "b"}], - } - ] - - assert func_abi["outputs"] == expected_output - - expected_input = { - "type": "tuple", - "name": "s", - "components": [{"type": "address", "name": "a"}, {"type": "uint256", "name": "b"}], - } - - assert func_abi["inputs"][0] == expected_input - - -@pytest.mark.parametrize( - "type,abi_type", [("DynArray[NestedStruct, 2]", "tuple[]"), ("NestedStruct[2]", "tuple[2]")] -) -def test_nested_struct(type, abi_type): - code = f""" -struct MyStruct: - a: address - b: bytes32 - -struct NestedStruct: - t: MyStruct - foo: uint256 - -@view -@external -def getStructList() -> {type}: - return [ - NestedStruct(t=MyStruct(a=msg.sender, b=block.prevhash), foo=1), - NestedStruct(t=MyStruct(a=msg.sender, b=block.prevhash), foo=2) - ] - """ - - out = compile_code(code, output_formats=["abi"]) - - assert out["abi"] == [ - { - "inputs": [], - "name": "getStructList", - "outputs": [ - { - "components": [ - { - "components": [ - {"name": "a", "type": "address"}, - {"name": "b", "type": "bytes32"}, - ], - "name": "t", - "type": "tuple", - }, - {"name": "foo", "type": "uint256"}, - ], - "name": "", - "type": f"{abi_type}", - } - ], - "stateMutability": "view", - "type": "function", - } - ] - - -@pytest.mark.parametrize( - "type,abi_type", [("DynArray[DynArray[Foo, 2], 2]", "tuple[][]"), ("Foo[2][2]", "tuple[2][2]")] -) -def test_2d_list_of_struct(type, abi_type): - code = f""" -struct Foo: - a: uint256 - b: uint256 - -@view -@external -def bar(x: {type}): - pass - """ - - out = compile_code(code, output_formats=["abi"]) - - assert out["abi"] == [ - { - "inputs": [ - { - "components": [ - {"name": "a", "type": "uint256"}, - {"name": "b", "type": "uint256"}, - ], - "name": "x", - "type": f"{abi_type}", - } - ], - "name": "bar", - "outputs": [], - "stateMutability": "view", - "type": "function", - } - ] - - -def test_exports_abi(make_input_bundle): - lib1 = """ -@external -def foo(): - pass - -@external -def bar(): - pass - """ - - main = """ -import lib1 - -initializes: lib1 - -exports: lib1.foo - """ - input_bundle = make_input_bundle({"lib1.vy": lib1}) - out = compile_code(main, input_bundle=input_bundle, output_formats=["abi"]) - - # just for clarity -- check bar() is not in the output - for fn in out["abi"]: - assert fn["name"] != "bar" - - expected = [ - { - "inputs": [], - "name": "foo", - "outputs": [], - "stateMutability": "nonpayable", - "type": "function", - } - ] - - assert out["abi"] == expected - - -def test_exports_variable(make_input_bundle): - lib1 = """ -@external -def foo(): - pass - -private_storage_variable: uint256 -private_immutable_variable: immutable(uint256) -private_constant_variable: constant(uint256) = 3 - -public_storage_variable: public(uint256) -public_immutable_variable: public(immutable(uint256)) -public_constant_variable: public(constant(uint256)) = 10 - -@deploy -def __init__(a: uint256, b: uint256): - public_immutable_variable = a - private_immutable_variable = b - """ - - main = """ -import lib1 - -initializes: lib1 - -exports: ( - lib1.foo, - lib1.public_storage_variable, - lib1.public_immutable_variable, - lib1.public_constant_variable, -) - -@deploy -def __init__(): - lib1.__init__(5, 6) - """ - input_bundle = make_input_bundle({"lib1.vy": lib1}) - out = compile_code(main, input_bundle=input_bundle, output_formats=["abi"]) - expected = [ - { - "inputs": [], - "name": "foo", - "outputs": [], - "stateMutability": "nonpayable", - "type": "function", - }, - { - "inputs": [], - "name": "public_storage_variable", - "outputs": [{"name": "", "type": "uint256"}], - "stateMutability": "view", - "type": "function", - }, - { - "inputs": [], - "name": "public_immutable_variable", - "outputs": [{"name": "", "type": "uint256"}], - "stateMutability": "view", - "type": "function", - }, - { - "inputs": [], - "name": "public_constant_variable", - "outputs": [{"name": "", "type": "uint256"}], - "stateMutability": "view", - "type": "function", - }, - {"inputs": [], "outputs": [], "stateMutability": "nonpayable", "type": "constructor"}, - ] - - assert out["abi"] == expected diff --git a/tests/functional/codegen/modules/test_events.py b/tests/functional/codegen/modules/test_events.py new file mode 100644 index 0000000000..8cec4c6577 --- /dev/null +++ b/tests/functional/codegen/modules/test_events.py @@ -0,0 +1,66 @@ +def test_module_event(get_contract, make_input_bundle, get_logs): + # log from a module + lib1 = """ +event MyEvent: + pass + +@internal +def foo(): + log MyEvent() + """ + main = """ +import lib1 + +@external +def bar(): + lib1.foo() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + logs = get_logs(c.bar(transact={}), c, "MyEvent") + assert len(logs) == 1 + + +def test_module_event2(get_contract, make_input_bundle, get_logs): + # log a module event from main contract + lib1 = """ +event MyEvent: + x: uint256 + """ + main = """ +import lib1 + +@external +def bar(): + log lib1.MyEvent(5) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + logs = get_logs(c.bar(transact={}), c, "MyEvent") + assert len(logs) == 1 + assert logs[0].args.x == 5 + + +def test_module_event_indexed(get_contract, make_input_bundle, get_logs): + lib1 = """ +event MyEvent: + x: uint256 + y: indexed(uint256) + +@internal +def foo(): + log MyEvent(5, 6) + """ + main = """ +import lib1 + +@external +def bar(): + lib1.foo() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + c = get_contract(main, input_bundle=input_bundle) + logs = get_logs(c.bar(transact={}), c, "MyEvent") + assert len(logs) == 1 + assert logs[0].args.x == 5 + assert logs[0].args.y == 6 diff --git a/tests/functional/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py index 785ba938ae..3d1d26e999 100644 --- a/tests/functional/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -125,7 +125,7 @@ def foo() -> uint256: compile_code(not_implemented_code, input_bundle=input_bundle) -def test_missing_event(make_input_bundle, assert_compile_failed): +def test_log_interface_event(make_input_bundle, assert_compile_failed): interface_code = """ event Foo: a: uint256 @@ -133,102 +133,18 @@ def test_missing_event(make_input_bundle, assert_compile_failed): input_bundle = make_input_bundle({"a.vyi": interface_code}) - not_implemented_code = """ -import a as FooBarInterface - -implements: FooBarInterface - -@external -def bar() -> uint256: - return 1 - """ - - assert_compile_failed( - lambda: compile_code(not_implemented_code, input_bundle=input_bundle), InterfaceViolation - ) - - -# check that event types match -def test_malformed_event(make_input_bundle, assert_compile_failed): - interface_code = """ -event Foo: - a: uint256 - """ - - input_bundle = make_input_bundle({"a.vyi": interface_code}) - - not_implemented_code = """ -import a as FooBarInterface - -implements: FooBarInterface - -event Foo: - a: int128 - -@external -def bar() -> uint256: - return 1 - """ - - assert_compile_failed( - lambda: compile_code(not_implemented_code, input_bundle=input_bundle), InterfaceViolation - ) - - -# check that event non-indexed arg needs to match interface -def test_malformed_events_indexed(make_input_bundle, assert_compile_failed): - interface_code = """ -event Foo: - a: uint256 - """ - - input_bundle = make_input_bundle({"a.vyi": interface_code}) - - not_implemented_code = """ -import a as FooBarInterface - -implements: FooBarInterface - -# a should not be indexed -event Foo: - a: indexed(uint256) - -@external -def bar() -> uint256: - return 1 - """ - - assert_compile_failed( - lambda: compile_code(not_implemented_code, input_bundle=input_bundle), InterfaceViolation - ) - - -# check that event indexed arg needs to match interface -def test_malformed_events_indexed2(make_input_bundle, assert_compile_failed): - interface_code = """ -event Foo: - a: indexed(uint256) - """ - - input_bundle = make_input_bundle({"a.vyi": interface_code}) - - not_implemented_code = """ + main = """ import a as FooBarInterface implements: FooBarInterface -# a should be indexed -event Foo: - a: uint256 - @external def bar() -> uint256: + log FooBarInterface.Foo(1) return 1 """ - assert_compile_failed( - lambda: compile_code(not_implemented_code, input_bundle=input_bundle), InterfaceViolation - ) + assert compile_code(main, input_bundle=input_bundle) is not None VALID_IMPORT_CODE = [ diff --git a/tests/functional/examples/tokens/test_erc721.py b/tests/functional/examples/tokens/test_erc721.py index c881149baa..4f55807ed3 100644 --- a/tests/functional/examples/tokens/test_erc721.py +++ b/tests/functional/examples/tokens/test_erc721.py @@ -102,7 +102,7 @@ def test_transferFrom_by_owner(c, w3, tx_failed, get_logs): args = logs[0].args assert args.sender == someone assert args.receiver == operator - assert args.tokenId == SOMEONE_TOKEN_IDS[0] + assert args.token_id == SOMEONE_TOKEN_IDS[0] assert c.ownerOf(SOMEONE_TOKEN_IDS[0]) == operator assert c.balanceOf(someone) == 2 assert c.balanceOf(operator) == 2 @@ -121,7 +121,7 @@ def test_transferFrom_by_approved(c, w3, get_logs): args = logs[0].args assert args.sender == someone assert args.receiver == operator - assert args.tokenId == SOMEONE_TOKEN_IDS[1] + assert args.token_id == SOMEONE_TOKEN_IDS[1] assert c.ownerOf(SOMEONE_TOKEN_IDS[1]) == operator assert c.balanceOf(someone) == 2 assert c.balanceOf(operator) == 2 @@ -140,7 +140,7 @@ def test_transferFrom_by_operator(c, w3, get_logs): args = logs[0].args assert args.sender == someone assert args.receiver == operator - assert args.tokenId == SOMEONE_TOKEN_IDS[2] + assert args.token_id == SOMEONE_TOKEN_IDS[2] assert c.ownerOf(SOMEONE_TOKEN_IDS[2]) == operator assert c.balanceOf(someone) == 2 assert c.balanceOf(operator) == 2 @@ -176,7 +176,7 @@ def test_safeTransferFrom_by_owner(c, w3, tx_failed, get_logs): args = logs[0].args assert args.sender == someone assert args.receiver == operator - assert args.tokenId == SOMEONE_TOKEN_IDS[0] + assert args.token_id == SOMEONE_TOKEN_IDS[0] assert c.ownerOf(SOMEONE_TOKEN_IDS[0]) == operator assert c.balanceOf(someone) == 2 assert c.balanceOf(operator) == 2 @@ -197,7 +197,7 @@ def test_safeTransferFrom_by_approved(c, w3, get_logs): args = logs[0].args assert args.sender == someone assert args.receiver == operator - assert args.tokenId == SOMEONE_TOKEN_IDS[1] + assert args.token_id == SOMEONE_TOKEN_IDS[1] assert c.ownerOf(SOMEONE_TOKEN_IDS[1]) == operator assert c.balanceOf(someone) == 2 assert c.balanceOf(operator) == 2 @@ -218,7 +218,7 @@ def test_safeTransferFrom_by_operator(c, w3, get_logs): args = logs[0].args assert args.sender == someone assert args.receiver == operator - assert args.tokenId == SOMEONE_TOKEN_IDS[2] + assert args.token_id == SOMEONE_TOKEN_IDS[2] assert c.ownerOf(SOMEONE_TOKEN_IDS[2]) == operator assert c.balanceOf(someone) == 2 assert c.balanceOf(operator) == 2 @@ -254,7 +254,7 @@ def onERC721Received( args = logs[0].args assert args.sender == someone assert args.receiver == receiver.address - assert args.tokenId == SOMEONE_TOKEN_IDS[0] + assert args.token_id == SOMEONE_TOKEN_IDS[0] assert c.ownerOf(SOMEONE_TOKEN_IDS[0]) == receiver.address assert c.balanceOf(someone) == 2 assert c.balanceOf(receiver.address) == 1 @@ -282,7 +282,7 @@ def test_approve(c, w3, tx_failed, get_logs): args = logs[0].args assert args.owner == someone assert args.approved == operator - assert args.tokenId == SOMEONE_TOKEN_IDS[0] + assert args.token_id == SOMEONE_TOKEN_IDS[0] def test_setApprovalForAll(c, w3, tx_failed, get_logs): @@ -322,7 +322,7 @@ def test_mint(c, w3, tx_failed, get_logs): args = logs[0].args assert args.sender == ZERO_ADDRESS assert args.receiver == someone - assert args.tokenId == NEW_TOKEN_ID + assert args.token_id == NEW_TOKEN_ID assert c.ownerOf(NEW_TOKEN_ID) == someone assert c.balanceOf(someone) == 4 @@ -342,7 +342,7 @@ def test_burn(c, w3, tx_failed, get_logs): args = logs[0].args assert args.sender == someone assert args.receiver == ZERO_ADDRESS - assert args.tokenId == SOMEONE_TOKEN_IDS[0] + assert args.token_id == SOMEONE_TOKEN_IDS[0] with tx_failed(): c.ownerOf(SOMEONE_TOKEN_IDS[0]) assert c.balanceOf(someone) == 2 diff --git a/tests/functional/syntax/test_interfaces.py b/tests/functional/syntax/test_interfaces.py index 193e665a34..9cff0b156a 100644 --- a/tests/functional/syntax/test_interfaces.py +++ b/tests/functional/syntax/test_interfaces.py @@ -6,6 +6,7 @@ InterfaceViolation, InvalidReference, InvalidType, + NamespaceCollision, StructureException, SyntaxException, TypeMismatch, @@ -135,7 +136,7 @@ def f(a: uint256): # visibility is nonpayable instead of view InterfaceViolation, ), ( - # `receiver` of `Transfer` event should be indexed + # exports two Transfer events """ from ethereum.ercs import IERC20 @@ -146,11 +147,6 @@ def f(a: uint256): # visibility is nonpayable instead of view receiver: address value: uint256 -event Approval: - owner: indexed(address) - spender: indexed(address) - value: uint256 - name: public(String[32]) symbol: public(String[32]) decimals: public(uint8) @@ -160,55 +156,19 @@ def f(a: uint256): # visibility is nonpayable instead of view @external def transfer(_to : address, _value : uint256) -> bool: + log Transfer(msg.sender, _to, _value) return True @external def transferFrom(_from : address, _to : address, _value : uint256) -> bool: + log IERC20.Transfer(_from, _to, _value) return True @external def approve(_spender : address, _value : uint256) -> bool: return True """, - InterfaceViolation, - ), - ( - # `value` of `Transfer` event should not be indexed - """ -from ethereum.ercs import IERC20 - -implements: IERC20 - -event Transfer: - sender: indexed(address) - receiver: indexed(address) - value: indexed(uint256) - -event Approval: - owner: indexed(address) - spender: indexed(address) - value: uint256 - -name: public(String[32]) -symbol: public(String[32]) -decimals: public(uint8) -balanceOf: public(HashMap[address, uint256]) -allowance: public(HashMap[address, HashMap[address, uint256]]) -totalSupply: public(uint256) - -@external -def transfer(_to : address, _value : uint256) -> bool: - return True - -@external -def transferFrom(_from : address, _to : address, _value : uint256) -> bool: - return True - -@external -def approve(_spender : address, _value : uint256) -> bool: - return True - """, - InterfaceViolation, + NamespaceCollision, ), ( # `payable` decorator not implemented diff --git a/tests/functional/syntax/test_logging.py b/tests/functional/syntax/test_logging.py index edc728bd89..b96700a128 100644 --- a/tests/functional/syntax/test_logging.py +++ b/tests/functional/syntax/test_logging.py @@ -24,6 +24,14 @@ def foo(): log Bar(x) """, """ +struct Foo: + pass + +@external +def foo(): + log Foo # missing parens + """, + """ event Test: n: uint256 @@ -36,7 +44,7 @@ def test(): @pytest.mark.parametrize("bad_code", fail_list) def test_logging_fail(bad_code): - with pytest.raises(TypeMismatch): + with pytest.raises((TypeMismatch, StructureException)): compiler.compile_code(bad_code) diff --git a/tests/unit/compiler/test_abi.py b/tests/unit/compiler/test_abi.py new file mode 100644 index 0000000000..5ffb3f4616 --- /dev/null +++ b/tests/unit/compiler/test_abi.py @@ -0,0 +1,638 @@ +import pytest + +from vyper.compiler import compile_code +from vyper.compiler.output import build_abi_output +from vyper.compiler.phases import CompilerData + +source_codes = [ + """ +x: int128 + +@deploy +def __init__(): + self.x = 1 + """, + """ +x: int128 + +@deploy +def __init__(): + pass + """, +] + + +@pytest.mark.parametrize("source_code", source_codes) +def test_only_init_function(source_code): + empty_sig = [ + {"outputs": [], "inputs": [], "stateMutability": "nonpayable", "type": "constructor"} + ] + + data = CompilerData(source_code) + assert build_abi_output(data) == empty_sig + + +def test_default_abi(): + default_code = """ +@payable +@external +def __default__(): + pass + """ + + data = CompilerData(default_code) + assert build_abi_output(data) == [{"stateMutability": "payable", "type": "fallback"}] + + +def test_method_identifiers(): + code = """ +x: public(int128) + +@external +def foo(y: uint256) -> Bytes[100]: + return b"hello" + """ + + out = compile_code(code, output_formats=["method_identifiers"]) + + assert out["method_identifiers"] == {"foo(uint256)": "0x2fbebd38", "x()": "0xc55699c"} + + +def test_struct_abi(): + code = """ +struct MyStruct: + a: address + b: uint256 + +@external +@view +def foo(s: MyStruct) -> MyStruct: + return s + """ + + data = CompilerData(code) + abi = build_abi_output(data) + func_abi = abi[0] + + assert func_abi["name"] == "foo" + + expected_output = [ + { + "type": "tuple", + "name": "", + "components": [{"type": "address", "name": "a"}, {"type": "uint256", "name": "b"}], + } + ] + + assert func_abi["outputs"] == expected_output + + expected_input = { + "type": "tuple", + "name": "s", + "components": [{"type": "address", "name": "a"}, {"type": "uint256", "name": "b"}], + } + + assert func_abi["inputs"][0] == expected_input + + +@pytest.mark.parametrize( + "type,abi_type", [("DynArray[NestedStruct, 2]", "tuple[]"), ("NestedStruct[2]", "tuple[2]")] +) +def test_nested_struct(type, abi_type): + code = f""" +struct MyStruct: + a: address + b: bytes32 + +struct NestedStruct: + t: MyStruct + foo: uint256 + +@view +@external +def getStructList() -> {type}: + return [ + NestedStruct(t=MyStruct(a=msg.sender, b=block.prevhash), foo=1), + NestedStruct(t=MyStruct(a=msg.sender, b=block.prevhash), foo=2) + ] + """ + + out = compile_code(code, output_formats=["abi"]) + + assert out["abi"] == [ + { + "inputs": [], + "name": "getStructList", + "outputs": [ + { + "components": [ + { + "components": [ + {"name": "a", "type": "address"}, + {"name": "b", "type": "bytes32"}, + ], + "name": "t", + "type": "tuple", + }, + {"name": "foo", "type": "uint256"}, + ], + "name": "", + "type": f"{abi_type}", + } + ], + "stateMutability": "view", + "type": "function", + } + ] + + +@pytest.mark.parametrize( + "type,abi_type", [("DynArray[DynArray[Foo, 2], 2]", "tuple[][]"), ("Foo[2][2]", "tuple[2][2]")] +) +def test_2d_list_of_struct(type, abi_type): + code = f""" +struct Foo: + a: uint256 + b: uint256 + +@view +@external +def bar(x: {type}): + pass + """ + + out = compile_code(code, output_formats=["abi"]) + + assert out["abi"] == [ + { + "inputs": [ + { + "components": [ + {"name": "a", "type": "uint256"}, + {"name": "b", "type": "uint256"}, + ], + "name": "x", + "type": f"{abi_type}", + } + ], + "name": "bar", + "outputs": [], + "stateMutability": "view", + "type": "function", + } + ] + + +def test_exports_abi(make_input_bundle): + lib1 = """ +@external +def foo(): + pass + +@external +def bar(): + pass + """ + + main = """ +import lib1 + +initializes: lib1 + +exports: lib1.foo + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + out = compile_code(main, input_bundle=input_bundle, output_formats=["abi"]) + + # just for clarity -- check bar() is not in the output + for fn in out["abi"]: + assert fn["name"] != "bar" + + expected = [ + { + "inputs": [], + "name": "foo", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function", + } + ] + + assert out["abi"] == expected + + +def test_exports_variable(make_input_bundle): + lib1 = """ +@external +def foo(): + pass + +private_storage_variable: uint256 +private_immutable_variable: immutable(uint256) +private_constant_variable: constant(uint256) = 3 + +public_storage_variable: public(uint256) +public_immutable_variable: public(immutable(uint256)) +public_constant_variable: public(constant(uint256)) = 10 + +@deploy +def __init__(a: uint256, b: uint256): + public_immutable_variable = a + private_immutable_variable = b + """ + + main = """ +import lib1 + +initializes: lib1 + +exports: ( + lib1.foo, + lib1.public_storage_variable, + lib1.public_immutable_variable, + lib1.public_constant_variable, +) + +@deploy +def __init__(): + lib1.__init__(5, 6) + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + out = compile_code(main, input_bundle=input_bundle, output_formats=["abi"]) + expected = [ + { + "inputs": [], + "name": "foo", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function", + }, + { + "inputs": [], + "name": "public_storage_variable", + "outputs": [{"name": "", "type": "uint256"}], + "stateMutability": "view", + "type": "function", + }, + { + "inputs": [], + "name": "public_immutable_variable", + "outputs": [{"name": "", "type": "uint256"}], + "stateMutability": "view", + "type": "function", + }, + { + "inputs": [], + "name": "public_constant_variable", + "outputs": [{"name": "", "type": "uint256"}], + "stateMutability": "view", + "type": "function", + }, + {"inputs": [], "outputs": [], "stateMutability": "nonpayable", "type": "constructor"}, + ] + + assert out["abi"] == expected + + +def test_event_export_from_init(make_input_bundle): + # test that events get exported when used in init functions + lib1 = """ +event MyEvent: + pass + +@deploy +def __init__(): + log MyEvent() + """ + main = """ +import lib1 + +initializes: lib1 + +@deploy +def __init__(): + lib1.__init__() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + out = compile_code(main, input_bundle=input_bundle, output_formats=["abi"]) + expected = { + "abi": [ + {"anonymous": False, "inputs": [], "name": "MyEvent", "type": "event"}, + {"inputs": [], "outputs": [], "stateMutability": "nonpayable", "type": "constructor"}, + ] + } + + assert out == expected + + +def test_event_export_from_function_export(make_input_bundle): + # test events used in exported functions are exported + lib1 = """ +event MyEvent: + pass + +@external +def foo(): + log MyEvent() + """ + main = """ +import lib1 + +initializes: lib1 + +exports: lib1.foo + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + out = compile_code(main, input_bundle=input_bundle, output_formats=["abi"]) + expected = { + "abi": [ + {"anonymous": False, "inputs": [], "name": "MyEvent", "type": "event"}, + { + "name": "foo", + "inputs": [], + "outputs": [], + "stateMutability": "nonpayable", + "type": "function", + }, + ] + } + + assert out == expected + + +def test_event_export_unused_function(make_input_bundle): + # test events in unused functions are not exported + lib1 = """ +event MyEvent: + pass + +@internal +def foo(): + log MyEvent() + """ + main = """ +import lib1 +initializes: lib1 + +# not exported/reachable from selector table +@internal +def do_foo(): + lib1.foo() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + out = compile_code(main, input_bundle=input_bundle, output_formats=["abi"]) + expected = {"abi": []} + + assert out == expected + + +def test_event_export_unused_module(make_input_bundle): + # test events are exported from functions which are used, even + # if the module is not marked `uses:`. + lib1 = """ +event MyEvent: + pass + +@internal +def foo(): + log MyEvent() + """ + main = """ +import lib1 + +@external +def bar(): + lib1.foo() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + out = compile_code(main, input_bundle=input_bundle, output_formats=["abi"]) + expected = { + "abi": [ + {"anonymous": False, "inputs": [], "name": "MyEvent", "type": "event"}, + { + "inputs": [], + "name": "bar", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function", + }, + ] + } + + assert out == expected + + +def test_event_no_export_implements(make_input_bundle): + # test events are not exported even if they are in implemented interface + ifoo = """ +event MyEvent: + pass + """ + main = """ +import ifoo + +implements: ifoo + """ + input_bundle = make_input_bundle({"ifoo.vyi": ifoo}) + out = compile_code(main, input_bundle=input_bundle, output_formats=["abi"]) + expected = {"abi": []} + + assert out == expected + + +def test_event_export_interface(make_input_bundle): + # test events from interfaces get exported + ifoo = """ +event MyEvent: + pass + +@external +def foo(): + ... + """ + main = """ +import ifoo + +@external +def bar(): + log ifoo.MyEvent() + """ + input_bundle = make_input_bundle({"ifoo.vyi": ifoo}) + out = compile_code(main, input_bundle=input_bundle, output_formats=["abi"]) + expected = { + "abi": [ + {"anonymous": False, "inputs": [], "name": "MyEvent", "type": "event"}, + { + "inputs": [], + "name": "bar", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function", + }, + ] + } + assert out == expected + + +def test_event_export_interface_no_use(make_input_bundle): + # test events from interfaces don't get exported unless used + ifoo = """ +event MyEvent: + pass + +@external +def foo(): + ... + """ + main = """ +import ifoo + +@external +def bar(): + ifoo(msg.sender).foo() + """ + input_bundle = make_input_bundle({"ifoo.vyi": ifoo}) + out = compile_code(main, input_bundle=input_bundle, output_formats=["abi"]) + expected = { + "abi": [ + { + "inputs": [], + "name": "bar", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function", + } + ] + } + + assert out == expected + + +def test_event_export_nested_export_chain(make_input_bundle): + # test exporting an event from a nested used module + lib1 = """ +event MyEvent: + pass + +@external +def foo(): + log MyEvent() + """ + lib2 = """ +import lib1 +exports: lib1.foo + """ + main = """ +import lib2 +exports: lib2.lib1.foo + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + out = compile_code(main, input_bundle=input_bundle, output_formats=["abi"]) + expected = { + "abi": [ + {"anonymous": False, "inputs": [], "name": "MyEvent", "type": "event"}, + { + "name": "foo", + "inputs": [], + "outputs": [], + "stateMutability": "nonpayable", + "type": "function", + }, + ] + } + + assert out == expected + + +def test_event_export_nested_internal(make_input_bundle): + # test events are exported from nested internal calls across modules + lib1 = """ +event MyEvent: + pass + +@internal +def foo(): + log MyEvent() + """ + lib2 = """ +import lib1 + +@internal +def bar(): + lib1.foo() + """ + main = """ +import lib2 # no uses + +@external +def baz(): + lib2.bar() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + out = compile_code(main, input_bundle=input_bundle, output_formats=["abi"]) + expected = { + "abi": [ + {"anonymous": False, "inputs": [], "name": "MyEvent", "type": "event"}, + { + "name": "baz", + "inputs": [], + "outputs": [], + "stateMutability": "nonpayable", + "type": "function", + }, + ] + } + + assert out == expected + + +def test_event_export_nested_no_uses(make_input_bundle): + # event is not exported when it's not used + lib1 = """ +event MyEvent: + pass + +counter: uint256 + +@internal +def foo(): + log MyEvent() + +@internal +def update_counter(): + self.counter += 1 + """ + lib2 = """ +import lib1 +uses: lib1 + +@internal +def use_lib1(): + lib1.update_counter() + """ + main = """ +import lib1 +import lib2 + +initializes: lib1 +initializes: lib2[lib1 := lib1] + +@external +def foo(): + lib2.use_lib1() + """ + input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) + out = compile_code(main, input_bundle=input_bundle, output_formats=["abi"]) + expected = { + "abi": [ + { + "name": "foo", + "inputs": [], + "outputs": [], + "stateMutability": "nonpayable", + "type": "function", + } + ] + } + + assert out == expected diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 2c18fa7ed9..772a9c0d03 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -167,7 +167,7 @@ pass_stmt: _PASS break_stmt: _BREAK continue_stmt: _CONTINUE -log_stmt: _LOG NAME "(" [arguments] ")" +log_stmt: _LOG (NAME | variable_access) "(" [arguments] ")" return_stmt: _RETURN [_expr ("," _expr)*] _UNREACHABLE: "UNREACHABLE" raise_stmt: _RAISE -> raise diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 053c2232b9..1310e997cf 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -728,6 +728,10 @@ def validate(self): class Log(Stmt): __slots__ = ("value",) + def validate(self): + if not isinstance(self.value, Call): + raise StructureException("Log must call an event", self.value) + class FlagDef(TopLevel): __slots__ = ("name", "body") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 4ddb89222e..cbc41a09a7 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -78,7 +78,7 @@ class arg(VyperNode): ... class Return(VyperNode): ... class Log(VyperNode): - value: VyperNode = ... + value: Call = ... class FlagDef(VyperNode): body: list = ... diff --git a/vyper/builtins/interfaces/IERC20.vyi b/vyper/builtins/interfaces/IERC20.vyi index ee533ab326..3f150d13e8 100644 --- a/vyper/builtins/interfaces/IERC20.vyi +++ b/vyper/builtins/interfaces/IERC20.vyi @@ -1,7 +1,7 @@ # Events event Transfer: sender: indexed(address) - recipient: indexed(address) + receiver: indexed(address) value: uint256 event Approval: diff --git a/vyper/builtins/interfaces/IERC721.vyi b/vyper/builtins/interfaces/IERC721.vyi index b8dcfd3c5f..345ba02529 100644 --- a/vyper/builtins/interfaces/IERC721.vyi +++ b/vyper/builtins/interfaces/IERC721.vyi @@ -2,7 +2,7 @@ event Transfer: sender: indexed(address) - recipient: indexed(address) + receiver: indexed(address) token_id: indexed(uint256) event Approval: diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 7a13050cac..9a1395bb49 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -22,9 +22,7 @@ def _runtime_reachable_functions(module_t, id_generator): ret = OrderedSet() for fn_t in module_t.exposed_functions: - # resolve variabledecl getter source - if isinstance(fn_t.ast_def, vy_ast.VariableDecl): - fn_t = fn_t.ast_def._expanded_getter._metadata["func_type"] + assert isinstance(fn_t.ast_def, vy_ast.FunctionDef) ret.update(fn_t.reachable_internal_functions) ret.add(fn_t) @@ -512,12 +510,18 @@ def generate_ir_for_module(module_t: ModuleT) -> tuple[IRnode, IRnode]: raise CompilerPanic("unreachable") deploy_code.append(["deploy", 0, runtime, 0]) - # compile all internal functions so that _ir_info is populated (whether or - # not it makes it into the final IR artifact) + # compile all remaining internal functions so that _ir_info is populated + # (whether or not it makes it into the final IR artifact) + to_visit: OrderedSet = OrderedSet() for func_ast in module_t.function_defs: fn_t = func_ast._metadata["func_type"] - if fn_t.is_internal and fn_t._ir_info is None: + if fn_t.is_internal: + to_visit.update(fn_t.reachable_internal_functions) + to_visit.add(fn_t) + + for fn_t in to_visit: + if fn_t._ir_info is None: id_generator.ensure_id(fn_t) - _ = _ir_for_internal_function(func_ast, module_t, False) + _ = _ir_for_internal_function(fn_t.ast_def, module_t, False) return IRnode.from_list(deploy_code), IRnode.from_list(runtime) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 7919d5c427..def62576e0 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -175,9 +175,9 @@ def __post_init__(self): self._modification_count = 0 @property - def getter_type(self) -> Optional["ContractFunctionT"]: + def getter_ast(self) -> Optional[vy_ast.VyperNode]: assert self.decl_node is not None # help mypy - ret = self.decl_node._metadata.get("getter_type", None) + ret = self.decl_node._expanded_getter assert (ret is not None) == self.is_public, self return ret diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 449f4e05e1..7bcef3506c 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -563,8 +563,9 @@ def visit_If(self, node): self.visit(n) def visit_Log(self, node): - if not isinstance(node.value, vy_ast.Call): - raise StructureException("Log must call an event", node) + # postcondition of Log.validate() + assert isinstance(node.value, vy_ast.Call) + f = get_exact_type_from_node(node.value.func) if not is_type_t(f, EventT): raise StructureException("Value is not an event", node.value) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 5c61864aa5..3f2fd2cebc 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -67,6 +67,7 @@ def validate_module_semantics_r( assert isinstance(module_ast._metadata["type"], ModuleT) return module_ast._metadata["type"] + # TODO: move this to parser or VyperNode construction validate_literal_nodes(module_ast) # validate semantics and annotate AST with type/semantics information @@ -74,10 +75,14 @@ def validate_module_semantics_r( with namespace.enter_scope(), import_graph.enter_path(module_ast): analyzer = ModuleAnalyzer(module_ast, input_bundle, namespace, import_graph, is_interface) - ret = analyzer.analyze() + analyzer.analyze_module_body() + _analyze_call_graph(module_ast) generate_public_variable_getters(module_ast) + ret = ModuleT(module_ast) + module_ast._metadata["type"] = ret + # if this is an interface, the function is already validated # in `ContractFunction.from_vyi()` if not is_interface: @@ -88,6 +93,39 @@ def validate_module_semantics_r( return ret +def _analyze_call_graph(module_ast: vy_ast.Module): + # get list of internal function calls made by each function + # CMC 2024-02-03 note: this could be cleaner in analysis/local.py + function_defs = module_ast.get_children(vy_ast.FunctionDef) + + for func in function_defs: + fn_t = func._metadata["func_type"] + assert len(fn_t.called_functions) == 0 + fn_t.called_functions = OrderedSet() + + function_calls = func.get_descendants(vy_ast.Call) + + for call in function_calls: + try: + call_t = get_exact_type_from_node(call.func) + except VyperException: + # there is a problem getting the call type. this might be + # an issue, but it will be handled properly later. right now + # we just want to be able to construct the call graph. + continue + + if isinstance(call_t, ContractFunctionT) and ( + call_t.is_internal or call_t.is_constructor + ): + fn_t.called_functions.add(call_t) + + for func in function_defs: + fn_t = func._metadata["func_type"] + + # compute reachable set and validate the call graph + _compute_reachable_set(fn_t) + + # compute reachable set and validate the call graph (detect cycles) def _compute_reachable_set(fn_t: ContractFunctionT, path: list[ContractFunctionT] = None) -> None: path = path or [] @@ -96,16 +134,19 @@ def _compute_reachable_set(fn_t: ContractFunctionT, path: list[ContractFunctionT root = path[0] for g in fn_t.called_functions: + if g in fn_t.reachable_internal_functions: + # already seen + continue + if g == root: message = " -> ".join([f.name for f in path]) raise CallViolation(f"Contract contains cyclic function call: {message}") _compute_reachable_set(g, path=path) - for h in g.reachable_internal_functions: - assert h != fn_t # sanity check - - fn_t.reachable_internal_functions.add(h) + g_reachable = g.reachable_internal_functions + assert fn_t not in g_reachable # sanity check + fn_t.reachable_internal_functions.update(g_reachable) fn_t.reachable_internal_functions.add(g) @@ -143,7 +184,7 @@ def __init__( if not hasattr(self.input_bundle._cache, "_ast_of"): self.input_bundle._cache._ast_of: dict[int, vy_ast.Module] = {} # type: ignore - def analyze(self) -> ModuleT: + def analyze_module_body(self): # generate a `ModuleT` from the top-level node # note: also validates unique method ids @@ -169,13 +210,9 @@ def analyze(self) -> ModuleT: # run before exports for exception handling priority self._visit_nodes_looping((vy_ast.VariableDecl, vy_ast.FunctionDef)) - # mutate _exposed_functions + # mutates _exposed_functions self._visit_nodes_linear(vy_ast.ExportsDecl) - # we can get a ModuleT once all functions and types are handled - self.module_t = ModuleT(self.ast) - self.ast._metadata["type"] = self.module_t - # handle implements last, after all functions are handled self._visit_nodes_linear(vy_ast.ImplementsDecl) @@ -191,40 +228,6 @@ def analyze(self) -> ModuleT: _ns.update({k: self.namespace[k] for k in self.namespace._scopes[-1]}) # type: ignore self.ast._metadata["namespace"] = _ns - self.analyze_call_graph() - - return self.module_t - - def analyze_call_graph(self): - # get list of internal function calls made by each function - # CMC 2024-02-03 note: this could be cleaner in analysis/local.py - function_defs = self.module_t.function_defs - - for func in function_defs: - fn_t = func._metadata["func_type"] - - function_calls = func.get_descendants(vy_ast.Call) - - for call in function_calls: - try: - call_t = get_exact_type_from_node(call.func) - except VyperException: - # either there is a problem getting the call type. this is - # an issue, but it will be handled properly later. right now - # we just want to be able to construct the call graph. - continue - - if isinstance(call_t, ContractFunctionT) and ( - call_t.is_internal or call_t.is_constructor - ): - fn_t.called_functions.add(call_t) - - for func in function_defs: - fn_t = func._metadata["func_type"] - - # compute reachable set and validate the call graph - _compute_reachable_set(fn_t) - def _visit_nodes_linear(self, node_type): for node in self._to_visit.copy(): if not isinstance(node, node_type): @@ -389,9 +392,11 @@ def visit_ImplementsDecl(self, node): hint = f"try renaming `{path}` to `{path}i`" raise StructureException(msg, node.annotation, hint=hint) - funcs = {fn_t: fn_t.decl_node for fn_t in self.module_t.exposed_functions} - events = [n._metadata["event_type"] for n in self.module_t.event_defs] - type_.validate_implements(node, funcs, events) + # grab exposed functions + funcs = self._exposed_functions + type_.validate_implements(node, funcs) + + node._metadata["interface_type"] = type_ def visit_UsesDecl(self, node): # TODO: check duplicate uses declarations, e.g. @@ -508,8 +513,8 @@ def visit_ExportsDecl(self, node): if not func_t.is_external: raise StructureException("not an external function!", decl_node, item) - self._add_exposed_function(func_t, item) - with tag_exceptions(item): + self._add_exposed_function(func_t, item, relax=False) + with tag_exceptions(item): # tag with specific item self._self_t.typ.add_member(func_t.name, func_t) funcs.append(func_t) @@ -527,7 +532,7 @@ def visit_ExportsDecl(self, node): def _self_t(self): return self.namespace["self"] - def _add_exposed_function(self, func_t, node): + def _add_exposed_function(self, func_t, node, relax=True): # call this before self._self_t.typ.add_member() for exception raising # priority if (prev_decl := self._exposed_functions.get(func_t)) is not None: @@ -545,6 +550,7 @@ def visit_VariableDecl(self, node): # we need this when building the public getter func_t = ContractFunctionT.getter_from_VariableDecl(node) node._metadata["getter_type"] = func_t + self._add_exposed_function(func_t, node) # TODO: move this check to local analysis if node.is_immutable: @@ -657,6 +663,7 @@ def visit_FunctionDef(self, node): self._self_t.typ.add_member(func_t.name, func_t) node._metadata["func_type"] = func_t + self._add_exposed_function(func_t, node) def visit_Import(self, node): # import x.y[name] as y[alias] @@ -721,6 +728,12 @@ def _load_import_helper( hint = "try renaming `vyper.interfaces` to `ethereum.ercs`" raise ModuleNotFound(module_str, hint=hint) if _is_builtin(module_str): + components = module_str.split(".") + # hint: rename ERC20 to IERC20 + if components[-1].startswith("ERC"): + module_prefix = components[-1] + hint = f"try renaming `{module_prefix}` to `I{module_prefix}`" + raise ModuleNotFound(module_str, hint=hint) return _load_builtin_import(level, module_str) path = _import_to_path(level, module_str) diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 0759c7aa84..2cbb972ac7 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -115,10 +115,11 @@ def __init__( self._analysed = False # a list of internal functions this function calls. - # to be populated during analysis + # to be populated during module analysis. self.called_functions: OrderedSet[ContractFunctionT] = OrderedSet() # recursively reachable from this function + # to be populated during module analysis. self.reachable_internal_functions: OrderedSet[ContractFunctionT] = OrderedSet() # writes to variables from this function @@ -475,9 +476,8 @@ def implements(self, other: "ContractFunctionT") -> bool: Used when determining if an interface has been implemented. This method should not be directly implemented by any inherited classes. """ - - if not self.is_external: - return False + if not self.is_external: # pragma: nocover + raise CompilerPanic("unreachable!") arguments, return_type = self._iface_sig other_arguments, other_return_type = other._iface_sig diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index ceeacd7263..5faefaf404 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -13,14 +13,16 @@ from vyper.semantics.analysis.base import Modifiability from vyper.semantics.analysis.utils import ( check_modifiability, + get_exact_type_from_node, validate_expected_type, validate_unique_method_ids, ) from vyper.semantics.data_locations import DataLocation -from vyper.semantics.types.base import TYPE_T, VyperType +from vyper.semantics.types.base import TYPE_T, VyperType, is_type_t from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.primitives import AddressT from vyper.semantics.types.user import EventT, StructT, _UserType +from vyper.utils import OrderedSet if TYPE_CHECKING: from vyper.semantics.analysis.base import ModuleInfo @@ -92,13 +94,9 @@ def _ctor_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifia return check_modifiability(node.args[0], modifiability) def validate_implements( - self, - node: vy_ast.ImplementsDecl, - functions: dict[ContractFunctionT, vy_ast.VyperNode], - events: list[EventT], + self, node: vy_ast.ImplementsDecl, functions: dict[ContractFunctionT, vy_ast.VyperNode] ) -> None: fns_by_name = {fn_t.name: fn_t for fn_t in functions.keys()} - events_by_name = {event_t.name: event_t for event_t in events} unimplemented = [] @@ -120,25 +118,13 @@ def _is_function_implemented(fn_name, fn_type): if not _is_function_implemented(name, type_): unimplemented.append(name) - # check for missing events - for name, event in self.events.items(): - if name not in events_by_name: - unimplemented.append(name) - continue - - other = events_by_name[name] - - if other.event_id != event.event_id or other.indexed != event.indexed: - unimplemented.append(f"{name} is not implemented! (should be {event})") - if len(unimplemented) > 0: # TODO: improve the error message for cases where the # mismatch is small (like mutability, or just one argument # is off, etc). missing_str = ", ".join(sorted(unimplemented)) raise InterfaceViolation( - f"Contract does not implement all interface functions or events: {missing_str}", - node, + f"Contract does not implement all interface functions: {missing_str}", node ) def to_toplevel_abi_dict(self) -> list[dict]: @@ -235,8 +221,13 @@ def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT": if (fn_t := module_t.init_function) is not None: funcs.append((fn_t.name, fn_t)) - events = [(node.name, node._metadata["event_type"]) for node in module_t.event_defs] + event_set: OrderedSet[EventT] = OrderedSet() + event_set.update([node._metadata["event_type"] for node in module_t.event_defs]) + event_set.update(module_t.used_events) + events = [(event_t.name, event_t) for event_t in event_set] + # these are accessible via import, but they do not show up + # in the ABI json structs = [(node.name, node._metadata["struct_type"]) for node in module_t.struct_defs] return cls._from_lists(module_t._id, funcs, events, structs) @@ -330,12 +321,11 @@ def __hash__(self): def get_type_member(self, key: str, node: vy_ast.VyperNode) -> "VyperType": return self._helper.get_member(key, node) - # this is a property, because the function set changes after AST expansion - @property + @cached_property def function_defs(self): return self._module.get_children(vy_ast.FunctionDef) - @property + @cached_property def event_defs(self): return self._module.get_children(vy_ast.EventDef) @@ -347,6 +337,10 @@ def struct_defs(self): def interface_defs(self): return self._module.get_children(vy_ast.InterfaceDef) + @cached_property + def implements_decls(self): + return self._module.get_children(vy_ast.ImplementsDecl) + @cached_property def interfaces(self) -> dict[str, InterfaceT]: ret = {} @@ -425,7 +419,7 @@ def exposed_functions(self): ret.extend(node._metadata["exports_info"].functions) ret.extend([f for f in self.functions.values() if f.is_external]) - ret.extend([v.getter_type for v in self.public_variables.values()]) + ret.extend([v.getter_ast._metadata["func_type"] for v in self.public_variables.values()]) # precondition: no duplicate exports assert len(set(ret)) == len(ret) @@ -450,6 +444,35 @@ def public_variables(self): def functions(self): return {f.name: f._metadata["func_type"] for f in self.function_defs} + @cached_property + # it would be nice to rely on the function analyzer to do this analysis, + # but we don't have the result of function analysis at the time we need to + # construct `self.interface`. + def used_events(self) -> OrderedSet[EventT]: + ret: OrderedSet[EventT] = OrderedSet() + + reachable: OrderedSet[ContractFunctionT] = OrderedSet() + if self.init_function is not None: + reachable.add(self.init_function) + reachable.update(self.init_function.reachable_internal_functions) + for fn_t in self.exposed_functions: + reachable.add(fn_t) + reachable.update(fn_t.reachable_internal_functions) + + for fn_t in reachable: + fn_ast = fn_t.decl_node + assert isinstance(fn_ast, vy_ast.FunctionDef) + + for node in fn_ast.get_descendants(vy_ast.Log): + call_t = get_exact_type_from_node(node.value.func) + if not is_type_t(call_t, EventT): + # this is an error, but it will be handled later + continue + + ret.add(call_t.typedef) + + return ret + @cached_property def immutables(self): return [t for t in self.variables.values() if t.is_immutable]