Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code fastpath scanning for valid jump destinations #348

Merged
merged 24 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
523a7b3
WIP
karmacoma-eth Aug 10, 2024
981fb45
delete unused definitions
karmacoma-eth Aug 14, 2024
fddfc1c
add optional dependencies for benchmarking
karmacoma-eth Aug 14, 2024
8c80246
refactor fast path and insn_len
karmacoma-eth Aug 14, 2024
ca12f87
Contract instances are immutable, avoid copies
karmacoma-eth Aug 15, 2024
a7ff795
cleanup
karmacoma-eth Aug 16, 2024
fca8066
fix tests (since we removed contract iteration)
karmacoma-eth Aug 16, 2024
f891e82
unused import
karmacoma-eth Aug 16, 2024
85d963f
DUPn: avoid resimplifying stack elements
karmacoma-eth Aug 16, 2024
8316f62
stack push: fast path for concrete values, no need to check size of s…
karmacoma-eth Aug 16, 2024
dc85678
use constant ZERO instead of con(0)
karmacoma-eth Aug 16, 2024
b29b766
avoid str(cond) pattern
karmacoma-eth Aug 16, 2024
c878a86
simplify PUSHn execution
karmacoma-eth Aug 19, 2024
4079178
remove spurious simplify call from unbox_int
karmacoma-eth Aug 19, 2024
616a8c4
perf: new method to compute jumpi_id
karmacoma-eth Aug 19, 2024
a292df6
fix tests
karmacoma-eth Aug 19, 2024
de8d7c2
better type annotations
karmacoma-eth Aug 19, 2024
20ea62a
fix type syntax for 3.11
karmacoma-eth Aug 20, 2024
b279253
fix HalmosLogs.bounded_loops
karmacoma-eth Aug 20, 2024
6ebdaaa
fix more uses of HalmosLogs.bounded_loops
karmacoma-eth Aug 20, 2024
61343e2
consistency: con(1) -> ONE
karmacoma-eth Aug 20, 2024
6dd088d
clean up types a bit
karmacoma-eth Aug 21, 2024
af3df48
fix revert condition check
karmacoma-eth Aug 21, 2024
afd5c11
add symbolic revert test
karmacoma-eth Aug 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions requirements-benchmark.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# optional dependencies, for benchmarking
pytest-benchmark
pytest-benchmark[histogram]
162 changes: 98 additions & 64 deletions src/halmos/sevm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: AGPL-3.0

import math
import re

from copy import deepcopy
Expand All @@ -22,7 +21,7 @@
)
from z3 import *

from .bytevec import Chunk, ByteVec
from .bytevec import ByteVec, Chunk, ConcreteChunk, UnwrappedBytes
from .cheatcodes import halmos_cheat_code, hevm_cheat_code, Prank
from .config import Config as HalmosConfig
from .console import console
Expand All @@ -45,10 +44,7 @@


# symbolic states
# calldataload(index)
f_calldataload = Function("f_calldataload", BitVecSort256, BitVecSort256)
# calldatasize()
f_calldatasize = Function("f_calldatasize", BitVecSort256)

# extcodesize(target address)
f_extcodesize = Function("f_extcodesize", BitVecSort160, BitVecSort256)
# extcodehash(target address)
Expand Down Expand Up @@ -82,17 +78,18 @@
new_address_offset: int = 1


def insn_len(opcode: int) -> int:
return 1 + (opcode - EVM.PUSH0) * (EVM.PUSH1 <= opcode <= EVM.PUSH32)


class Instruction:
opcode: int
pc: int = -1
next_pc: int = -1
operand: Optional[ByteVec] = None

def __init__(self, opcode, pc=-1, next_pc=-1, operand=None) -> None:
def __init__(self, opcode, pc=-1, operand=None) -> None:
self.opcode = opcode

self.pc = pc
self.next_pc = next_pc
self.operand = operand

def __str__(self) -> str:
Expand All @@ -103,7 +100,7 @@ def __repr__(self) -> str:
return f"Instruction({mnemonic(self.opcode)}, pc={self.pc}, operand={repr(self.operand)})"

def __len__(self) -> int:
return self.next_pc - self.pc
return insn_len(self.opcode)


def id_str(x: Any) -> str:
Expand Down Expand Up @@ -330,7 +327,7 @@ def pop(self) -> Word:
return self.stack.pop()

def dup(self, n: int) -> None:
self.push(self.stack[-n])
self.stack.append(self.stack[-n])
daejunpark marked this conversation as resolved.
Show resolved Hide resolved

def swap(self, n: int) -> None:
self.stack[-(n + 1)], self.stack[-1] = self.stack[-1], self.stack[-(n + 1)]
Expand Down Expand Up @@ -373,22 +370,61 @@ def __init__(self, **kwargs) -> None:
class Contract:
"""Abstraction over contract bytecode. Can include concrete and symbolic elements."""

_code: ByteVec
_fastcode: Optional[bytes]
_insn: Dict[int, Instruction]
_next_pc: Dict[int, int]
daejunpark marked this conversation as resolved.
Show resolved Hide resolved
_jumpdests: Optional[set]

def __init__(self, code: Optional[ByteVec] = None) -> None:
# if
if not isinstance(code, ByteVec):
code = ByteVec(code)

self._code = code
self._fastcode = None

# if the bytecode starts with a concrete prefix, we store it separately for fast access
# (this is a common case, especially for test contracts that deploy other contracts)
if code.chunks:
first_chunk = code.chunks[0]
if isinstance(first_chunk, ConcreteChunk):
self._fastcode = first_chunk.unwrap()

# maps pc to decoded instruction (including operand and next_pc)
self._insn = dict()
self._next_pc = dict()
self._jumpdests = None

def __init_jumpdests(self):
assert not hasattr(self, "_jumpdests")
self._jumpdests = set((pc for (pc, op) in iter(self) if op == EVM.JUMPDEST))
def __deepcopy__(self, memo):
# the class is essentially immutable (the only mutable fields are caches)
# so we can return the object itself instead of creating a new copy
return self

def __get_jumpdests(self):
# quick scan, does not eagerly decode instructions
jumpdests = set()
pc = 0

# optimistically process fast path first
for bytecode in (self._fastcode, self._code):
if not bytecode:
continue

N = len(bytecode)
while pc < N:
try:
opcode = int_of(bytecode[pc])

if opcode == EVM.JUMPDEST:
jumpdests.add(pc)

next_pc = pc + insn_len(opcode)
self._next_pc[pc] = next_pc
pc = next_pc
except NotConcreteError:
break

def __iter__(self):
return CodeIterator(self)
daejunpark marked this conversation as resolved.
Show resolved Hide resolved
return jumpdests

def from_hexcode(hexcode: str):
"""Create a contract from a hexcode string, e.g. "aabbccdd" """
Expand All @@ -409,37 +445,57 @@ def from_hexcode(hexcode: str):
except ValueError as e:
raise ValueError(f"{e} (hexcode={hexcode})")

def _decode_instruction(self, pc: int) -> Instruction:
opcode = int_of(self._code[pc], f"symbolic opcode at pc={pc}")

if EVM.PUSH1 <= opcode <= EVM.PUSH32:
operand_offset = pc + 1
operand_size = opcode - EVM.PUSH0
next_pc = operand_offset + operand_size
def _decode_instruction(self, pc: int) -> Tuple[Instruction, int]:
opcode = int_of(self[pc], f"symbolic opcode at pc={pc}")
daejunpark marked this conversation as resolved.
Show resolved Hide resolved
length = insn_len(opcode)
next_pc = pc + length

if length > 1:
# TODO: consider slicing lazily
operand = self.slice(operand_offset, next_pc).unwrap()
return Instruction(opcode, pc=pc, operand=operand, next_pc=next_pc)
operand = self.unwrapped_slice(pc + 1, next_pc)
return (Instruction(opcode, pc=pc, operand=operand), next_pc)

return Instruction(opcode, pc=pc, next_pc=pc + 1)
return (Instruction(opcode, pc=pc), next_pc)

def decode_instruction(self, pc: int) -> Instruction:
insn = self._insn.get(pc, None)
if insn is None:
insn = self._decode_instruction(pc)
"""decode instruction at pc and cache the result"""

if (insn := self._insn.get(pc)) is None:
insn, next_pc = self._decode_instruction(pc)
self._insn[pc] = insn
self._next_pc[pc] = next_pc

return insn

def next_pc(self, pc):
return self.decode_instruction(pc).next_pc
if (result := self._next_pc.get(pc)) is not None:
return result

self.decode_instruction(pc)
return self._next_pc[pc]

def slice(self, start, stop) -> ByteVec:
# fast path for offsets in the concrete prefix
if self._fastcode and stop < len(self._fastcode):
return ByteVec(self._fastcode[start:stop])

return self._code.slice(start, stop)

def unwrapped_slice(self, start, stop) -> UnwrappedBytes:
# fast path for offsets in the concrete prefix
if self._fastcode and stop < len(self._fastcode):
return self._fastcode[start:stop]

return self._code.slice(start, stop).unwrap()

def __getitem__(self, key: int) -> Byte:
"""Returns the byte at the given offset."""
offset = int_of(key, "symbolic index into contract bytecode {offset!r}")

# fast path for offsets in the concrete prefix
if self._fastcode and offset < len(self._fastcode):
return self._fastcode[offset]

return self._code.get_byte(offset)

def __len__(self) -> int:
Expand All @@ -448,34 +504,12 @@ def __len__(self) -> int:

def valid_jump_destinations(self) -> set:
"""Returns the set of valid jump destinations."""
if not hasattr(self, "_jumpdests"):
self.__init_jumpdests()
if self._jumpdests is None:
self._jumpdests = self.__get_jumpdests()

return self._jumpdests


class CodeIterator:
def __init__(self, contract: Contract):
self.contract = contract
self.pc = 0

def __iter__(self):
return self

def __next__(self) -> Tuple[int, int]:
"""Returns a tuple of (pc, opcode)"""
if self.pc >= len(self.contract):
raise StopIteration

try:
pc = self.pc
insn = self.contract.decode_instruction(pc)
self.pc = insn.next_pc
return (pc, insn.opcode)
except NotConcreteError:
raise StopIteration


@dataclass(frozen=True)
class SMTQuery:
smtlib: str
Expand Down Expand Up @@ -788,7 +822,7 @@ def dump(self, print_mem=False) -> str:
)
)

def next_pc(self) -> None:
def advance_pc(self) -> None:
self.pc = self.pgm.next_pc(self.pc)

def check(self, cond: Any) -> Any:
Expand Down Expand Up @@ -1678,7 +1712,7 @@ def callback(new_ex: Exec, stack, step_id):
new_ex.balance = orig_balance

# add to worklist even if it reverted during the external call
new_ex.next_pc()
new_ex.advance_pc()
stack.push(new_ex, step_id)

sub_ex = Exec(
Expand Down Expand Up @@ -1836,7 +1870,7 @@ def call_unknown() -> None:
# TODO: check if still needed
ex.calls.append((exit_code_var, exit_code, ex.context.output.data))

ex.next_pc()
ex.advance_pc()
stack.push(ex, step_id)

# precompiles or cheatcodes
Expand Down Expand Up @@ -1932,7 +1966,7 @@ def create(
if new_addr in ex.code:
# address conflicts don't revert, they push 0 on the stack and continue
ex.st.push(0)
ex.next_pc()
ex.advance_pc()

# add a virtual subcontext to the trace for debugging purposes
subcall = CallContext(message=message, depth=ex.context.depth + 1)
Expand All @@ -1958,7 +1992,7 @@ def create(
# transfer value
self.transfer_value(ex, pranked_caller, new_addr, value)

def callback(new_ex, stack, step_id):
def callback(new_ex: Exec, stack, step_id):
subcall = new_ex.context

# continue execution in the context of the parent
Expand Down Expand Up @@ -1996,7 +2030,7 @@ def callback(new_ex, stack, step_id):
new_ex.balance = orig_balance

# add to worklist
new_ex.next_pc()
new_ex.advance_pc()
stack.push(new_ex, step_id)

sub_ex = Exec(
Expand Down Expand Up @@ -2079,7 +2113,7 @@ def jumpi(
if follow_false:
new_ex_false = ex
new_ex_false.path.append(cond_false, branching=True)
new_ex_false.next_pc()
new_ex_false.advance_pc()

if new_ex_true:
if potential_true and potential_false:
Expand Down Expand Up @@ -2605,7 +2639,7 @@ def finalize(ex: Exec):
# this halts the path, but we should only halt the current context
raise HalmosException(f"Unsupported opcode {hex(opcode)}")

ex.next_pc()
ex.advance_pc()
stack.push(ex, step_id)

except InfeasiblePath as err:
Expand Down
24 changes: 11 additions & 13 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from z3 import *

from halmos.utils import EVM, hexify

from halmos.sevm import con, Contract, Instruction

from halmos.__main__ import str_abi, run_bytecode, FunctionInfo

from test_fixtures import args
Expand Down Expand Up @@ -57,10 +55,6 @@ def test_decode_concrete_bytecode():
assert contract[10] == EVM.JUMPDEST
assert contract[11] == EVM.STOP

# iteration
opcodes = [opcode for (pc, opcode) in contract]
assert bytes(opcodes).hex() == hexcode.lower()

# jump destination scanning
assert contract.valid_jump_destinations() == set([10])

Expand All @@ -83,8 +77,16 @@ def test_decode_mixed_bytecode():
assert contract[27] == EVM.RETURN
assert contract[28] == EVM.STOP # past the end

# iteration
pcs, opcodes = zip(*iter(contract))
contract.valid_jump_destinations() == set()

# force decoding
pc = 0
while pc < len(contract):
contract.decode_instruction(pc)
pc = contract.next_pc(pc)

pcs, insns = zip(*((pc, insn) for (pc, insn) in contract._insn.items()))
opcodes = tuple(insn.opcode for insn in insns)

assert opcodes == (
EVM.PUSH20,
Expand All @@ -95,7 +97,7 @@ def test_decode_mixed_bytecode():
EVM.RETURN,
)

disassembly = " ".join([str(contract.decode_instruction(pc)) for pc in pcs])
disassembly = " ".join([str(insn) for insn in insns])
assert disassembly == "PUSH20 x() PUSH0 MSTORE PUSH1 0x14 PUSH1 0x0c RETURN"

# jump destination scanning
Expand Down Expand Up @@ -134,11 +136,9 @@ def test_instruction():
def test_decode_hex():
code = Contract.from_hexcode("600100")
assert str(code.decode_instruction(0)) == f"PUSH1 {hexify(1)}"
assert [opcode for (pc, opcode) in code] == [0x60, 0x00]

code = Contract.from_hexcode("01")
assert str(code.decode_instruction(0)) == "ADD"
assert [opcode for (pc, opcode) in code] == [1]

with pytest.raises(ValueError, match="1"):
Contract.from_hexcode("1")
Expand All @@ -155,8 +155,6 @@ def test_decode():
assert str(code[31]) == "Extract(7, 0, x)"

code = Contract(Concat(BitVecVal(EVM.PUSH3, 8), BitVec("x", 16)))
ops = list(code)
assert len(ops) == 1
assert (
str(code.decode_instruction(0)) == "PUSH3 Concat(x(), 0x00)"
) # 'PUSH3 ERROR x (1 bytes missed)'
Expand Down