diff --git a/boa/contracts/abi/abi_contract.py b/boa/contracts/abi/abi_contract.py index 9c4aba83..06a142a9 100644 --- a/boa/contracts/abi/abi_contract.py +++ b/boa/contracts/abi/abi_contract.py @@ -1,7 +1,6 @@ from collections import defaultdict from copy import deepcopy from functools import cached_property -from os.path import basename from typing import Any, Optional, Union from warnings import warn @@ -33,7 +32,9 @@ def __init__(self, abi: dict, contract_name: str): self.contract: Optional["ABIContract"] = None @property - def name(self) -> str: + def name(self) -> str | None: + if self.is_constructor: + return None return self._abi["name"] @cached_property @@ -54,16 +55,28 @@ def return_type(self) -> list: @property def full_signature(self) -> str: + assert self.name is not None, "Constructor does not have a name." return f"{self.name}{self.signature}" @property def pretty_signature(self) -> str: - return f"{self.name}{self.signature} -> {self.return_type}" + return f"{self.pretty_name}{self.signature} -> {self.return_type}" + + @cached_property + def pretty_name(self): + if self.is_constructor: + return "constructor" + return self.name @cached_property def method_id(self) -> bytes: + assert self.name, "Constructor does not have a method id." return method_id(self.name + self.signature) + @cached_property + def is_constructor(self): + return self._abi["type"] == "constructor" + def __repr__(self) -> str: return f"ABI {self._contract_name}.{self.pretty_signature}" @@ -87,7 +100,10 @@ def is_encodable(self, *args, **kwargs) -> bool: def prepare_calldata(self, *args, **kwargs) -> bytes: """Prepare the call data for the function call.""" abi_args = self._merge_kwargs(*args, **kwargs) - return self.method_id + abi_encode(self.signature, abi_args) + encoded_args = abi_encode(self.signature, abi_args) + if self.is_constructor: + return encoded_args + return self.method_id + encoded_args def _merge_kwargs(self, *args, **kwargs) -> list: """Merge positional and keyword arguments into a single list.""" @@ -153,7 +169,7 @@ def __init__(self, functions: list[ABIFunction]): self.functions = functions @cached_property - def name(self) -> str: + def name(self) -> str | None: return self.functions[0].name def prepare_calldata(self, *args, disambiguate_signature=None, **kwargs) -> bytes: @@ -193,6 +209,7 @@ def _pick_overload( ] assert len(matches) <= 1, "ABI signature must be unique" + assert self.name, "Constructor does not have a name." match matches: case [function]: return function @@ -215,7 +232,7 @@ class ABIContract(_BaseEVMContract): def __init__( self, name: str, - abi: dict, + abi: list[dict], functions: list[ABIFunction], address: Address, filename: Optional[str] = None, @@ -237,10 +254,12 @@ def __init__( for f in self._functions: overloads[f.name].append(f) - for name, group in overloads.items(): - setattr(self, name, ABIOverload.create(group, self)) + for fn_name, group in overloads.items(): + if fn_name is not None: # constructors have no name + setattr(self, fn_name, ABIOverload.create(group, self)) self._address = Address(address) + self._computation: Optional[ComputationAPI] = None @property def abi(self): @@ -252,7 +271,11 @@ def method_id_map(self): Returns a mapping from method id to function object. This is used to create the stack trace when an error occurs. """ - return {function.method_id: function for function in self._functions} + return { + function.method_id: function + for function in self._functions + if not function.is_constructor + } def marshal_to_python(self, computation, abi_type: list[str]) -> tuple[Any, ...]: """ @@ -260,6 +283,7 @@ def marshal_to_python(self, computation, abi_type: list[str]) -> tuple[Any, ...] :param computation: the computation object returned by `execute_code` :param abi_type: the ABI type of the return value. """ + self._computation = computation # when there's no contract in the address, the computation output is empty if computation.is_error: return self.handle_error(computation) @@ -274,13 +298,17 @@ def stack_trace(self, computation: ComputationAPI) -> StackTrace: """ Create a stack trace for a failed contract call. """ + reason = "" + if computation.is_error: + reason = " ".join(str(arg) for arg in computation.error.args if arg != b"") + calldata_method_id = bytes(computation.msg.data[:4]) if calldata_method_id in self.method_id_map: function = self.method_id_map[calldata_method_id] - msg = f" ({self}.{function.pretty_signature})" + msg = f" {reason}({self}.{function.pretty_signature})" else: # Method might not be specified in the ABI - msg = f" (unknown method id {self}.0x{calldata_method_id.hex()})" + msg = f" {reason}(unknown method id {self}.0x{calldata_method_id.hex()})" return_trace = StackTrace([msg]) return _handle_child_trace(computation, self.env, return_trace) @@ -290,7 +318,7 @@ def deployer(self) -> "ABIContractFactory": """ Returns a factory that can be used to retrieve another deployed contract. """ - return ABIContractFactory(self._name, self._abi, self._functions) + return ABIContractFactory(self._name, self._abi, filename=self.filename) def __repr__(self): file_str = f" (file {self.filename})" if self.filename else "" @@ -305,28 +333,26 @@ class ABIContractFactory: do any contract deployment. """ - def __init__( - self, - name: str, - abi: dict, - functions: list[ABIFunction], - filename: Optional[str] = None, - ): + def __init__(self, name: str, abi: list[dict], filename: Optional[str] = None): self._name = name self._abi = abi - self._functions = functions - self._filename = filename + self.filename = filename @cached_property def abi(self): return deepcopy(self._abi) - @classmethod - def from_abi_dict(cls, abi, name=""): - functions = [ - ABIFunction(item, name) for item in abi if item.get("type") == "function" + @cached_property + def functions(self): + return [ + ABIFunction(item, self._name) + for item in self.abi + if item.get("type") == "function" ] - return cls(basename(name), abi, functions, filename=name) + + @classmethod + def from_abi_dict(cls, abi, name="", filename=None): + return cls(name, abi, filename) def at(self, address: Address | str) -> ABIContract: """ @@ -334,7 +360,7 @@ def at(self, address: Address | str) -> ABIContract: """ address = Address(address) contract = ABIContract( - self._name, self._abi, self._functions, address, self._filename + self._name, self._abi, self.functions, address, self.filename ) contract.env.register_contract(address, contract) return contract diff --git a/boa/contracts/base_evm_contract.py b/boa/contracts/base_evm_contract.py index d9e6d0cd..fdeb2898 100644 --- a/boa/contracts/base_evm_contract.py +++ b/boa/contracts/base_evm_contract.py @@ -61,10 +61,9 @@ def last_frame(self): def _trace_for_unknown_contract(computation, env): - ret = StackTrace( - [f""] - ) - return _handle_child_trace(computation, env, ret) + err = f" " + trace = StackTrace([err]) + return _handle_child_trace(computation, env, trace) def _handle_child_trace(computation, env, return_trace): diff --git a/boa/contracts/vyper/vyper_contract.py b/boa/contracts/vyper/vyper_contract.py index 98c7e503..24ed9bd9 100644 --- a/boa/contracts/vyper/vyper_contract.py +++ b/boa/contracts/vyper/vyper_contract.py @@ -74,6 +74,8 @@ class VyperDeployer: + create_compiler_data = CompilerData # this may be a different class in plugins + def __init__(self, compiler_data, filename=None): self.compiler_data = compiler_data @@ -303,6 +305,11 @@ def _check(cond, msg=""): assert len(args) == 1, "multiple args!" assert len(kwargs) == 0, "can't mix args and kwargs!" err = args[0] + if isinstance(frame, str): + # frame for unknown contracts is a string + _check(err in frame, f"{frame} does not match {args}") + return + # try to match anything _check( err == frame.pretty_vm_reason @@ -315,6 +322,10 @@ def _check(cond, msg=""): # try to match a specific kwarg assert len(kwargs) == 1 and len(args) == 0 + if isinstance(frame, str): + # frame for unknown contracts is a string + raise ValueError(f"expected {kwargs} but got {frame}") + # don't accept magic if frame.dev_reason: assert frame.dev_reason.reason_type not in ("vm_error", "compiler") @@ -547,6 +558,7 @@ def _set_bytecode(self, bytecode: bytes) -> None: to_check = bytecode if self.data_section_size != 0: to_check = bytecode[: -self.data_section_size] + assert isinstance(self.compiler_data, CompilerData) if to_check != self.compiler_data.bytecode_runtime: warnings.warn( f"casted bytecode does not match compiled bytecode at {self}", diff --git a/boa/integrations/jupyter/browser.py b/boa/integrations/jupyter/browser.py index 22426939..019df7fa 100644 --- a/boa/integrations/jupyter/browser.py +++ b/boa/integrations/jupyter/browser.py @@ -74,21 +74,16 @@ def send_transaction(self, tx_data: dict) -> dict: ) return convert_frontend_dict(sign_data) - def sign_typed_data( - self, domain: dict[str, Any], types: dict[str, list], value: dict[str, Any] - ) -> str: + def sign_typed_data(self, full_message: dict[str, Any]) -> str: """ Sign typed data value with types data structure for domain using the EIP-712 specification. - :param domain: The domain data structure. - :param types: The types data structure. - :param value: The value to sign. + :param full_message: The full message to sign. :return: The signature. """ return _javascript_call( - "signTypedData", - domain, - types, - value, + "rpc", + "eth_signTypedData_v4", + [self.address, full_message], timeout_message=TRANSACTION_TIMEOUT_MESSAGE, ) @@ -141,18 +136,10 @@ def __init__(self, address=None, **kwargs): self.signer = BrowserSigner(address) self.set_eoa(self.signer) - def get_chain_id(self) -> int: - chain_id = _javascript_call( - "rpc", "eth_chainId", timeout_message=RPC_TIMEOUT_MESSAGE - ) - return int.from_bytes(bytes.fromhex(chain_id[2:]), "big") - def set_chain_id(self, chain_id: int | str): - _javascript_call( - "rpc", + self._rpc.fetch( "wallet_switchEthereumChain", [{"chainId": chain_id if isinstance(chain_id, str) else hex(chain_id)}], - timeout_message=RPC_TIMEOUT_MESSAGE, ) self._reset_fork() @@ -169,7 +156,7 @@ def _javascript_call(js_func: str, *args, timeout_message: str) -> Any: :return: The result of the Javascript snippet sent to the API. """ token = _generate_token() - args_str = ", ".join(json.dumps(p) for p in chain([token], args)) + args_str = ", ".join(json.dumps(p, cls=_BytesEncoder) for p in chain([token], args)) js_code = f"window._titanoboa.{js_func}({args_str});" if BrowserRPC._debug_mode: logging.warning(f"Calling {js_func} with {args_str}") @@ -224,9 +211,31 @@ def _parse_js_result(result: dict) -> Any: if "data" in result: return result["data"] + def _find_key(input_dict, target_key, typ) -> Any: + for key, value in input_dict.items(): + if isinstance(value, dict): + found = _find_key(value, target_key, typ) + if found is not None: + return found + if key == target_key and isinstance(value, typ) and value != "error": + return value + return None + # raise the error in the Jupyter cell so that the user can see it error = result["error"] - error = error.get("info", error).get("error", error) + error = error.get("data", error) raise RPCError( - message=error.get("message", error), code=error.get("code", "CALLBACK_ERROR") + message=_find_key(error, "message", str) or _find_key(error, "error", str), + code=_find_key(error, "code", int) or -1, ) + + +class _BytesEncoder(json.JSONEncoder): + """ + A JSONEncoder that converts bytes to hex strings to be passed to JavaScript. + """ + + def default(self, o): + if isinstance(o, bytes): + return "0x" + o.hex() + return super().default(o) diff --git a/boa/integrations/jupyter/constants.py b/boa/integrations/jupyter/constants.py index ef39ac26..7cc199a2 100644 --- a/boa/integrations/jupyter/constants.py +++ b/boa/integrations/jupyter/constants.py @@ -2,7 +2,7 @@ NUL = b"\0" CALLBACK_TOKEN_TIMEOUT = timedelta(minutes=3) -SHARED_MEMORY_LENGTH = 50 * 1024 + len(NUL) # Size of the shared memory object +SHARED_MEMORY_LENGTH = 100 * 1024 + len(NUL) # Size of the shared memory object CALLBACK_TOKEN_CHARS = 30 # OSx limits this to 31 characters PLUGIN_NAME = "titanoboa_jupyterlab" TOKEN_REGEX = rf"[0-9a-fA-F]{{{CALLBACK_TOKEN_CHARS}}}" diff --git a/boa/integrations/jupyter/jupyter.js b/boa/integrations/jupyter/jupyter.js index e84d0763..a25f1bc1 100644 --- a/boa/integrations/jupyter/jupyter.js +++ b/boa/integrations/jupyter/jupyter.js @@ -44,22 +44,14 @@ return response.text(); } - let from; const loadSigner = async (address) => { const accounts = await rpc('eth_requestAccounts'); - from = accounts.includes(address) ? address : accounts[0]; - return from; + return accounts.includes(address) ? address : accounts[0]; }; /** Sign a transaction via ethers */ const sendTransaction = async transaction => ({"hash": await rpc('eth_sendTransaction', [transaction])}); - /** Sign a typed data via ethers */ - const signTypedData = (domain, types, value) => rpc( - 'eth_signTypedData_v4', - [from, JSON.stringify({domain, types, value})] - ); - /** Wait until the transaction is mined */ const waitForTransactionReceipt = async (tx_hash, timeout, poll_latency) => { while (true) { @@ -120,7 +112,6 @@ window._titanoboa = { loadSigner: handleCallback(loadSigner), sendTransaction: handleCallback(sendTransaction), - signTypedData: handleCallback(signTypedData), waitForTransactionReceipt: handleCallback(waitForTransactionReceipt), rpc: handleCallback(rpc), multiRpc: handleCallback(multiRpc), diff --git a/boa/interpret.py b/boa/interpret.py index 3fe5c1de..c40d1f66 100644 --- a/boa/interpret.py +++ b/boa/interpret.py @@ -18,6 +18,7 @@ VyperContract, VyperDeployer, ) +from boa.environment import Env from boa.explorer import fetch_abi_from_etherscan from boa.util.abi import Address from boa.util.disk_cache import DiskCache @@ -75,8 +76,12 @@ def create_module(self, spec): sys.meta_path.append(BoaImporter()) -def compiler_data(source_code: str, contract_name: str, **kwargs) -> CompilerData: +def compiler_data( + source_code: str, contract_name: str, deployer=None, **kwargs +) -> CompilerData: global _disk_cache + if deployer is None: + deployer = _get_default_deployer_class() def _ifaces(): # use get_interface_codes to get the interface source dict @@ -86,22 +91,27 @@ def _ifaces(): if _disk_cache is None: ifaces = _ifaces() - ret = CompilerData(source_code, contract_name, interface_codes=ifaces, **kwargs) - return ret + return deployer.create_compiler_data( + source_code, contract_name, interface_codes=ifaces, **kwargs + ) def func(): ifaces = _ifaces() - ret = CompilerData(source_code, contract_name, interface_codes=ifaces, **kwargs) + ret = deployer.create_compiler_data( + source_code, contract_name, interface_codes=ifaces, **kwargs + ) with anchor_compiler_settings(ret): _ = ret.bytecode, ret.bytecode_runtime # force compilation to happen return ret - cache_key = str((contract_name, source_code, kwargs)) + assert isinstance(deployer, type) + deployer_id = repr(deployer) # a unique str identifying the deployer class + cache_key = str((contract_name, source_code, kwargs, deployer_id)) return _disk_cache.caching_lookup(cache_key, func) def load(filename: str | Path, *args, **kwargs) -> _Contract: # type: ignore - name = filename + name = Path(filename).stem # TODO: investigate if we can just put name in the signature if "name" in kwargs: name = kwargs.pop("name") @@ -149,11 +159,12 @@ def loads_partial( compiler_args = compiler_args or {} - data = compiler_data(source_code, name, **compiler_args) - return VyperDeployer(data, filename=filename) + deployer_class = _get_default_deployer_class() + data = compiler_data(source_code, name, deployer_class, **compiler_args) + return deployer_class(data, filename=filename) -def load_partial(filename: str, compiler_args=None) -> VyperDeployer: # type: ignore +def load_partial(filename: str, compiler_args=None): with open(filename) as f: return loads_partial( f.read(), name=filename, filename=filename, compiler_args=compiler_args @@ -168,4 +179,11 @@ def from_etherscan( return ABIContractFactory.from_abi_dict(abi, name=name).at(addr) +def _get_default_deployer_class(): + env = Env.get_singleton() + if hasattr(env, "deployer_class"): + return env.deployer_class + return VyperDeployer + + __all__ = [] # type: ignore diff --git a/boa/network.py b/boa/network.py index b3115835..5e347e1d 100644 --- a/boa/network.py +++ b/boa/network.py @@ -43,8 +43,8 @@ def returndata_bytes(self): def is_error(self): if "structLogs" in self.raw_trace: return self.raw_trace["failed"] - else: - return "error" in self.raw_trace + # we can have `"error": null` in the payload + return self.raw_trace.get("error") is not None class _EstimateGasFailed(Exception): @@ -162,7 +162,7 @@ def __init__( ) rpc = EthereumRPC(rpc) - self._rpc = rpc + self._rpc: RPC = rpc self._reset_fork() @@ -489,7 +489,7 @@ def _send_txn(self, from_, to=None, gas=None, value=None, data=None): except RPCError as e: if e.code == 3: # execution failed at estimateGas, probably the txn reverted - raise _EstimateGasFailed() + raise _EstimateGasFailed() from e raise e from e if from_ not in self._accounts: @@ -523,6 +523,11 @@ def _send_txn(self, from_, to=None, gas=None, value=None, data=None): t_obj = TraceObject(trace) if trace is not None else None return receipt, t_obj + def get_chain_id(self) -> int: + """Get the current chain ID of the network as an integer.""" + chain_id = self._rpc.fetch("eth_chainId", []) + return int(chain_id, 16) + def set_balance(self, address, value): raise NotImplementedError("Cannot use set_balance in network mode") diff --git a/boa/vm/py_evm.py b/boa/vm/py_evm.py index f4a9392c..c9603d94 100644 --- a/boa/vm/py_evm.py +++ b/boa/vm/py_evm.py @@ -394,7 +394,7 @@ def fork_rpc(self, rpc: RPC, block_identifier: str, **kwargs): self.patch.timestamp = int(block_info["timestamp"], 16) self.patch.block_number = int(block_info["number"], 16) - # TODO patch the other stuff + self.patch.chain_id = int(rpc.fetch("eth_chainId", []), 16) self.vm.state._account_db._rpc._init_db() diff --git a/tests/unitary/jupyter/test_browser.py b/tests/unitary/jupyter/test_browser.py index db790276..2cb7e697 100644 --- a/tests/unitary/jupyter/test_browser.py +++ b/tests/unitary/jupyter/test_browser.py @@ -108,6 +108,7 @@ def create_task(future): def mock_fork(mock_callback): mock_callback("evm_snapshot", "0x123456") mock_callback("evm_revert", "0x12345678") + mock_callback("eth_chainId", "0x1") data = {"number": "0x123", "timestamp": "0x65bbb460"} mock_callback("eth_getBlockByNumber", data) @@ -135,9 +136,9 @@ def test_nest_applied(): def test_browser_sign_typed_data(display_mock, mock_callback, env): signature = env.generate_address() - mock_callback("signTypedData", signature) + mock_callback("eth_signTypedData_v4", signature) data = env.signer.sign_typed_data( - {"name": "My App"}, {"types": []}, {"data": "0x1234"} + full_message={"domain": {"name": "My App"}, "types": [], "data": "0x1234"} ) assert data == signature @@ -165,8 +166,8 @@ def test_browser_chain_id(token, env, display_mock, mock_callback): assert env.get_chain_id() == 4660 mock_callback("wallet_switchEthereumChain") env.set_chain_id(1) - assert display_mock.call_count == 3 - (js,), _ = display_mock.call_args_list[-2] + assert display_mock.call_count == 4 + (js,), _ = display_mock.call_args_list[1] assert ( f'rpc("{token}", "wallet_switchEthereumChain", [{{"chainId": "0x1"}}])' in js.data @@ -177,7 +178,7 @@ def test_browser_rpc(token, display_mock, mock_callback, account, mock_fork, env mock_callback("eth_gasPrice", "0x123") assert env.get_gas_price() == 291 - assert display_mock.call_count == 6 + assert display_mock.call_count == 7 (js,), _ = display_mock.call_args assert f'rpc("{token}", "eth_gasPrice", [])' in js.data @@ -205,8 +206,50 @@ def test_browser_rpc_server_error( assert str(exc_info.value) == "-32603: server error" +def test_browser_rpc_internal_error(mock_callback, env): + # this error was found while testing zksync in Google Colab with the browser RPC + # we are agnostic to the exact data, just testing that we find the key properly + error = { + "code": -32603, + "message": "Internal JSON-RPC error.", + "data": { + "code": 3, + "message": "insufficient funds for gas + value. balance: 0, fee: 116, value: 0", + "data": "0x", + "cause": None, + }, + } + mock_callback("eth_gasPrice", error=error) + with pytest.raises(RPCError) as exc_info: + env.get_gas_price() + assert ( + str(exc_info.value) + == "3: insufficient funds for gas + value. balance: 0, fee: 116, value: 0" + ) + + +def test_browser_rpc_debug_error(mock_callback, env): + # this error was found while testing zksync in Google Colab with the browser RPC + # we are agnostic to the exact data, just testing that we find the key properly + message = 'The method "debug_traceCall" does not exist / is not available.' + error = { + "error": { + "code": -32601, + "message": message, + "data": { + "origin": "https://44wgpcbsrwx-496ff2e9c6d22116-0-colab.googleusercontent.com", + "cause": None, + }, + } + } + mock_callback("eth_gasPrice", error=error) + with pytest.raises(RPCError) as exc_info: + env.get_gas_price() + assert str(exc_info.value) == f"-32601: {message}" + + def test_browser_js_error(token, display_mock, mock_callback, account, mock_fork): mock_callback("loadSigner", error={"message": "custom message", "stack": ""}) with pytest.raises(RPCError) as exc_info: BrowserSigner() - assert str(exc_info.value) == "CALLBACK_ERROR: custom message" + assert str(exc_info.value) == "-1: custom message" diff --git a/tests/unitary/jupyter/test_handlers.py b/tests/unitary/jupyter/test_handlers.py index 8a5da923..501d7923 100644 --- a/tests/unitary/jupyter/test_handlers.py +++ b/tests/unitary/jupyter/test_handlers.py @@ -59,7 +59,7 @@ def test_value_error(callback_handler, token, shared_memory): callback_handler.post(token) assert callback_handler.get_status() == 413 callback_handler.finish.assert_called_once_with( - {"error": "Request body has 51201 bytes, but only 51200 are allowed"} + {"error": "Request body has 102401 bytes, but only 102400 are allowed"} ) diff --git a/tests/unitary/test_abi.py b/tests/unitary/test_abi.py index 1115efb2..831140a2 100644 --- a/tests/unitary/test_abi.py +++ b/tests/unitary/test_abi.py @@ -169,7 +169,8 @@ def test(n: uint256) -> uint256: def test_abi_not_deployed(): - f = ABIFunction({"name": "test", "inputs": [], "outputs": []}, contract_name="c") + fn_abi = {"name": "test", "inputs": [], "outputs": [], "type": "function"} + f = ABIFunction(fn_abi, contract_name="c") with pytest.raises(Exception) as exc_info: f() (error,) = exc_info.value.args diff --git a/tests/unitary/test_reverts.py b/tests/unitary/test_reverts.py index cadba031..58d01c04 100644 --- a/tests/unitary/test_reverts.py +++ b/tests/unitary/test_reverts.py @@ -207,6 +207,31 @@ def ext_call2(): p.ext_call2() +def test_stack_trace(contract): + c = boa.loads( + """ +interface HasFoo: + def foo(x: uint256): nonpayable + +@external +def revert(contract: HasFoo): + contract.foo(5) + """ + ) + + with pytest.raises(BoaError) as context: + c.revert(contract.address) + + trace = [ + (line.contract_repr, line.error_detail, line.pretty_vm_reason) + for line in context.value.stack_trace + ] + assert trace == [ + (repr(contract), "user revert with reason", "x is not 4"), + (repr(c), "external call failed", "x is not 4"), + ] + + def test_trace_constructor_revert(): code = """ @external