Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prefetch state checks #233

Merged
merged 13 commits into from
May 31, 2024
52 changes: 29 additions & 23 deletions boa/vm/fork.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,14 @@ def _mk_key(self, method: str, params: Any) -> Any:
return json.dumps({"method": method, "params": params}).encode("utf-8")

def fetch(self, method, params):
# dispatch into fetch_multi for caching behavior.
(res,) = self.fetch_multi([(method, params)])
return res
# cannot dispatch into fetch_multi, doesn't work for debug_traceCall.
key = self._mk_key(method, params)
if key in self._db:
return json.loads(self._db[key])

result = self._rpc.fetch(method, params)
self._db[key] = json.dumps(result).encode("utf-8")
return result

def fetch_uncached(self, method, params):
return self._rpc.fetch_uncached(method, params)
Expand Down Expand Up @@ -214,20 +219,15 @@ def try_prefetch_state(self, msg: Message):
"data": msg.data,
}
)
# TODO: skip debug_traceCall if we have seen these specific
# arguments with this specific block before
try:
tracer = {"tracer": "prestateTracer"}
res = self._rpc.fetch_uncached(
"debug_traceCall", [args, self._block_id, tracer]
)
trace_args = [args, self._block_id, {"tracer": "prestateTracer"}]
trace = self._rpc.fetch("debug_traceCall", trace_args)
except (RPCError, HTTPError):
return

snapshot = self.record()

# everything is returned in hex
for address, v in res.items():
for address, account_trace in trace.items():
DanielSchiavini marked this conversation as resolved.
Show resolved Hide resolved
try:
address = to_canonical_address(address)
except ValueError:
Expand All @@ -236,14 +236,18 @@ def try_prefetch_state(self, msg: Message):
return

# set account if we don't already have it
if self._get_account_helper(address) is None:
balance = to_int(v.get("balance", "0x"))
code = to_bytes(v.get("code", "0x"))
nonce = v.get("nonce", 0) # already an int
account_helper = self._get_account_helper(address)
if account_helper is None:
DanielSchiavini marked this conversation as resolved.
Show resolved Hide resolved
balance = to_int(account_trace.get("balance", "0x"))
code = to_bytes(account_trace.get("code", "0x"))
nonce = account_trace.get("nonce", 0) # already an int
self._set_account(address, Account(nonce=nonce, balance=balance))
self.set_code(address, code)
self._dirty_accounts.add(address)
else:
self._account_cache[address] = account_helper

storage = v.get("storage", dict())
DanielSchiavini marked this conversation as resolved.
Show resolved Hide resolved
storage = account_trace.get("storage", dict())

account_store = super()._get_address_store(address)
for hexslot, hexvalue in storage.items():
Expand All @@ -256,16 +260,20 @@ def try_prefetch_state(self, msg: Message):
key = int_to_big_endian(slot)
if not self._helper_have_storage(address, slot):
account_store._journal_storage[key] = rlp.encode(value) # type: ignore
self.lock_changes()
DanielSchiavini marked this conversation as resolved.
Show resolved Hide resolved

self.commit(snapshot)

def get_code(self, address):
try:
return super().get_code(address)
except MissingBytecode: # will get thrown if code_hash != hash(empty)
ret = self._rpc.fetch(
"eth_getCode", [to_checksum_address(address), self._block_id]
code = to_bytes(
self._rpc.fetch(
DanielSchiavini marked this conversation as resolved.
Show resolved Hide resolved
"eth_getCode", [to_checksum_address(address), self._block_id]
)
)
return to_bytes(ret)
self.set_code(address, code)
return code

def discard(self, checkpoint):
super().discard(checkpoint)
Expand All @@ -276,9 +284,7 @@ def commit(self, checkpoint):
self._dontfetch.commit(checkpoint)

def record(self):
checkpoint = super().record()
self._dontfetch.record(checkpoint)
return checkpoint
return self._dontfetch.record(super().record())

# helper to determine if something is in the storage db
# or we need to get from RPC
Expand Down
7 changes: 3 additions & 4 deletions boa/vm/py_evm.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,9 @@ def fork_rpc(self, rpc: RPC, block_identifier: str, **kwargs):

@property
def is_forked(self):
return self.vm.__class__._state_class.account_db_class == AccountDBFork

def _set_account_db_class(self, account_db_class: type):
self.vm.__class__._state_class.account_db_class = account_db_class
return issubclass(
self.vm.__class__._state_class.account_db_class, AccountDBFork
)

def get_gas_meter_class(self):
return self.vm.state.computation_class._gas_meter_class
Expand Down
10 changes: 7 additions & 3 deletions tests/integration/fork/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
from boa.environment import Env


@pytest.fixture(scope="module")
def rpc_url():
return os.environ.get("MAINNET_ENDPOINT") or "http://localhost:8545"


# run all tests with this forked environment
# called as fixture for its side effects
@pytest.fixture(scope="module", autouse=True)
def forked_env():
def forked_env(rpc_url):
with boa.swap_env(Env()):
fork_uri = os.environ.get("MAINNET_ENDPOINT") or "http://localhost:8545"
block_id = 18801970 # some block we know the state of
boa.env.fork(fork_uri, block_identifier=block_id)
boa.env.fork(rpc_url, block_identifier=block_id)
yield
41 changes: 41 additions & 0 deletions tests/integration/fork/test_from_etherscan.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
from unittest.mock import patch

import pytest
from eth.vm.message import Message
from vyper.utils import method_id

import boa
from boa import Env

crvusd = "0xf939E0A03FB07F59A73314E73794Be0E57ac1b4E"
voting_agent = "0xE478de485ad2fe566d49342Cbd03E49ed7DB3356"
Expand Down Expand Up @@ -33,3 +37,40 @@ def test_proxy_contract(proxy_contract):
assert proxy_contract.minTime() == 43200
assert proxy_contract.voteTime() == 604800
assert proxy_contract.minBalance() == 2500000000000000000000


@pytest.mark.parametrize("fresh_env", [True, False])
def test_prefetch_state(proxy_contract, rpc_url, fresh_env):
env = boa.env
if fresh_env:
env = Env()
env.fork(rpc_url)

msg = Message(
to=proxy_contract.address.canonical_address,
sender=env.eoa.canonical_address,
gas=0,
value=0,
code=proxy_contract._bytecode,
data=method_id("minTime()"),
)
db = env.evm.vm.state._account_db
db.try_prefetch_state(msg)
account = db._account_cache[proxy_contract.address.canonical_address]
assert db._journaldb[account.code_hash] == proxy_contract._bytecode


@pytest.mark.parametrize("prefetch", [True, False])
def test_prefetch_state_called_on_message(proxy_contract, prefetch):
boa.env.evm._fork_try_prefetch_state = prefetch
with patch("boa.vm.fork.AccountDBFork.try_prefetch_state") as mock:
DanielSchiavini marked this conversation as resolved.
Show resolved Hide resolved
proxy_contract.minTime()
assert mock.called == prefetch


@pytest.mark.parametrize("prefetch", [True, False])
def test_prefetch_state_called_on_deploy(proxy_contract, prefetch):
boa.env.evm._fork_try_prefetch_state = prefetch
with patch("boa.vm.fork.AccountDBFork.try_prefetch_state") as mock:
boa.loads("")
assert mock.called == prefetch
Loading