Skip to content

Commit

Permalink
Merge branch 'main' into perf/early-exit-fail
Browse files Browse the repository at this point in the history
  • Loading branch information
daejunpark authored Jul 17, 2024
2 parents 78dbeea + 7ba5a74 commit be42683
Show file tree
Hide file tree
Showing 4 changed files with 434 additions and 24 deletions.
65 changes: 50 additions & 15 deletions src/halmos/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import time
import traceback
import uuid

from collections import Counter
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import asdict, dataclass
Expand All @@ -18,13 +17,9 @@

from .bytevec import ByteVec
from .calldata import Calldata
from .config import (
arg_parser,
default_config,
resolve_config_files,
toml_parser,
Config as HalmosConfig,
)
from .config import Config as HalmosConfig
from .config import arg_parser, default_config, resolve_config_files, toml_parser
from .mapper import DeployAddressMapper, Mapper
from .sevm import *
from .utils import (
NamedTimer,
Expand Down Expand Up @@ -281,18 +276,16 @@ def rendered_trace(context: CallContext) -> str:
return output.getvalue()


def rendered_calldata(calldata: ByteVec) -> str:
return hexify(calldata.unwrap()) if calldata else "0x"
def rendered_calldata(calldata: ByteVec, contract_name: str = None) -> str:
return hexify(calldata.unwrap(), contract_name) if calldata else "0x"


def render_trace(context: CallContext, file=sys.stdout) -> None:
# TODO: label for known addresses
# TODO: decode calldata
# TODO: decode logs

message = context.message
addr = unbox_int(message.target)
addr_str = str(addr) if is_bv(addr) else hex(addr)
# check if we have a contract name for this address in our deployment mapper
addr_str = DeployAddressMapper().get_deployed_contract(addr_str)

value = unbox_int(message.value)
value_str = f" (value: {value})" if is_bv(value) or value > 0 else ""
Expand All @@ -303,13 +296,29 @@ def render_trace(context: CallContext, file=sys.stdout) -> None:
if message.is_create():
# TODO: select verbosity level to render full initcode
# initcode_str = rendered_initcode(context)

try:
if context.output.error is None:
target = hex(int(str(message.target)))
bytecode = context.output.data.unwrap().hex()
contract_name = (
Mapper()
.get_contract_mapping_info_by_bytecode(bytecode)
.contract_name
)

DeployAddressMapper().add_deployed_contract(target, contract_name)
addr_str = contract_name
except:
pass

initcode_str = f"<{byte_length(message.data)} bytes of initcode>"
print(
f"{indent}{call_scheme_str}{addr_str}::{initcode_str}{value_str}", file=file
)

else:
calldata = rendered_calldata(message.data)
calldata = rendered_calldata(message.data, addr_str)
call_str = f"{addr_str}::{calldata}"
static_str = yellow(" [static]") if message.is_static else ""
print(f"{indent}{call_scheme_str}{call_str}{static_str}{value_str}", file=file)
Expand Down Expand Up @@ -1348,6 +1357,29 @@ def parse_build_out(args: HalmosConfig) -> Dict:
sol_dirname,
)
contract_map[contract_name] = (json_out, contract_type, natspec)

try:
bytecode = contract_map[contract_name][0]["bytecode"]["object"]
contract_mapping_info = Mapper().get_contract_mapping_info_by_name(
contract_name
)

if contract_mapping_info is None:
Mapper().add_contract_mapping_info(
contract_name=contract_name,
bytecode=bytecode,
nodes=[],
)
else:
contract_mapping_info.bytecode = bytecode

contract_mapping_info = Mapper().get_contract_mapping_info_by_name(
contract_name
)
Mapper().parse_ast(contract_map[contract_name][0]["ast"])

except Exception:
pass
except Exception as err:
warn_code(
PARSING_ERROR,
Expand Down Expand Up @@ -1571,6 +1603,9 @@ def on_signal(signum, frame):
contract_path = f"{contract_json['ast']['absolutePath']}:{contract_name}"
print(f"\nRunning {num_found} tests for {contract_path}")

# Set 0xaaaa0001 in DeployAddressMapper
DeployAddressMapper().add_deployed_contract("0xaaaa0001", contract_name)

# support for `/// @custom:halmos` annotations
contract_args = with_natspec(args, contract_name, natspec)
run_args = RunArgs(
Expand Down
194 changes: 194 additions & 0 deletions src/halmos/mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type


@dataclass
class AstNode:
node_type: str
id: int
name: str
address: str # TODO: rename it to `selector` or `signature` to better reflect the meaning
visibility: str


@dataclass
class ContractMappingInfo:
contract_name: str
bytecode: str
nodes: List[AstNode]


class SingletonMeta(type):
_instances: Dict[Type, Any] = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)

return cls._instances[cls]


class Mapper(metaclass=SingletonMeta):
"""
Mapping from a contract name to its runtime bytecode and the signatures of functions/events/errors declared in the contract
"""

_PARSING_IGNORED_NODE_TYPES = [
"StructDefinition",
"EnumDefinition",
"PragmaDirective",
"ImportDirective",
"Block",
]

def __init__(self):
self._contracts: Dict[str, ContractMappingInfo] = {}

def add_contract_mapping_info(
self, contract_name: str, bytecode: str, nodes: List[AstNode]
):
if contract_name in self._contracts:
raise ValueError(f"Contract {contract_name} already exists")

self._contracts[contract_name] = ContractMappingInfo(
contract_name, bytecode, nodes
)

def get_contract_mapping_info_by_name(
self, contract_name: str
) -> Optional[ContractMappingInfo]:
return self._contracts.get(contract_name, None)

def get_contract_mapping_info_by_bytecode(
self, bytecode: str
) -> Optional[ContractMappingInfo]:
# TODO: Handle cases for contracts with immutable variables
# Current implementation might not work correctly if the following code is added the test solidity file
#
# address immutable public owner;
# constructor() {
# owner = msg.sender;
# }

for contract_mapping_info in self._contracts.values():
# TODO: use regex instaed of `endswith` to better handle immutables or constructors with arguments
if contract_mapping_info.bytecode.endswith(bytecode):
return contract_mapping_info

return None

def append_node(self, contract_name: str, node: AstNode):
contract_mapping_info = self.get_contract_mapping_info_by_name(contract_name)

if contract_mapping_info is None:
raise ValueError(f"Contract {contract_name} not found")

contract_mapping_info.nodes.append(node)

def parse_ast(self, node: Dict, contract_name: str = ""):
node_type = node["nodeType"]

if node_type in self._PARSING_IGNORED_NODE_TYPES:
return

current_contract = self._get_current_contract(node, contract_name)

if node_type == "ContractDefinition":
if current_contract not in self._contracts:
self.add_contract_mapping_info(
contract_name=current_contract, bytecode="", nodes=[]
)

if self.get_contract_mapping_info_by_name(current_contract).nodes:
return
elif node_type != "SourceUnit":
id, name, address, visibility = self._get_node_info(node, node_type)

self.append_node(
current_contract,
AstNode(node_type, id, name, address, visibility),
)

for child_node in node.get("nodes", []):
self.parse_ast(child_node, current_contract)

if "body" in node:
self.parse_ast(node["body"], current_contract)

def _get_node_info(self, node: Dict, node_type: str) -> Dict:
return (
node.get("id", ""),
node.get("name", ""),
"0x" + self._get_node_address(node, node_type),
node.get("visibility", ""),
)

def _get_node_address(self, node: Dict, node_type: str) -> str:
address_fields = {
"VariableDeclaration": "functionSelector",
"FunctionDefinition": "functionSelector",
"EventDefinition": "eventSelector",
"ErrorDefinition": "errorSelector",
}

return node.get(address_fields.get(node_type, ""), "")

def _get_current_contract(self, node: Dict, contract_name: str) -> str:
return (
node.get("name", "")
if node["nodeType"] == "ContractDefinition"
else contract_name
)

def find_nodes_by_address(self, address: str, contract_name: str = None):
# if the given signature is declared in the given contract, return its name.
if contract_name:
contract_mapping_info = self.get_contract_mapping_info_by_name(
contract_name
)

if contract_mapping_info:
for node in contract_mapping_info.nodes:
if node.address == address:
return node.name

# otherwise, search for the signature in other contracts, and return all the contracts that declare it.
# note: ambiguity may occur if multiple compilation units exist.
result = ""
for key, contract_info in self._contracts.items():
matching_nodes = [
node for node in contract_info.nodes if node.address == address
]

for node in matching_nodes:
result += f"{key}.{node.name} "

return result.strip() if result != "" and address != "0x" else address


# TODO: create a new instance or reset for each test
class DeployAddressMapper(metaclass=SingletonMeta):
"""
Mapping from deployed addresses to contract names
"""

def __init__(self):
self._deployed_contracts: Dict[str, str] = {}

# Set up some default mappings
self.add_deployed_contract(
"0x7109709ecfa91a80626ff3989d68f67f5b1dd12d", "HEVM_ADDRESS"
)
self.add_deployed_contract(
"0xf3993a62377bcd56ae39d773740a5390411e8bc9", "SVM_ADDRESS"
)

def add_deployed_contract(
self,
address: str,
contract_name: str,
):
self._deployed_contracts[address] = contract_name

def get_deployed_contract(self, address: str) -> Optional[str]:
return self._deployed_contracts.get(address, address)
23 changes: 14 additions & 9 deletions src/halmos/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# SPDX-License-Identifier: AGPL-3.0

import re

from functools import partial
from timeit import default_timer as timer
from typing import Dict, Tuple, Any, Optional, Union as UnionType
from typing import Any, Dict, Optional, Tuple
from typing import Union as UnionType

from z3 import *

from .exceptions import NotConcreteError, HalmosException
from halmos.mapper import Mapper

from .exceptions import HalmosException, NotConcreteError

# order of the secp256k1 curve
secp256k1n = (
Expand Down Expand Up @@ -318,23 +321,25 @@ def decode_hex(hexstring: str) -> Optional[bytes]:
return None


def hexify(x):
def hexify(x, contract_name: str = None):
if isinstance(x, str):
return re.sub(r"\b(\d+)\b", lambda match: hex(int(match.group(1))), x)
elif isinstance(x, int):
return f"0x{x:02x}"
elif isinstance(x, bytes):
return "0x" + x.hex()
return Mapper().find_nodes_by_address("0x" + x.hex(), contract_name)
elif hasattr(x, "unwrap"):
return hexify(x.unwrap())
return hexify(x.unwrap(), contract_name)
elif is_bv_value(x):
# maintain the byte size of x
num_bytes = byte_length(x, strict=False)
return f"0x{x.as_long():0{num_bytes * 2}x}"
return Mapper().find_nodes_by_address(
f"0x{x.as_long():0{num_bytes * 2}x}", contract_name
)
elif is_app(x):
return f"{str(x.decl())}({', '.join(map(hexify, x.children()))})"
return f"{str(x.decl())}({', '.join(map(partial(hexify, contract_name=contract_name), x.children()))})"
else:
return hexify(str(x))
return hexify(str(x), contract_name)


def render_uint(x: BitVecRef) -> str:
Expand Down
Loading

0 comments on commit be42683

Please sign in to comment.