diff --git a/src/halmos/__main__.py b/src/halmos/__main__.py index 5efac38b..c5b13a24 100644 --- a/src/halmos/__main__.py +++ b/src/halmos/__main__.py @@ -9,7 +9,6 @@ import time import traceback import uuid - from collections import Counter from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from dataclasses import asdict, dataclass @@ -18,13 +17,9 @@ from .bytevec import ByteVec from .calldata import Calldata -from .config import ( - arg_parser, - default_config, - resolve_config_files, - toml_parser, - Config as HalmosConfig, -) +from .config import Config as HalmosConfig +from .config import arg_parser, default_config, resolve_config_files, toml_parser +from .mapper import DeployAddressMapper, Mapper from .sevm import * from .utils import ( NamedTimer, @@ -281,18 +276,16 @@ def rendered_trace(context: CallContext) -> str: return output.getvalue() -def rendered_calldata(calldata: ByteVec) -> str: - return hexify(calldata.unwrap()) if calldata else "0x" +def rendered_calldata(calldata: ByteVec, contract_name: str = None) -> str: + return hexify(calldata.unwrap(), contract_name) if calldata else "0x" def render_trace(context: CallContext, file=sys.stdout) -> None: - # TODO: label for known addresses - # TODO: decode calldata - # TODO: decode logs - message = context.message addr = unbox_int(message.target) addr_str = str(addr) if is_bv(addr) else hex(addr) + # check if we have a contract name for this address in our deployment mapper + addr_str = DeployAddressMapper().get_deployed_contract(addr_str) value = unbox_int(message.value) value_str = f" (value: {value})" if is_bv(value) or value > 0 else "" @@ -303,13 +296,29 @@ def render_trace(context: CallContext, file=sys.stdout) -> None: if message.is_create(): # TODO: select verbosity level to render full initcode # initcode_str = rendered_initcode(context) + + try: + if context.output.error is None: + target = hex(int(str(message.target))) + bytecode = context.output.data.unwrap().hex() + contract_name = ( + Mapper() + .get_contract_mapping_info_by_bytecode(bytecode) + .contract_name + ) + + DeployAddressMapper().add_deployed_contract(target, contract_name) + addr_str = contract_name + except: + pass + initcode_str = f"<{byte_length(message.data)} bytes of initcode>" print( f"{indent}{call_scheme_str}{addr_str}::{initcode_str}{value_str}", file=file ) else: - calldata = rendered_calldata(message.data) + calldata = rendered_calldata(message.data, addr_str) call_str = f"{addr_str}::{calldata}" static_str = yellow(" [static]") if message.is_static else "" print(f"{indent}{call_scheme_str}{call_str}{static_str}{value_str}", file=file) @@ -1348,6 +1357,29 @@ def parse_build_out(args: HalmosConfig) -> Dict: sol_dirname, ) contract_map[contract_name] = (json_out, contract_type, natspec) + + try: + bytecode = contract_map[contract_name][0]["bytecode"]["object"] + contract_mapping_info = Mapper().get_contract_mapping_info_by_name( + contract_name + ) + + if contract_mapping_info is None: + Mapper().add_contract_mapping_info( + contract_name=contract_name, + bytecode=bytecode, + nodes=[], + ) + else: + contract_mapping_info.bytecode = bytecode + + contract_mapping_info = Mapper().get_contract_mapping_info_by_name( + contract_name + ) + Mapper().parse_ast(contract_map[contract_name][0]["ast"]) + + except Exception: + pass except Exception as err: warn_code( PARSING_ERROR, @@ -1571,6 +1603,9 @@ def on_signal(signum, frame): contract_path = f"{contract_json['ast']['absolutePath']}:{contract_name}" print(f"\nRunning {num_found} tests for {contract_path}") + # Set 0xaaaa0001 in DeployAddressMapper + DeployAddressMapper().add_deployed_contract("0xaaaa0001", contract_name) + # support for `/// @custom:halmos` annotations contract_args = with_natspec(args, contract_name, natspec) run_args = RunArgs( diff --git a/src/halmos/mapper.py b/src/halmos/mapper.py new file mode 100644 index 00000000..fa59fdc9 --- /dev/null +++ b/src/halmos/mapper.py @@ -0,0 +1,194 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + + +@dataclass +class AstNode: + node_type: str + id: int + name: str + address: str # TODO: rename it to `selector` or `signature` to better reflect the meaning + visibility: str + + +@dataclass +class ContractMappingInfo: + contract_name: str + bytecode: str + nodes: List[AstNode] + + +class SingletonMeta(type): + _instances: Dict[Type, Any] = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super().__call__(*args, **kwargs) + + return cls._instances[cls] + + +class Mapper(metaclass=SingletonMeta): + """ + Mapping from a contract name to its runtime bytecode and the signatures of functions/events/errors declared in the contract + """ + + _PARSING_IGNORED_NODE_TYPES = [ + "StructDefinition", + "EnumDefinition", + "PragmaDirective", + "ImportDirective", + "Block", + ] + + def __init__(self): + self._contracts: Dict[str, ContractMappingInfo] = {} + + def add_contract_mapping_info( + self, contract_name: str, bytecode: str, nodes: List[AstNode] + ): + if contract_name in self._contracts: + raise ValueError(f"Contract {contract_name} already exists") + + self._contracts[contract_name] = ContractMappingInfo( + contract_name, bytecode, nodes + ) + + def get_contract_mapping_info_by_name( + self, contract_name: str + ) -> Optional[ContractMappingInfo]: + return self._contracts.get(contract_name, None) + + def get_contract_mapping_info_by_bytecode( + self, bytecode: str + ) -> Optional[ContractMappingInfo]: + # TODO: Handle cases for contracts with immutable variables + # Current implementation might not work correctly if the following code is added the test solidity file + # + # address immutable public owner; + # constructor() { + # owner = msg.sender; + # } + + for contract_mapping_info in self._contracts.values(): + # TODO: use regex instaed of `endswith` to better handle immutables or constructors with arguments + if contract_mapping_info.bytecode.endswith(bytecode): + return contract_mapping_info + + return None + + def append_node(self, contract_name: str, node: AstNode): + contract_mapping_info = self.get_contract_mapping_info_by_name(contract_name) + + if contract_mapping_info is None: + raise ValueError(f"Contract {contract_name} not found") + + contract_mapping_info.nodes.append(node) + + def parse_ast(self, node: Dict, contract_name: str = ""): + node_type = node["nodeType"] + + if node_type in self._PARSING_IGNORED_NODE_TYPES: + return + + current_contract = self._get_current_contract(node, contract_name) + + if node_type == "ContractDefinition": + if current_contract not in self._contracts: + self.add_contract_mapping_info( + contract_name=current_contract, bytecode="", nodes=[] + ) + + if self.get_contract_mapping_info_by_name(current_contract).nodes: + return + elif node_type != "SourceUnit": + id, name, address, visibility = self._get_node_info(node, node_type) + + self.append_node( + current_contract, + AstNode(node_type, id, name, address, visibility), + ) + + for child_node in node.get("nodes", []): + self.parse_ast(child_node, current_contract) + + if "body" in node: + self.parse_ast(node["body"], current_contract) + + def _get_node_info(self, node: Dict, node_type: str) -> Dict: + return ( + node.get("id", ""), + node.get("name", ""), + "0x" + self._get_node_address(node, node_type), + node.get("visibility", ""), + ) + + def _get_node_address(self, node: Dict, node_type: str) -> str: + address_fields = { + "VariableDeclaration": "functionSelector", + "FunctionDefinition": "functionSelector", + "EventDefinition": "eventSelector", + "ErrorDefinition": "errorSelector", + } + + return node.get(address_fields.get(node_type, ""), "") + + def _get_current_contract(self, node: Dict, contract_name: str) -> str: + return ( + node.get("name", "") + if node["nodeType"] == "ContractDefinition" + else contract_name + ) + + def find_nodes_by_address(self, address: str, contract_name: str = None): + # if the given signature is declared in the given contract, return its name. + if contract_name: + contract_mapping_info = self.get_contract_mapping_info_by_name( + contract_name + ) + + if contract_mapping_info: + for node in contract_mapping_info.nodes: + if node.address == address: + return node.name + + # otherwise, search for the signature in other contracts, and return all the contracts that declare it. + # note: ambiguity may occur if multiple compilation units exist. + result = "" + for key, contract_info in self._contracts.items(): + matching_nodes = [ + node for node in contract_info.nodes if node.address == address + ] + + for node in matching_nodes: + result += f"{key}.{node.name} " + + return result.strip() if result != "" and address != "0x" else address + + +# TODO: create a new instance or reset for each test +class DeployAddressMapper(metaclass=SingletonMeta): + """ + Mapping from deployed addresses to contract names + """ + + def __init__(self): + self._deployed_contracts: Dict[str, str] = {} + + # Set up some default mappings + self.add_deployed_contract( + "0x7109709ecfa91a80626ff3989d68f67f5b1dd12d", "HEVM_ADDRESS" + ) + self.add_deployed_contract( + "0xf3993a62377bcd56ae39d773740a5390411e8bc9", "SVM_ADDRESS" + ) + + def add_deployed_contract( + self, + address: str, + contract_name: str, + ): + self._deployed_contracts[address] = contract_name + + def get_deployed_contract(self, address: str) -> Optional[str]: + return self._deployed_contracts.get(address, address) diff --git a/src/halmos/utils.py b/src/halmos/utils.py index b3eb4854..0331a632 100644 --- a/src/halmos/utils.py +++ b/src/halmos/utils.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: AGPL-3.0 import re - +from functools import partial from timeit import default_timer as timer -from typing import Dict, Tuple, Any, Optional, Union as UnionType +from typing import Any, Dict, Optional, Tuple +from typing import Union as UnionType from z3 import * -from .exceptions import NotConcreteError, HalmosException +from halmos.mapper import Mapper + +from .exceptions import HalmosException, NotConcreteError # order of the secp256k1 curve secp256k1n = ( @@ -318,23 +321,25 @@ def decode_hex(hexstring: str) -> Optional[bytes]: return None -def hexify(x): +def hexify(x, contract_name: str = None): if isinstance(x, str): return re.sub(r"\b(\d+)\b", lambda match: hex(int(match.group(1))), x) elif isinstance(x, int): return f"0x{x:02x}" elif isinstance(x, bytes): - return "0x" + x.hex() + return Mapper().find_nodes_by_address("0x" + x.hex(), contract_name) elif hasattr(x, "unwrap"): - return hexify(x.unwrap()) + return hexify(x.unwrap(), contract_name) elif is_bv_value(x): # maintain the byte size of x num_bytes = byte_length(x, strict=False) - return f"0x{x.as_long():0{num_bytes * 2}x}" + return Mapper().find_nodes_by_address( + f"0x{x.as_long():0{num_bytes * 2}x}", contract_name + ) elif is_app(x): - return f"{str(x.decl())}({', '.join(map(hexify, x.children()))})" + return f"{str(x.decl())}({', '.join(map(partial(hexify, contract_name=contract_name), x.children()))})" else: - return hexify(str(x)) + return hexify(str(x), contract_name) def render_uint(x: BitVecRef) -> str: diff --git a/tests/test_mapper.py b/tests/test_mapper.py new file mode 100644 index 00000000..ce7f1fba --- /dev/null +++ b/tests/test_mapper.py @@ -0,0 +1,176 @@ +from typing import List + +import pytest + +from halmos.mapper import AstNode, ContractMappingInfo, Mapper, SingletonMeta + + +@pytest.fixture +def ast_nodes() -> List[AstNode]: + return [ + AstNode( + node_type="type1", id=1, name="Node1", address="0x123", visibility="public" + ), + AstNode( + node_type="type2", id=2, name="Node2", address="0x456", visibility="private" + ), + ] + + +@pytest.fixture +def mapper() -> Mapper: + return Mapper() + + +@pytest.fixture(autouse=True) +def reset_singleton(): + SingletonMeta._instances = {} + + +def test_singleton(): + mapper1 = Mapper() + mapper2 = Mapper() + assert mapper1 is mapper2 + + +def test_add_contract_mapping_info(mapper, ast_nodes): + mapper.add_contract_mapping_info("ContractA", "bytecodeA", ast_nodes) + contract_info = mapper.get_contract_mapping_info_by_name("ContractA") + assert contract_info is not None + assert contract_info.contract_name == "ContractA" + assert contract_info.bytecode == "bytecodeA" + assert len(contract_info.nodes) == 2 + + +def test_add_contract_mapping_info_already_existence(mapper, ast_nodes): + mapper.add_contract_mapping_info("ContractA", "bytecodeA", ast_nodes) + + with pytest.raises(ValueError, match=r"Contract ContractA already exists"): + mapper.add_contract_mapping_info("ContractA", "bytecodeA", ast_nodes) + + +def test_get_contract_mapping_info_by_name(mapper, ast_nodes): + mapper.add_contract_mapping_info("ContractA", "bytecodeA", ast_nodes) + contract_info = mapper.get_contract_mapping_info_by_name("ContractA") + assert contract_info is not None + assert contract_info.contract_name == "ContractA" + + +def test_get_contract_mapping_info_by_name_nonexistent(mapper): + contract_info = mapper.get_contract_mapping_info_by_name("ContractA") + assert contract_info is None + + +def test_get_contract_mapping_info_by_bytecode(mapper, ast_nodes): + mapper.add_contract_mapping_info("ContractA", "bytecodeA", ast_nodes) + contract_info = mapper.get_contract_mapping_info_by_bytecode("bytecodeA") + assert contract_info is not None + assert contract_info.bytecode == "bytecodeA" + + +def test_get_contract_mapping_info_by_bytecode_nonexistent(mapper): + contract_info = mapper.get_contract_mapping_info_by_bytecode("bytecodeA") + assert contract_info is None + + +def test_append_node(mapper, ast_nodes): + mapper.add_contract_mapping_info("ContractA", "bytecodeA", ast_nodes) + new_node = AstNode( + node_type="type3", id=3, name="Node3", address="0x789", visibility="public" + ) + mapper.append_node("ContractA", new_node) + contract_info = mapper.get_contract_mapping_info_by_name("ContractA") + assert contract_info is not None + assert len(contract_info.nodes) == 3 + assert contract_info.nodes[-1].id == 3 + + +def test_append_node_to_nonexistent_contract(mapper): + new_node = AstNode( + node_type="type3", id=3, name="Node3", address="0x789", visibility="public" + ) + with pytest.raises(ValueError, match=r"Contract NonexistentContract not found"): + mapper.append_node("NonexistentContract", new_node) + + +def test_parse_simple_ast(mapper): + example_ast = { + "nodeType": "ContractDefinition", + "id": 1, + "name": "ExampleContract", + "nodes": [ + { + "nodeType": "FunctionDefinition", + "id": 2, + "name": "exampleFunction", + "functionSelector": "abcdef", + "visibility": "public", + "nodes": [], + } + ], + } + mapper.parse_ast(example_ast) + contract_info = mapper.get_contract_mapping_info_by_name("ExampleContract") + + assert contract_info is not None + assert contract_info.contract_name == "ExampleContract" + assert len(contract_info.nodes) == 1 + assert contract_info.nodes[0].name == "exampleFunction" + + +def test_parse_complex_ast(mapper): + complex_ast = { + "nodeType": "ContractDefinition", + "id": 1, + "name": "ComplexContract", + "nodes": [ + { + "nodeType": "VariableDeclaration", + "id": 2, + "name": "var1", + "functionSelector": "", + "visibility": "private", + }, + { + "nodeType": "FunctionDefinition", + "id": 3, + "name": "func1", + "functionSelector": "222222", + "visibility": "public", + "nodes": [ + { + "nodeType": "Block", + "id": 4, + "name": "innerBlock", + "functionSelector": "", + "visibility": "", + } + ], + }, + { + "nodeType": "EventDefinition", + "id": 5, + "name": "event1", + "eventSelector": "444444", + "visibility": "public", + }, + { + "nodeType": "ErrorDefinition", + "id": 6, + "name": "error1", + "errorSelector": "555555", + "visibility": "public", + }, + ], + } + mapper.parse_ast(complex_ast) + contract_info = mapper.get_contract_mapping_info_by_name("ComplexContract") + assert contract_info is not None + assert contract_info.contract_name == "ComplexContract" + assert len(contract_info.nodes) == 4 + + node_names = [node.name for node in contract_info.nodes] + assert "var1" in node_names + assert "func1" in node_names + assert "event1" in node_names + assert "error1" in node_names