diff --git a/boa/contracts/abi/abi_contract.py b/boa/contracts/abi/abi_contract.py index 47f8b382..2579f68f 100644 --- a/boa/contracts/abi/abi_contract.py +++ b/boa/contracts/abi/abi_contract.py @@ -124,6 +124,10 @@ def __call__(self, *args, value=0, gas=None, sender=None, **kwargs): if not self.contract or not self.contract.env: raise Exception(f"Cannot call {self} without deploying contract.") + override_bytecode = None + if hasattr(self, "_override_bytecode"): + override_bytecode = self._override_bytecode + computation = self.contract.env.execute_code( to_address=self.contract.address, sender=sender, @@ -132,6 +136,7 @@ def __call__(self, *args, value=0, gas=None, sender=None, **kwargs): gas=gas, is_modifying=self.is_mutable, contract=self.contract, + override_bytecode=override_bytecode, ) match self.contract.marshal_to_python(computation, self.return_type): diff --git a/boa/contracts/vvm/vvm_contract.py b/boa/contracts/vvm/vvm_contract.py index 6eedc2b8..b6a2ba8a 100644 --- a/boa/contracts/vvm/vvm_contract.py +++ b/boa/contracts/vvm/vvm_contract.py @@ -1,51 +1,82 @@ import re from functools import cached_property +from pathlib import Path +from typing import Any, Optional -from boa.contracts.abi.abi_contract import ABIContractFactory, ABIFunction +import vvm +from vyper.utils import method_id + +from boa.contracts.abi.abi_contract import ABIContract, ABIContractFactory, ABIFunction from boa.environment import Env +from boa.rpc import to_bytes +from boa.util.abi import Address +from boa.util.disk_cache import get_disk_cache from boa.util.eip5202 import generate_blueprint_bytecode -# TODO: maybe this doesn't detect release candidates -VERSION_RE = re.compile(r"\s*#\s*(pragma\s+version|@version)\s+(\d+\.\d+\.\d+)") +def _compile_source(*args, **kwargs) -> Any: + """ + Compile Vyper source code via the VVM. + When a disk cache is available, the result of the compilation is cached. + """ + disk_cache = get_disk_cache() + + def _compile(): + return vvm.compile_source(*args, **kwargs) + + if disk_cache is None: + return _compile() -# TODO: maybe move this up to vvm? -def _detect_version(source_code: str): - res = VERSION_RE.findall(source_code) - if len(res) < 1: - return None - # TODO: handle len(res) > 1 - return res[0][1] + cache_key = f"{args}{kwargs}" + return disk_cache.caching_lookup(cache_key, _compile) -class VVMDeployer: +class VVMDeployer(ABIContractFactory): """ A deployer that uses the Vyper Version Manager (VVM). This allows deployment of contracts written in older versions of Vyper that can interact with new versions using the ABI definition. """ - def __init__(self, abi, bytecode, filename): + def __init__( + self, + name: str, + compiler_output: dict, + source_code: str, + vyper_version: str, + filename: Optional[str] = None, + ): """ Initialize a VVMDeployer instance. - :param abi: The contract's ABI. - :param bytecode: The contract's bytecode. + :param name: The name of the contract. + :param compiler_output: The compiler output of the contract. + :param source_code: The source code of the contract. + :param vyper_version: The Vyper version used to compile the contract. :param filename: The filename of the contract. """ - self.abi = abi - self.bytecode = bytecode - self.filename = filename + super().__init__(name, compiler_output["abi"], filename) + self.compiler_output = compiler_output + self.source_code = source_code + self.vyper_version = vyper_version + + @cached_property + def bytecode(self): + return to_bytes(self.compiler_output["bytecode"]) @classmethod - def from_compiler_output(cls, compiler_output, filename): - abi = compiler_output["abi"] - bytecode_nibbles = compiler_output["bytecode"] - bytecode = bytes.fromhex(bytecode_nibbles.removeprefix("0x")) - return cls(abi, bytecode, filename) + def from_source_code( + cls, + source_code: str, + vyper_version: str, + filename: Optional[str] = None, + name: Optional[str] = None, + ): + if name is None: + name = Path(filename).stem if filename is not None else "" + compiled_src = _compile_source(source_code, vyper_version=vyper_version) + compiler_output = compiled_src[""] - @cached_property - def factory(self): - return ABIContractFactory.from_abi_dict(self.abi) + return cls(name, compiler_output, source_code, vyper_version, filename) @cached_property def constructor(self): @@ -97,5 +128,249 @@ def deploy_as_blueprint(self, env=None, blueprint_preamble=None, **kwargs): def __call__(self, *args, **kwargs): return self.deploy(*args, **kwargs) - def at(self, address): - return self.factory.at(address) + def at(self, address: Address | str) -> "VVMContract": + """ + Create an ABI contract object for a deployed contract at `address`. + """ + address = Address(address) + contract = VVMContract( + compiler_output=self.compiler_output, + source_code=self.source_code, + vyper_version=self.vyper_version, + name=self._name, + abi=self._abi, + functions=self.functions, + address=address, + filename=self.filename, + ) + contract.env.register_contract(address, contract) + return contract + + +class VVMContract(ABIContract): + """ + A deployed contract compiled with vvm, which is called via ABI. + """ + + def __init__(self, compiler_output, source_code, vyper_version, *args, **kwargs): + super().__init__(*args, **kwargs) + self.compiler_output = compiler_output + self.source_code = source_code + self.vyper_version = vyper_version + + @cached_property + def bytecode(self): + return to_bytes(self.compiler_output["bytecode"]) + + @cached_property + def bytecode_runtime(self): + return to_bytes(self.compiler_output["bytecode_runtime"]) + + def inject_function(self, fn_source_code, force=False): + """ + Inject a function into this VVM Contract without affecting the + contract's source code. useful for testing private functionality. + :param fn_source_code: The source code of the function to inject. + :param force: If True, the function will be injected even if it already exists. + :returns: The result of the statement evaluation. + """ + fn = VVMInjectedFunction(fn_source_code, self) + if hasattr(self, fn.name) and not force: + raise ValueError(f"Function {fn.name} already exists on contract.") + setattr(self, fn.name, fn) + fn.contract = self + + @cached_property + def _storage(self): + """ + Allows access to the storage variables of the contract. + Note that this is quite slow, as it requires the complete contract to be + recompiled. + """ + + def storage(): + return None + + for name, spec in self.compiler_output["layout"]["storage_layout"].items(): + setattr(storage, name, VVMStorageVariable(name, spec, self)) + return storage + + @cached_property + def internal(self): + """ + Allows access to internal functions of the contract. + Note that this is quite slow, as it requires the complete contract to be + recompiled. + """ + + # an object with working setattr + def _obj(): + return None + + result = _compile_source( + self.source_code, vyper_version=self.vyper_version, output_format="metadata" + )["function_info"] + for fn_name, meta in result.items(): + if meta["visibility"] == "internal": + function = VVMInternalFunction(meta, self) + setattr(_obj, function.name, function) + return _obj + + +class _VVMInternal(ABIFunction): + """ + An ABI function that temporarily changes the bytecode at the contract's address. + Subclasses of this class are used to inject code into the contract via the + `source_code` property using the vvm, temporarily changing the bytecode + at the contract's address. + """ + + @cached_property + def _override_bytecode(self) -> bytes: + return to_bytes(self._compiler_output["bytecode_runtime"]) + + @cached_property + def _compiler_output(self): + assert isinstance(self.contract, VVMContract) # help mypy + source = "\n".join((self.contract.source_code, self.source_code)) + compiled = _compile_source(source, vyper_version=self.contract.vyper_version) + return compiled[""] + + @property + def source_code(self) -> str: + """ + Returns the source code an internal function. + Must be implemented in subclasses. + """ + raise NotImplementedError + + +class VVMInternalFunction(_VVMInternal): + """ + An internal function that is made available via the `internal` namespace. + It will temporarily change the bytecode at the contract's address. + """ + + def __init__(self, meta: dict, contract: VVMContract): + abi = { + "anonymous": False, + "inputs": [ + {"name": arg_name, "type": arg_type} + for arg_name, arg_type in meta["positional_args"].items() + ], + "outputs": ( + [{"name": meta["name"], "type": meta["return_type"]}] + if meta["return_type"] != "None" + else [] + ), + "stateMutability": meta["mutability"], + "name": meta["name"], + "type": "function", + } + super().__init__(abi, contract.contract_name) + self.contract = contract + + @cached_property + def method_id(self) -> bytes: + return method_id(f"__boa_internal_{self.name}__" + self.signature) + + @cached_property + def source_code(self): + fn_args = ", ".join([arg["name"] for arg in self._abi["inputs"]]) + + return_sig = "" + fn_call = "" + if self.return_type: + return_sig = f" -> {self.return_type}" + fn_call = "return " + + fn_call += f"self.{self.name}({fn_args})" + fn_sig = ", ".join( + f"{arg['name']}: {arg['type']}" for arg in self._abi["inputs"] + ) + return f""" +@external +@payable +def __boa_internal_{self.name}__({fn_sig}){return_sig}: + {fn_call} +""" + + +class VVMStorageVariable(_VVMInternal): + """ + A storage variable that is made available via the `storage` namespace. + It will temporarily change the bytecode at the contract's address. + """ + + def __init__(self, name, spec, contract): + inputs, output_type = _get_storage_variable_types(spec) + abi = { + "anonymous": False, + "inputs": inputs, + "outputs": [{"name": name, "type": output_type}], + "name": name, + "type": "function", + } + super().__init__(abi, contract.contract_name) + self.contract = contract + + def get(self, *args): + # get the value of the storage variable. note that this is + # different from the behavior of VyperContract storage variables! + return self.__call__(*args) + + @cached_property + def method_id(self) -> bytes: + return method_id(f"__boa_private_{self.name}__" + self.signature) + + @cached_property + def source_code(self): + getter_call = "".join(f"[{i['name']}]" for i in self._abi["inputs"]) + args_signature = ", ".join( + f"{i['name']}: {i['type']}" for i in self._abi["inputs"] + ) + return f""" +@external +@payable +def __boa_private_{self.name}__({args_signature}) -> {self.return_type[0]}: + return self.{self.name}{getter_call} +""" + + +class VVMInjectedFunction(_VVMInternal): + """ + A Vyper function that is injected into a VVM contract. + It will temporarily change the bytecode at the contract's address. + """ + + def __init__(self, source_code: str, contract: VVMContract): + self.contract = contract + self._source_code = source_code + abi = [i for i in self._compiler_output["abi"] if i not in contract.abi] + if len(abi) != 1: + err = "Expected exactly one new ABI entry after injecting function. " + err += f"Found {abi}." + raise ValueError(err) + + super().__init__(abi[0], contract.contract_name) + + @cached_property + def source_code(self): + return self.code + + +def _get_storage_variable_types(spec: dict) -> tuple[list[dict], str]: + """ + Get the types of a storage variable + :param spec: The storage variable specification. + :return: The types of the storage variable: + 1. A list of dictionaries containing the input types. + 2. The output type name. + """ + hashmap_regex = re.compile(r"^HashMap\[([^[]+), (.+)]$") + output_type = spec["type"] + inputs: list[dict] = [] + while output_type.startswith("HashMap"): + key_type, output_type = hashmap_regex.match(output_type).groups() # type: ignore + inputs.append({"name": f"key{len(inputs)}", "type": key_type}) + return inputs, output_type diff --git a/boa/interpret.py b/boa/interpret.py index 17136d3c..a768c67e 100644 --- a/boa/interpret.py +++ b/boa/interpret.py @@ -4,10 +4,11 @@ from importlib.machinery import SourceFileLoader from importlib.util import spec_from_loader from pathlib import Path -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Optional, Union import vvm import vyper +from packaging.version import Version from vyper.ast.parse import parse_to_ast from vyper.cli.vyper_compile import get_search_paths from vyper.compiler.input_bundle import ( @@ -23,7 +24,7 @@ from vyper.utils import sha256sum from boa.contracts.abi.abi_contract import ABIContractFactory -from boa.contracts.vvm.vvm_contract import VVMDeployer, _detect_version +from boa.contracts.vvm.vvm_contract import VVMDeployer from boa.contracts.vyper.vyper_contract import ( VyperBlueprint, VyperContract, @@ -33,15 +34,14 @@ from boa.explorer import Etherscan, get_etherscan from boa.rpc import json from boa.util.abi import Address -from boa.util.disk_cache import DiskCache +# export set_cache_dir, NOTE: consider moving to boa/__init__.py +from boa.util.disk_cache import get_disk_cache, set_cache_dir # noqa: F401 if TYPE_CHECKING: from vyper.semantics.analysis.base import ImportInfo _Contract = Union[VyperContract, VyperBlueprint] - -_disk_cache = None _search_path = None @@ -50,22 +50,6 @@ def set_search_path(path: list[str]): _search_path = path -def set_cache_dir(cache_dir="~/.cache/titanoboa"): - global _disk_cache - if cache_dir is None: - _disk_cache = None - return - compiler_version = f"{vyper.__version__}.{vyper.__commit__}" - _disk_cache = DiskCache(cache_dir, compiler_version) - - -def disable_cache(): - set_cache_dir(None) - - -set_cache_dir() # enable caching, by default! - - class BoaImporter(MetaPathFinder): def find_spec(self, fullname, path, target=None): path = Path(fullname.replace(".", "/")).with_suffix(".vy") @@ -131,7 +115,7 @@ def get_module_fingerprint( def compiler_data( source_code: str, contract_name: str, filename: str | Path, deployer=None, **kwargs ) -> CompilerData: - global _disk_cache, _search_path + global _search_path path = Path(contract_name) resolved_path = Path(filename).resolve(strict=False) @@ -145,7 +129,9 @@ def compiler_data( settings = Settings(**kwargs) ret = CompilerData(file_input, input_bundle, settings) - if _disk_cache is None: + + disk_cache = get_disk_cache() + if disk_cache is None: return ret with anchor_settings(ret.settings): @@ -165,7 +151,7 @@ def get_compiler_data(): assert isinstance(deployer, type) or deployer is None deployer_id = repr(deployer) # a unique str identifying the deployer class cache_key = str((contract_name, fingerprint, kwargs, deployer_id)) - return _disk_cache.caching_lookup(cache_key, get_compiler_data) + return disk_cache.caching_lookup(cache_key, get_compiler_data) def load(filename: str | Path, *args, **kwargs) -> _Contract: # type: ignore @@ -239,21 +225,18 @@ def loads_partial( dedent: bool = True, compiler_args: dict = None, ) -> VyperDeployer: - name = name or "VyperContract" - filename = filename or "" - + name = name or "VyperContract" # TODO handle this upstream in CompilerData if dedent: source_code = textwrap.dedent(source_code) - version = _detect_version(source_code) - if version is not None and version != vyper.__version__: - filename = str(filename) # help mypy - # TODO: pass name to loads_partial_vvm, not filename - return _loads_partial_vvm(source_code, version, filename) + version = vvm.detect_vyper_version_from_source(source_code) + if version is not None and version != Version(vyper.__version__): + return _loads_partial_vvm(source_code, str(version), filename, name) compiler_args = compiler_args or {} deployer_class = _get_default_deployer_class() + filename = filename or "" data = compiler_data(source_code, name, filename, deployer_class, **compiler_args) return deployer_class(data, filename=filename) @@ -265,25 +248,19 @@ def load_partial(filename: str, compiler_args=None): ) -def _loads_partial_vvm(source_code: str, version: str, filename: str): - global _disk_cache +def _loads_partial_vvm( + source_code: str, + version: str, + filename: Optional[str | Path] = None, + name: Optional[str] = None, +): + if filename is not None: + filename = str(filename) # install the requested version if not already installed vvm.install_vyper(version=version) - def _compile(): - compiled_src = vvm.compile_source(source_code, vyper_version=version) - compiler_output = compiled_src[""] - return VVMDeployer.from_compiler_output(compiler_output, filename=filename) - - # Ensure the cache is initialized - if _disk_cache is None: - return _compile() - - # Generate a unique cache key - cache_key = f"{source_code}:{version}" - # Check the cache and return the result if available - return _disk_cache.caching_lookup(cache_key, _compile) + return VVMDeployer.from_source_code(source_code, version, filename, name) def from_etherscan( diff --git a/boa/util/disk_cache.py b/boa/util/disk_cache.py index 6912ec51..19160f00 100644 --- a/boa/util/disk_cache.py +++ b/boa/util/disk_cache.py @@ -5,6 +5,9 @@ import threading import time from pathlib import Path +from typing import Optional + +import vyper _ONE_WEEK = 7 * 24 * 3600 @@ -77,3 +80,26 @@ def caching_lookup(self, string, func): # because worst case we will just rebuild the item tmp_p.rename(p) return res + + +_disk_cache = None + + +def get_disk_cache() -> Optional[DiskCache]: + return _disk_cache + + +def set_cache_dir(cache_dir: Optional[str] = "~/.cache/titanoboa"): + global _disk_cache + if cache_dir is None: + _disk_cache = None + return + compiler_version = f"{vyper.__version__}.{vyper.__commit__}" + _disk_cache = DiskCache(cache_dir, compiler_version) + + +def disable_cache(): + set_cache_dir(None) + + +set_cache_dir() # enable caching, by default! diff --git a/pyproject.toml b/pyproject.toml index 323c21b4..8dd3c559 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "pytest-cov", # required to compile older versions of vyper - "vvm", + "vvm>=0.3.1", # eth-rlp requirement, not installed by default with 3.12 "typing-extensions", diff --git a/tests/unitary/contracts/vvm/mock_3_10.vy b/tests/unitary/contracts/vvm/mock_3_10.vy index bd23c5a7..28e5e909 100644 --- a/tests/unitary/contracts/vvm/mock_3_10.vy +++ b/tests/unitary/contracts/vvm/mock_3_10.vy @@ -3,7 +3,20 @@ foo: public(uint256) bar: public(uint256) +hash_map: HashMap[address, HashMap[uint8, uint256]] +is_empty: bool + @external def __init__(bar: uint256): self.foo = 42 self.bar = bar + self.is_empty = True + +@external +def set_map(x: uint256): + self._set_map(msg.sender, x) + +@internal +def _set_map(addr: address, x: uint256): + self.hash_map[addr][0] = x + self.is_empty = False diff --git a/tests/unitary/contracts/vvm/test_vvm.py b/tests/unitary/contracts/vvm/test_vvm.py index 3f10106d..ab29ea79 100644 --- a/tests/unitary/contracts/vvm/test_vvm.py +++ b/tests/unitary/contracts/vvm/test_vvm.py @@ -1,6 +1,10 @@ +import pytest + import boa mock_3_10_path = "tests/unitary/contracts/vvm/mock_3_10.vy" +with open(mock_3_10_path) as f: + mock_3_10_code = f.read() def test_load_partial_vvm(): @@ -12,10 +16,7 @@ def test_load_partial_vvm(): def test_loads_partial_vvm(): - with open(mock_3_10_path) as f: - code = f.read() - - contract_deployer = boa.loads_partial(code) + contract_deployer = boa.loads_partial(mock_3_10_code) contract = contract_deployer.deploy(43) assert contract.foo() == 42 @@ -30,15 +31,58 @@ def test_load_vvm(): def test_loads_vvm(): - with open(mock_3_10_path) as f: - code = f.read() - - contract = boa.loads(code, 43) + contract = boa.loads(mock_3_10_code, 43) assert contract.foo() == 42 assert contract.bar() == 43 +def test_vvm_storage(): + contract = boa.loads(mock_3_10_code, 43) + assert contract._storage.is_empty.get() + assert contract._storage.hash_map.get(boa.env.eoa, 0) == 0 + contract.set_map(69) + assert not contract._storage.is_empty.get() + assert contract._storage.hash_map.get(boa.env.eoa, 0) == 69 + + +def test_vvm_internal(): + contract = boa.loads(mock_3_10_code, 43) + assert not hasattr(contract.internal, "set_map") + address = boa.env.generate_address() + contract.internal._set_map(address, 69) + assert contract._storage.hash_map.get(address, 0) == 69 + + +def test_vvm_inject_fn(): + contract = boa.loads(mock_3_10_code, 43) + contract.inject_function( + """ +@external +def set_bar(bar: uint256): + self.bar = bar +""" + ) + assert contract.bar() == 43 + assert contract.set_bar(44) is None + assert contract.bar() == 44 + + +def test_vvm_inject_fn_exists(): + contract = boa.loads(mock_3_10_code, 43) + code = """ +@external +def bytecode(): + assert False, "Function injected" +""" + with pytest.raises(ValueError) as e: + contract.inject_function(code) + assert "Function bytecode already exists" in str(e.value) + contract.inject_function(code, force=True) + with boa.reverts("Function injected"): + contract.bytecode() + + def test_forward_args_on_deploy(): with open(mock_3_10_path) as f: code = f.read() diff --git a/tests/unitary/fixtures/module_contract.vy b/tests/unitary/fixtures/module_contract.vy index 886424db..9a243065 100644 --- a/tests/unitary/fixtures/module_contract.vy +++ b/tests/unitary/fixtures/module_contract.vy @@ -1,4 +1,4 @@ -# pragma version ^0.4.0 +# pragma version ~=0.4.0 import module_lib diff --git a/tests/unitary/utils/test_cache.py b/tests/unitary/utils/test_cache.py index 64659821..c21cfe84 100644 --- a/tests/unitary/utils/test_cache.py +++ b/tests/unitary/utils/test_cache.py @@ -4,12 +4,13 @@ from vyper.compiler import CompilerData from boa.contracts.vyper.vyper_contract import VyperDeployer -from boa.interpret import _disk_cache, _loads_partial_vvm, compiler_data, set_cache_dir +from boa.interpret import _loads_partial_vvm, compiler_data, get_disk_cache +from boa.util.disk_cache import set_cache_dir @pytest.fixture(autouse=True) def cache_dir(tmp_path): - tmp = _disk_cache.cache_dir + tmp = get_disk_cache().cache_dir try: set_cache_dir(tmp_path) yield @@ -21,7 +22,7 @@ def test_cache_contract_name(): code = """ x: constant(int128) = 1000 """ - assert _disk_cache is not None + assert get_disk_cache() is not None test1 = compiler_data(code, "test1", __file__, VyperDeployer) test2 = compiler_data(code, "test2", __file__, VyperDeployer) test3 = compiler_data(code, "test1", __file__, VyperDeployer) @@ -36,7 +37,7 @@ def test_cache_vvm(): """ version = "0.2.8" version2 = "0.3.1" - assert _disk_cache is not None + assert get_disk_cache() is not None # Mock vvm.compile_source with patch("vvm.compile_source") as mock_compile: