Skip to content

Commit

Permalink
fix: trace API fixes (#2093)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed May 31, 2024
1 parent 17bce2f commit b9d254f
Show file tree
Hide file tree
Showing 9 changed files with 344 additions and 517 deletions.
6 changes: 3 additions & 3 deletions src/ape/api/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,12 +573,12 @@ def enrich_trace(self, trace: "TraceAPI", **kwargs) -> "TraceAPI":
Enhance the data in the call tree using information about the ecosystem.
Args:
call (:class:`~ape.api.trace.TraceAPI`): The trace to enrich.
kwargs: Additional kwargs to control enrichment, defined at the
trace (:class:`~ape.api.trace.TraceAPI`): The trace to enrich.
**kwargs: Additional kwargs to control enrichment, defined at the
plugin level.
Returns:
:class:`~ape.types.trace.CallTreeNode`
:class:`~ape.api.trace.TraceAPI`
"""
return trace

Expand Down
119 changes: 76 additions & 43 deletions src/ape_ethereum/ecosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
ProxyInfo,
ProxyType,
)
from ape_ethereum.trace import Trace, TransactionTrace
from ape_ethereum.trace import _REVERT_PREFIX, Trace, TransactionTrace
from ape_ethereum.transactions import (
AccessListTransaction,
BaseTransaction,
Expand Down Expand Up @@ -1003,6 +1003,7 @@ def get_abi(_topic: HexStr) -> Optional[LogInputABICollection]:
)

def enrich_trace(self, trace: TraceAPI, **kwargs) -> TraceAPI:
kwargs["trace"] = trace
if not isinstance(trace, Trace):
return trace

Expand All @@ -1016,10 +1017,10 @@ def enrich_trace(self, trace: TraceAPI, **kwargs) -> TraceAPI:
# Get the un-enriched calltree.
data = trace.get_calltree().model_dump(mode="json", by_alias=True)

# Return value was discovered already.
if isinstance(trace, TransactionTrace):
return_value = trace.__dict__.get("return_value") if data.get("depth", 0) == 0 else None
if return_value is not None:
# Return value was discovered already.
kwargs["return_value"] = return_value

enriched_calltree = self._enrich_calltree(data, **kwargs)
Expand Down Expand Up @@ -1052,7 +1053,9 @@ def _enrich_calltree(self, call: dict, **kwargs) -> dict:
# Figure out the contract.
address = call.pop("address", "")
try:
call["contract_id"] = address = str(self.decode_address(address))
call["contract_id"] = address = kwargs["contract_address"] = str(
self.decode_address(address)
)
except Exception:
# Tx was made with a weird address.
call["contract_id"] = address
Expand Down Expand Up @@ -1122,6 +1125,10 @@ def _enrich_calltree(self, call: dict, **kwargs) -> dict:
# For constructors, don't include outputs, as it is likely a large amount of bytes.
call["returndata"] = None

elif "revert_message" not in call:
# Method not found but perhaps we still know the error.
call = self._enrich_revert_message(call)

return call

def _enrich_contract_id(self, address: AddressType, **kwargs) -> str:
Expand Down Expand Up @@ -1199,12 +1206,36 @@ def _enrich_returndata(self, call: dict, method_abi: MethodABI, **kwargs) -> dic
call["returndata"] = ""
return call

elif "revert_message" in call:
# Already enriched, in a sense..
return call

default_return_value = "<?>"
returndata = call.get("returndata")
returndata = call.get("returndata", "")
is_hexstr = isinstance(returndata, str) and is_0x_prefixed(returndata)
return_value_bytes = None

if (
returndata and isinstance(returndata, str) and is_0x_prefixed(returndata)
) or isinstance(returndata, (int, bytes)):
# Check if return is only a revert string.
call = self._enrich_revert_message(call)
if "revert_message" in call:
return call

elif is_hexstr:
return_value_bytes = HexBytes(returndata)

# Check if custom-error.
if "trace" in kwargs and "contract_address" in kwargs:
address = kwargs["contract_address"]
try:
instance = self.decode_custom_error(return_value_bytes, address, **kwargs)
except NotImplementedError:
pass
else:
if instance is not None:
call["revert_message"] = repr(instance)
return call

elif is_hexstr or isinstance(returndata, (int, bytes)):
return_value_bytes = HexBytes(returndata)
else:
return_value_bytes = None
Expand Down Expand Up @@ -1246,6 +1277,16 @@ def _enrich_returndata(self, call: dict, method_abi: MethodABI, **kwargs) -> dic
call["returndata"] = output_val
return call

def _enrich_revert_message(self, call: dict) -> dict:
returndata = call.get("returndata", "")
is_hexstr = isinstance(returndata, str) and is_0x_prefixed(returndata)
if is_hexstr and returndata.startswith(_REVERT_PREFIX):
# The returndata is the revert-str.
decoded_result = decode(("string",), HexBytes(returndata)[4:])
call["revert_message"] = decoded_result[0] if len(decoded_result) == 1 else ""

return call

def get_python_types(self, abi_type: ABIType) -> Union[type, Sequence]:
return self._python_type_for_abi_type(abi_type)

Expand All @@ -1264,51 +1305,43 @@ def decode_custom_error(
selector = data[:4]
input_data = data[4:]

abi = None
if selector not in contract.contract_type.errors:
# ABI not found. Try looking at the "last" contract.
if not (tx := kwargs.get("txn")) or not self.network_manager.active_provider:
return None

try:
tx_hash = tx.txn_hash
except SignatureError:
return None

if not (last_addr := self._get_last_address_from_trace(tx_hash)):
return None

if last_addr == address:
# Avoid checking same address twice.
return None
if selector in contract.contract_type.errors:
abi = contract.contract_type.errors[selector]
error_cls = contract.get_error_by_signature(abi.signature)
inputs = self.decode_calldata(abi, input_data)
kwargs["contract_address"] = address
error_kwargs = {
k: v
for k, v in kwargs.items()
if k in ("trace", "txn", "contract_address", "source_traceback")
}
return error_cls(abi, inputs, **error_kwargs)

try:
if not (cerr := self.decode_custom_error(data, last_addr)):
return cerr
except NotImplementedError:
return None
# ABI not found. Try looking at the "last" contract.
if not (tx := kwargs.get("txn")) or not self.network_manager.active_provider:
return None

# error never found.
try:
tx_hash = tx.txn_hash
except SignatureError:
return None

abi = contract.contract_type.errors[selector]
error_cls = contract.get_error_by_signature(abi.signature)
inputs = self.decode_calldata(abi, input_data)
kwargs["contract_address"] = address
return error_cls(abi, inputs, **kwargs)
trace = kwargs.get("trace") or self.provider.get_transaction_trace(tx_hash)
if not (last_addr := next(trace.get_addresses_used(reverse=True), None)):
return None

def _get_last_address_from_trace(self, txn_hash: Union[str, HexBytes]) -> Optional[AddressType]:
try:
trace = list(self.chain_manager.provider.get_transaction_trace(txn_hash))
except Exception:
if last_addr == address:
# Avoid checking same address twice.
return None

for frame in trace[::-1]:
if not (addr := frame.contract_address):
continue
try:
if cerr := self.decode_custom_error(data, last_addr, **kwargs):
return cerr

return addr
except NotImplementedError:
return None

# error never found.
return None


Expand Down
32 changes: 16 additions & 16 deletions src/ape_ethereum/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def get_transaction_trace(self, transaction_hash: str, **kwargs) -> TraceAPI:
if "call_trace_approach" not in kwargs:
kwargs["call_trace_approach"] = self._call_trace_approach

return self._get_transaction_trace(transaction_hash, **kwargs)
return TransactionTrace(transaction_hash=transaction_hash, **kwargs)

def send_call(
self,
Expand Down Expand Up @@ -994,6 +994,21 @@ def make_request(self, rpc: str, parameters: Optional[Iterable] = None) -> Any:

return result

def stream_request(self, method: str, params: Iterable, iter_path: str = "result.item"):
if not (uri := self.http_uri):
raise ProviderError("This provider has no HTTP URI and is unable to stream requests.")

payload = {"jsonrpc": "2.0", "id": 1, "method": method, "params": params}
results = ijson.sendable_list()
coroutine = ijson.items_coro(results, iter_path)
resp = requests.post(uri, json=payload, stream=True)
resp.raise_for_status()

for chunk in resp.iter_content(chunk_size=2**17):
coroutine.send(chunk)
yield from results
del results[:]

def create_access_list(
self, transaction: TransactionAPI, block_id: Optional[BlockID] = None
) -> list[AccessList]:
Expand Down Expand Up @@ -1125,9 +1140,6 @@ def _handle_execution_reverted(

return enriched

def _get_transaction_trace(self, transaction_hash: str, **kwargs) -> TraceAPI:
return TransactionTrace(transaction_hash=transaction_hash, **kwargs)


class EthereumNodeProvider(Web3Provider, ABC):
# optimal values for geth
Expand Down Expand Up @@ -1289,18 +1301,6 @@ def _get_contract_creation_receipt(self, address: AddressType) -> Optional[Recei

return None

def stream_request(self, method: str, params: Iterable, iter_path: str = "result.item"):
payload = {"jsonrpc": "2.0", "id": 1, "method": method, "params": params}
results = ijson.sendable_list()
coroutine = ijson.items_coro(results, iter_path)
resp = requests.post(self.uri, json=payload, stream=True)
resp.raise_for_status()

for chunk in resp.iter_content(chunk_size=2**17):
coroutine.send(chunk)
yield from results
del results[:]

def connect(self):
self._set_web3()
if not self.is_connected:
Expand Down
Loading

0 comments on commit b9d254f

Please sign in to comment.