Skip to content

Commit

Permalink
Merge pull request #195 from DanielSchiavini/zksync
Browse files Browse the repository at this point in the history
feat: refactor for zksync abstraction
  • Loading branch information
charles-cooper authored Jun 12, 2024
2 parents 2e0d5c1 + 5f48854 commit 2a0065d
Show file tree
Hide file tree
Showing 13 changed files with 215 additions and 86 deletions.
80 changes: 53 additions & 27 deletions boa/contracts/abi/abi_contract.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections import defaultdict
from copy import deepcopy
from functools import cached_property
from os.path import basename
from typing import Any, Optional, Union
from warnings import warn

Expand Down Expand Up @@ -33,7 +32,9 @@ def __init__(self, abi: dict, contract_name: str):
self.contract: Optional["ABIContract"] = None

@property
def name(self) -> str:
def name(self) -> str | None:
if self.is_constructor:
return None
return self._abi["name"]

@cached_property
Expand All @@ -54,16 +55,28 @@ def return_type(self) -> list:

@property
def full_signature(self) -> str:
assert self.name is not None, "Constructor does not have a name."
return f"{self.name}{self.signature}"

@property
def pretty_signature(self) -> str:
return f"{self.name}{self.signature} -> {self.return_type}"
return f"{self.pretty_name}{self.signature} -> {self.return_type}"

@cached_property
def pretty_name(self):
if self.is_constructor:
return "constructor"
return self.name

@cached_property
def method_id(self) -> bytes:
assert self.name, "Constructor does not have a method id."
return method_id(self.name + self.signature)

@cached_property
def is_constructor(self):
return self._abi["type"] == "constructor"

def __repr__(self) -> str:
return f"ABI {self._contract_name}.{self.pretty_signature}"

Expand All @@ -87,7 +100,10 @@ def is_encodable(self, *args, **kwargs) -> bool:
def prepare_calldata(self, *args, **kwargs) -> bytes:
"""Prepare the call data for the function call."""
abi_args = self._merge_kwargs(*args, **kwargs)
return self.method_id + abi_encode(self.signature, abi_args)
encoded_args = abi_encode(self.signature, abi_args)
if self.is_constructor:
return encoded_args
return self.method_id + encoded_args

def _merge_kwargs(self, *args, **kwargs) -> list:
"""Merge positional and keyword arguments into a single list."""
Expand Down Expand Up @@ -153,7 +169,7 @@ def __init__(self, functions: list[ABIFunction]):
self.functions = functions

@cached_property
def name(self) -> str:
def name(self) -> str | None:
return self.functions[0].name

def prepare_calldata(self, *args, disambiguate_signature=None, **kwargs) -> bytes:
Expand Down Expand Up @@ -193,6 +209,7 @@ def _pick_overload(
]
assert len(matches) <= 1, "ABI signature must be unique"

assert self.name, "Constructor does not have a name."
match matches:
case [function]:
return function
Expand All @@ -215,7 +232,7 @@ class ABIContract(_BaseEVMContract):
def __init__(
self,
name: str,
abi: dict,
abi: list[dict],
functions: list[ABIFunction],
address: Address,
filename: Optional[str] = None,
Expand All @@ -237,10 +254,12 @@ def __init__(
for f in self._functions:
overloads[f.name].append(f)

for name, group in overloads.items():
setattr(self, name, ABIOverload.create(group, self))
for fn_name, group in overloads.items():
if fn_name is not None: # constructors have no name
setattr(self, fn_name, ABIOverload.create(group, self))

self._address = Address(address)
self._computation: Optional[ComputationAPI] = None

@property
def abi(self):
Expand All @@ -252,14 +271,19 @@ def method_id_map(self):
Returns a mapping from method id to function object.
This is used to create the stack trace when an error occurs.
"""
return {function.method_id: function for function in self._functions}
return {
function.method_id: function
for function in self._functions
if not function.is_constructor
}

def marshal_to_python(self, computation, abi_type: list[str]) -> tuple[Any, ...]:
"""
Convert the output of a contract call to a Python object.
:param computation: the computation object returned by `execute_code`
:param abi_type: the ABI type of the return value.
"""
self._computation = computation
# when there's no contract in the address, the computation output is empty
if computation.is_error:
return self.handle_error(computation)
Expand All @@ -274,13 +298,17 @@ def stack_trace(self, computation: ComputationAPI) -> StackTrace:
"""
Create a stack trace for a failed contract call.
"""
reason = ""
if computation.is_error:
reason = " ".join(str(arg) for arg in computation.error.args if arg != b"")

calldata_method_id = bytes(computation.msg.data[:4])
if calldata_method_id in self.method_id_map:
function = self.method_id_map[calldata_method_id]
msg = f" ({self}.{function.pretty_signature})"
msg = f" {reason}({self}.{function.pretty_signature})"
else:
# Method might not be specified in the ABI
msg = f" (unknown method id {self}.0x{calldata_method_id.hex()})"
msg = f" {reason}(unknown method id {self}.0x{calldata_method_id.hex()})"

return_trace = StackTrace([msg])
return _handle_child_trace(computation, self.env, return_trace)
Expand All @@ -290,7 +318,7 @@ def deployer(self) -> "ABIContractFactory":
"""
Returns a factory that can be used to retrieve another deployed contract.
"""
return ABIContractFactory(self._name, self._abi, self._functions)
return ABIContractFactory(self._name, self._abi, filename=self.filename)

def __repr__(self):
file_str = f" (file {self.filename})" if self.filename else ""
Expand All @@ -305,36 +333,34 @@ class ABIContractFactory:
do any contract deployment.
"""

def __init__(
self,
name: str,
abi: dict,
functions: list[ABIFunction],
filename: Optional[str] = None,
):
def __init__(self, name: str, abi: list[dict], filename: Optional[str] = None):
self._name = name
self._abi = abi
self._functions = functions
self._filename = filename
self.filename = filename

@cached_property
def abi(self):
return deepcopy(self._abi)

@classmethod
def from_abi_dict(cls, abi, name="<anonymous contract>"):
functions = [
ABIFunction(item, name) for item in abi if item.get("type") == "function"
@cached_property
def functions(self):
return [
ABIFunction(item, self._name)
for item in self.abi
if item.get("type") == "function"
]
return cls(basename(name), abi, functions, filename=name)

@classmethod
def from_abi_dict(cls, abi, name="<anonymous contract>", filename=None):
return cls(name, abi, filename)

def at(self, address: Address | str) -> ABIContract:
"""
Create an ABI contract object for a deployed contract at `address`.
"""
address = Address(address)
contract = ABIContract(
self._name, self._abi, self._functions, address, self._filename
self._name, self._abi, self.functions, address, self.filename
)
contract.env.register_contract(address, contract)
return contract
Expand Down
7 changes: 3 additions & 4 deletions boa/contracts/base_evm_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,9 @@ def last_frame(self):


def _trace_for_unknown_contract(computation, env):
ret = StackTrace(
[f"<Unknown location in unknown contract {computation.msg.code_address.hex()}>"]
)
return _handle_child_trace(computation, env, ret)
err = f" <Unknown contract 0x{computation.msg.code_address.hex()}>"
trace = StackTrace([err])
return _handle_child_trace(computation, env, trace)


def _handle_child_trace(computation, env, return_trace):
Expand Down
12 changes: 12 additions & 0 deletions boa/contracts/vyper/vyper_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@


class VyperDeployer:
create_compiler_data = CompilerData # this may be a different class in plugins

def __init__(self, compiler_data, filename=None):
self.compiler_data = compiler_data

Expand Down Expand Up @@ -303,6 +305,11 @@ def _check(cond, msg=""):
assert len(args) == 1, "multiple args!"
assert len(kwargs) == 0, "can't mix args and kwargs!"
err = args[0]
if isinstance(frame, str):
# frame for unknown contracts is a string
_check(err in frame, f"{frame} does not match {args}")
return

# try to match anything
_check(
err == frame.pretty_vm_reason
Expand All @@ -315,6 +322,10 @@ def _check(cond, msg=""):
# try to match a specific kwarg
assert len(kwargs) == 1 and len(args) == 0

if isinstance(frame, str):
# frame for unknown contracts is a string
raise ValueError(f"expected {kwargs} but got {frame}")

# don't accept magic
if frame.dev_reason:
assert frame.dev_reason.reason_type not in ("vm_error", "compiler")
Expand Down Expand Up @@ -547,6 +558,7 @@ def _set_bytecode(self, bytecode: bytes) -> None:
to_check = bytecode
if self.data_section_size != 0:
to_check = bytecode[: -self.data_section_size]
assert isinstance(self.compiler_data, CompilerData)
if to_check != self.compiler_data.bytecode_runtime:
warnings.warn(
f"casted bytecode does not match compiled bytecode at {self}",
Expand Down
53 changes: 31 additions & 22 deletions boa/integrations/jupyter/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,16 @@ def send_transaction(self, tx_data: dict) -> dict:
)
return convert_frontend_dict(sign_data)

def sign_typed_data(
self, domain: dict[str, Any], types: dict[str, list], value: dict[str, Any]
) -> str:
def sign_typed_data(self, full_message: dict[str, Any]) -> str:
"""
Sign typed data value with types data structure for domain using the EIP-712 specification.
:param domain: The domain data structure.
:param types: The types data structure.
:param value: The value to sign.
:param full_message: The full message to sign.
:return: The signature.
"""
return _javascript_call(
"signTypedData",
domain,
types,
value,
"rpc",
"eth_signTypedData_v4",
[self.address, full_message],
timeout_message=TRANSACTION_TIMEOUT_MESSAGE,
)

Expand Down Expand Up @@ -141,18 +136,10 @@ def __init__(self, address=None, **kwargs):
self.signer = BrowserSigner(address)
self.set_eoa(self.signer)

def get_chain_id(self) -> int:
chain_id = _javascript_call(
"rpc", "eth_chainId", timeout_message=RPC_TIMEOUT_MESSAGE
)
return int.from_bytes(bytes.fromhex(chain_id[2:]), "big")

def set_chain_id(self, chain_id: int | str):
_javascript_call(
"rpc",
self._rpc.fetch(
"wallet_switchEthereumChain",
[{"chainId": chain_id if isinstance(chain_id, str) else hex(chain_id)}],
timeout_message=RPC_TIMEOUT_MESSAGE,
)
self._reset_fork()

Expand All @@ -169,7 +156,7 @@ def _javascript_call(js_func: str, *args, timeout_message: str) -> Any:
:return: The result of the Javascript snippet sent to the API.
"""
token = _generate_token()
args_str = ", ".join(json.dumps(p) for p in chain([token], args))
args_str = ", ".join(json.dumps(p, cls=_BytesEncoder) for p in chain([token], args))
js_code = f"window._titanoboa.{js_func}({args_str});"
if BrowserRPC._debug_mode:
logging.warning(f"Calling {js_func} with {args_str}")
Expand Down Expand Up @@ -224,9 +211,31 @@ def _parse_js_result(result: dict) -> Any:
if "data" in result:
return result["data"]

def _find_key(input_dict, target_key, typ) -> Any:
for key, value in input_dict.items():
if isinstance(value, dict):
found = _find_key(value, target_key, typ)
if found is not None:
return found
if key == target_key and isinstance(value, typ) and value != "error":
return value
return None

# raise the error in the Jupyter cell so that the user can see it
error = result["error"]
error = error.get("info", error).get("error", error)
error = error.get("data", error)
raise RPCError(
message=error.get("message", error), code=error.get("code", "CALLBACK_ERROR")
message=_find_key(error, "message", str) or _find_key(error, "error", str),
code=_find_key(error, "code", int) or -1,
)


class _BytesEncoder(json.JSONEncoder):
"""
A JSONEncoder that converts bytes to hex strings to be passed to JavaScript.
"""

def default(self, o):
if isinstance(o, bytes):
return "0x" + o.hex()
return super().default(o)
2 changes: 1 addition & 1 deletion boa/integrations/jupyter/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

NUL = b"\0"
CALLBACK_TOKEN_TIMEOUT = timedelta(minutes=3)
SHARED_MEMORY_LENGTH = 50 * 1024 + len(NUL) # Size of the shared memory object
SHARED_MEMORY_LENGTH = 100 * 1024 + len(NUL) # Size of the shared memory object
CALLBACK_TOKEN_CHARS = 30 # OSx limits this to 31 characters
PLUGIN_NAME = "titanoboa_jupyterlab"
TOKEN_REGEX = rf"[0-9a-fA-F]{{{CALLBACK_TOKEN_CHARS}}}"
Expand Down
11 changes: 1 addition & 10 deletions boa/integrations/jupyter/jupyter.js
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,14 @@
return response.text();
}

let from;
const loadSigner = async (address) => {
const accounts = await rpc('eth_requestAccounts');
from = accounts.includes(address) ? address : accounts[0];
return from;
return accounts.includes(address) ? address : accounts[0];
};

/** Sign a transaction via ethers */
const sendTransaction = async transaction => ({"hash": await rpc('eth_sendTransaction', [transaction])});

/** Sign a typed data via ethers */
const signTypedData = (domain, types, value) => rpc(
'eth_signTypedData_v4',
[from, JSON.stringify({domain, types, value})]
);

/** Wait until the transaction is mined */
const waitForTransactionReceipt = async (tx_hash, timeout, poll_latency) => {
while (true) {
Expand Down Expand Up @@ -120,7 +112,6 @@
window._titanoboa = {
loadSigner: handleCallback(loadSigner),
sendTransaction: handleCallback(sendTransaction),
signTypedData: handleCallback(signTypedData),
waitForTransactionReceipt: handleCallback(waitForTransactionReceipt),
rpc: handleCallback(rpc),
multiRpc: handleCallback(multiRpc),
Expand Down
Loading

0 comments on commit 2a0065d

Please sign in to comment.