From f1e803b9fce0ee6f163841d198d6a0cac28a1f0e Mon Sep 17 00:00:00 2001 From: karmacoma Date: Wed, 31 Jul 2024 08:57:51 -0700 Subject: [PATCH 01/21] WIP: beginning of startPrank(address sender, address origin) --- src/halmos/cheatcodes.py | 92 +++++++++++++++++++++++++++------------- src/halmos/sevm.py | 4 +- 2 files changed, 66 insertions(+), 30 deletions(-) diff --git a/src/halmos/cheatcodes.py b/src/halmos/cheatcodes.py index 9d08c84f..a2977fac 100644 --- a/src/halmos/cheatcodes.py +++ b/src/halmos/cheatcodes.py @@ -75,58 +75,68 @@ def stringified_bytes_to_bytes(hexstring: str) -> ByteVec: return ByteVec(ret_bytes) +@dataclass +class PrankResult: + sender: Address | None + origin: Address | None + + class Prank: - addr: Any # prank address + sender: Address | None # prank msg.sender address + origin: Address | None # prank tx.origin address keep: bool # start / stop prank - def __init__(self, addr: Any = None, keep: bool = False) -> None: - if addr is not None: - assert_address(addr) - self.addr = addr - self.keep = keep + def __init__(self) -> None: + self.sender = None + self.origin = None + self.keep = False + + def __bool__(self) -> bool: + return self.sender is not None or self.origin is not None def __str__(self) -> str: - if self.addr: - if self.keep: - return f"startPrank({str(self.addr)})" - else: - return f"prank({str(self.addr)})" + if not self: + return "no active prank" + + fn_name = "startPrank" if self.keep else "prank" + if self.origin is not None: + return f"{fn_name}({hexify(self.sender)}, {hexify(self.origin)})" else: - return "None" + return f"{fn_name}({hexify(self.sender)})" - def lookup(self, this: Any, to: Any) -> Any: - assert_address(this) + def lookup(self, to: Address) -> PrankResult: assert_address(to) - caller = this + result = PrankResult() if ( - self.addr is not None + self and not eq(to, hevm_cheat_code.address) and not eq(to, halmos_cheat_code.address) ): - caller = self.addr + result.caller = self.sender + result.origin = self.origin if not self.keep: - self.addr = None - return caller + self.stopPrank() + return result - def prank(self, addr: Any) -> bool: - assert_address(addr) - if self.addr is not None: + def prank(self, sender: Address) -> bool: + if self.sender is not None: return False - self.addr = addr + self.sender = sender self.keep = False return True - def startPrank(self, addr: Any) -> bool: + def startPrank(self, addr: Address) -> bool: assert_address(addr) - if self.addr is not None: + if self.sender is not None: return False - self.addr = addr + self.sender = addr self.keep = True return True def stopPrank(self) -> bool: # stopPrank is allowed to call even when no active prank exists - self.addr = None + self.sender = None + self.origin = None self.keep = False return True @@ -282,9 +292,15 @@ class hevm_cheat_code: # bytes4(keccak256("prank(address)")) prank_sig: int = 0xCA669FA7 + # bytes4(keccak256("prank(address,address)")) + prank_addr_addr_sig: int = 0x42424242 + # bytes4(keccak256("startPrank(address)")) start_prank_sig: int = 0x06447D56 + # bytes4(keccak256("startPrank(address,address)")) + start_prank_addr_addr_sig: int = 0x42424242 + # bytes4(keccak256("stopPrank()")) stop_prank_sig: int = 0x90C5013B @@ -381,8 +397,17 @@ def handle(sevm, ex, arg: ByteVec, stack, step_id) -> Optional[ByteVec]: # vm.prank(address) elif funsig == hevm_cheat_code.prank_sig: - address = uint160(arg.get_word(4)) - result = ex.prank.prank(address) + sender = uint160(arg.get_word(4)) + result = ex.prank.prank(sender) + if not result: + raise HalmosException("You have an active prank already.") + return ret + + # vm.prank(address sender, address origin) + elif funsig == hevm_cheat_code.prank_addr_addr_sig: + sender = uint160(arg.get_word(4)) + origin = uint160(arg.get_word(36)) + result = ex.prank.prank(sender, origin) if not result: raise HalmosException("You have an active prank already.") return ret @@ -395,6 +420,15 @@ def handle(sevm, ex, arg: ByteVec, stack, step_id) -> Optional[ByteVec]: raise HalmosException("You have an active prank already.") return ret + # vm.startPrank(address sender, address origin) + elif funsig == hevm_cheat_code.start_prank_addr_addr_sig: + sender = uint160(arg.get_word(4)) + origin = uint160(arg.get_word(36)) + result = ex.prank.startPrank(sender, origin) + if not result: + raise HalmosException("You have an active prank already.") + return ret + # vm.stopPrank() elif funsig == hevm_cheat_code.stop_prank_sig: ex.prank.stopPrank() diff --git a/src/halmos/sevm.py b/src/halmos/sevm.py index 93044120..f0b16ce5 100644 --- a/src/halmos/sevm.py +++ b/src/halmos/sevm.py @@ -1558,7 +1558,9 @@ def call( if not ret_size >= 0: raise ValueError(ret_size) - caller = ex.prank.lookup(ex.this, to) + prank_result = ex.prank.lookup(to) + caller = ex.this if prank_result.sender is None else prank_result.sender + origin = f_origin() if prank_result.origin is None else prank_result.origin arg = ex.st.memory.slice(arg_loc, arg_loc + arg_size) def send_callvalue(condition=None) -> None: From 022e2df85f5c3f8315de39e1657f1464aaa6c009 Mon Sep 17 00:00:00 2001 From: karmacoma Date: Thu, 1 Aug 2024 15:41:36 -0700 Subject: [PATCH 02/21] add tests/test_prank.py --- tests/test_prank.py | 128 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 tests/test_prank.py diff --git a/tests/test_prank.py b/tests/test_prank.py new file mode 100644 index 00000000..b735f441 --- /dev/null +++ b/tests/test_prank.py @@ -0,0 +1,128 @@ +import pytest + +from z3 import BitVec + +from halmos.cheatcodes import ( + Prank, + PrankResult, + NO_PRANK, + hevm_cheat_code, + halmos_cheat_code, +) + + +@pytest.fixture +def prank(): + return Prank() + + +@pytest.fixture +def sender(): + return BitVec("sender", 160) + + +@pytest.fixture +def origin(): + return BitVec("origin", 160) + + +@pytest.fixture +def other(): + return BitVec("other", 160) + + +def test_prank_truthiness(sender, origin): + assert not Prank() + assert Prank(PrankResult(sender=sender)) + assert Prank(PrankResult(origin=origin)) + assert Prank(PrankResult(sender=sender, origin=origin)) + + +def test_prank_can_not_override_active_prank(prank, sender, origin): + active_prank = PrankResult(sender=sender) + + # when we call prank() the first time, it activates the prank + assert prank.prank(sender) + assert prank.active == active_prank + + # when we call prank() the second time, it does not change the prank state + assert not prank.prank(sender) + assert prank.active == active_prank + + # same with startPrank + assert not prank.startPrank(sender, origin) + assert prank.active == active_prank + assert not prank.keep + + +def test_start_prank_can_not_override_active_prank(prank, sender, origin, other): + active_prank = PrankResult(sender=sender, origin=origin) + + assert prank.startPrank(sender, origin) + assert prank.active == active_prank + assert prank.keep + + # can not override active prank + assert not prank.startPrank(other, other) + assert prank.active == active_prank + assert prank.keep + + +def test_stop_prank(prank, sender, origin): + # can call stopPrank() even if there is no active prank + assert prank.stopPrank() + assert not prank.keep + assert not prank + + # when we call prank(), the prank is activated + prank.prank(sender) + assert prank.active == PrankResult(sender=sender) + assert not prank.keep + + # when we call stopPrank(), the prank is deactivated + prank.stopPrank() + assert not prank + assert not prank.keep + + # when we call startPrank(), the prank is activated + prank.startPrank(sender, origin) + assert prank.active == PrankResult(sender=sender, origin=origin) + assert prank.keep + + # when we call stopPrank(), the prank is deactivated + prank.stopPrank() + assert not prank + assert not prank.keep + + +def test_lookup_no_active_prank(prank, other): + # when we call lookup() without an active prank, it returns NO_PRANK + assert prank.lookup(other) == NO_PRANK + assert prank.lookup(hevm_cheat_code.address) == NO_PRANK + assert prank.lookup(halmos_cheat_code.address) == NO_PRANK + + +def test_prank_lookup(prank, sender, other): + # when calling lookup() after prank() + prank.prank(sender) + result = prank.lookup(other) + + # then the active prank is returned + assert result.sender == sender + assert result.origin is None + + # and the prank is no longer active + assert not prank + + +def test_startPrank_lookup(prank, sender, origin, other): + # when calling lookup() after startPrank() + prank.startPrank(sender, origin) + result = prank.lookup(other) + + # then the active prank is returned + assert result.sender == sender + assert result.origin == origin + + # and the prank is still active + assert prank From c1d5f9eb2c70106952cdfd71cd1af0dd145ddbd6 Mon Sep 17 00:00:00 2001 From: karmacoma Date: Thu, 1 Aug 2024 15:42:03 -0700 Subject: [PATCH 03/21] WIP: rework Prank class (still need to fix the callers) --- src/halmos/cheatcodes.py | 89 +++++++++++++++++++++++++--------------- 1 file changed, 57 insertions(+), 32 deletions(-) diff --git a/src/halmos/cheatcodes.py b/src/halmos/cheatcodes.py index a2977fac..6b5477ec 100644 --- a/src/halmos/cheatcodes.py +++ b/src/halmos/cheatcodes.py @@ -75,68 +75,93 @@ def stringified_bytes_to_bytes(hexstring: str) -> ByteVec: return ByteVec(ret_bytes) -@dataclass +@dataclass(frozen=True) class PrankResult: - sender: Address | None - origin: Address | None + sender: Address | None = None + origin: Address | None = None + + def __bool__(self) -> bool: + """ + True iff either sender or origin is set. + """ + return self.sender is not None or self.origin is not None + + def __str__(self) -> str: + return f"{hexify(self.sender)}, {hexify(self.origin)}" +NO_PRANK = PrankResult() + + +@dataclass class Prank: - sender: Address | None # prank msg.sender address - origin: Address | None # prank tx.origin address - keep: bool # start / stop prank + """ + A mutable object to store current prank context, one per execution context. - def __init__(self) -> None: - self.sender = None - self.origin = None - self.keep = False + Because it's mutable, it must be copied across contexts. + + Can test for the existence of an active prank with `if prank: ...` + + A prank is active if either sender or origin is set. + Technically supports pranking origin but not sender, which is not + possible with the current cheatcodes: + - prank(address) sets sender + - prank(address, address) sets both sender and origin + """ + + active: PrankResult = NO_PRANK # active prank context + keep: bool = False # start / stop prank def __bool__(self) -> bool: - return self.sender is not None or self.origin is not None + """ + True iff either sender or origin is set. + """ + return bool(self.active) def __str__(self) -> str: if not self: return "no active prank" fn_name = "startPrank" if self.keep else "prank" - if self.origin is not None: - return f"{fn_name}({hexify(self.sender)}, {hexify(self.origin)})" - else: - return f"{fn_name}({hexify(self.sender)})" + return f"{fn_name}({str(self.active)})" def lookup(self, to: Address) -> PrankResult: + """ + If `to` is an eligible prank destination, return the active prank context. + + If `keep` is False, this resets the prank context. + """ + assert_address(to) - result = PrankResult() if ( self and not eq(to, hevm_cheat_code.address) and not eq(to, halmos_cheat_code.address) ): - result.caller = self.sender - result.origin = self.origin + result = self.active if not self.keep: self.stopPrank() - return result + return result - def prank(self, sender: Address) -> bool: - if self.sender is not None: + return NO_PRANK + + def prank(self, sender: Address, origin: Address | None = None) -> bool: + assert_address(sender) + if self.active: return False - self.sender = sender + + self.active = PrankResult(sender=sender, origin=origin) self.keep = False return True - def startPrank(self, addr: Address) -> bool: - assert_address(addr) - if self.sender is not None: - return False - self.sender = addr - self.keep = True - return True + def startPrank(self, sender: Address, origin: Address | None = None) -> bool: + result = self.prank(sender, origin) + self.keep = result if result else self.keep + return result def stopPrank(self) -> bool: - # stopPrank is allowed to call even when no active prank exists - self.sender = None - self.origin = None + # stopPrank calls are allowed even when no active prank exists + self.active = NO_PRANK self.keep = False return True From 15ffcdba450f88be7d9c66b28c45c301eef3d6e5 Mon Sep 17 00:00:00 2001 From: karmacoma Date: Thu, 1 Aug 2024 16:44:39 -0700 Subject: [PATCH 04/21] add an origin field to Exec and wire pranks to it --- src/halmos/__main__.py | 1 + src/halmos/sevm.py | 38 ++++++++++++++++++++++++++------------ src/halmos/utils.py | 4 ++++ tests/test_prank.py | 24 ++++++++++++++++++++++-- 4 files changed, 53 insertions(+), 14 deletions(-) diff --git a/src/halmos/__main__.py b/src/halmos/__main__.py index b7dd1ab6..19eeadf2 100644 --- a/src/halmos/__main__.py +++ b/src/halmos/__main__.py @@ -681,6 +681,7 @@ def run( jumpis={}, symbolic=args.symbolic_storage, prank=Prank(), # prank is reset after setUp() + origin=setup_ex.origin, # path=path, alias=setup_ex.alias.copy(), diff --git a/src/halmos/sevm.py b/src/halmos/sevm.py index f0b16ce5..4ba8d0f4 100644 --- a/src/halmos/sevm.py +++ b/src/halmos/sevm.py @@ -630,6 +630,7 @@ class Exec: # an execution path jumpis: Dict[str, Dict[bool, int]] # for loop detection symbolic: bool # symbolic or concrete storage prank: Prank + origin: Address addresses_to_delete: Set[Address] # path @@ -662,6 +663,7 @@ def __init__(self, **kwargs) -> None: self.jumpis = kwargs["jumpis"] self.symbolic = kwargs["symbolic"] self.prank = kwargs["prank"] + self.origin = kwargs["origin"] self.addresses_to_delete = kwargs.get("addresses_to_delete") or set() # self.path = kwargs["path"] @@ -724,6 +726,13 @@ def current_opcode(self) -> UnionType[int, BitVecRef]: def current_instruction(self) -> Instruction: return self.pgm.decode_instruction(self.pc) + def resolve_prank(self, to: Address) -> Tuple[Address, Address]: + # this potentially "consumes" the active prank + prank_result = self.prank.lookup(to) + caller = self.this if prank_result.sender is None else prank_result.sender + origin = f_origin() if prank_result.origin is None else prank_result.origin + return caller, origin + def set_code(self, who: Address, code: UnionType[ByteVec, Contract]) -> None: """ Sets the code at a given address. @@ -1558,16 +1567,14 @@ def call( if not ret_size >= 0: raise ValueError(ret_size) - prank_result = ex.prank.lookup(to) - caller = ex.this if prank_result.sender is None else prank_result.sender - origin = f_origin() if prank_result.origin is None else prank_result.origin + pranked_caller, pranked_origin = ex.resolve_prank(to) arg = ex.st.memory.slice(arg_loc, arg_loc + arg_size) def send_callvalue(condition=None) -> None: # no balance update for CALLCODE which transfers to itself if op == EVM.CALL: # TODO: revert if context is static - self.transfer_value(ex, caller, to, fund, condition) + self.transfer_value(ex, pranked_caller, to, fund, condition) def call_known(to: Address) -> None: # backup current state @@ -1580,7 +1587,7 @@ def call_known(to: Address) -> None: message = Message( target=to if op in [EVM.CALL, EVM.STATICCALL] else ex.this, - caller=caller if op != EVM.DELEGATECALL else ex.caller(), + caller=pranked_caller if op != EVM.DELEGATECALL else ex.caller(), value=fund if op != EVM.DELEGATECALL else ex.callvalue(), data=arg, is_static=(ex.context.message.is_static or op == EVM.STATICCALL), @@ -1612,6 +1619,7 @@ def callback(new_ex: Exec, stack, step_id): new_ex.jumpis = deepcopy(ex.jumpis) new_ex.symbolic = ex.symbolic new_ex.prank = deepcopy(ex.prank) + new_ex.origin = ex.origin # set return data (in memory) effective_ret_size = min(ret_size, new_ex.returndatasize()) @@ -1654,6 +1662,7 @@ def callback(new_ex: Exec, stack, step_id): jumpis={}, symbolic=ex.symbolic, prank=Prank(), + origin=pranked_origin, # path=ex.path, alias=ex.alias, @@ -1777,7 +1786,7 @@ def call_unknown() -> None: CallContext( message=Message( target=to, - caller=caller, + caller=pranked_caller, value=fund, data=ex.st.memory.slice(arg_loc, arg_loc + arg_size), call_scheme=op, @@ -1849,8 +1858,8 @@ def create( if op == EVM.CREATE2: salt = ex.st.pop() - # lookup prank - caller = ex.prank.lookup(ex.this, con_addr(0)) + # check if there is an active prank + pranked_caller, pranked_origin = ex.resolve_prank(address(ex.this)) # contract creation code create_hexcode = ex.st.memory.slice(loc, loc + size) @@ -1869,14 +1878,16 @@ def create( create_hexcode = bytes_to_bv_value(create_hexcode) code_hash = ex.sha3_data(create_hexcode) - hash_data = simplify(Concat(con(0xFF, 8), uint160(caller), salt, code_hash)) + hash_data = simplify( + Concat(con(0xFF, 8), uint160(pranked_caller), salt, code_hash) + ) new_addr = uint160(ex.sha3_data(hash_data)) else: raise HalmosException(f"Unknown CREATE opcode: {op}") message = Message( target=new_addr, - caller=caller, + caller=pranked_caller, value=value, data=create_hexcode, is_static=False, @@ -1910,7 +1921,7 @@ def create( ex.storage[new_addr] = {} # existing storage may not be empty and reset here # transfer value - self.transfer_value(ex, caller, new_addr, value) + self.transfer_value(ex, pranked_caller, new_addr, value) def callback(new_ex, stack, step_id): subcall = new_ex.context @@ -1974,6 +1985,7 @@ def callback(new_ex, stack, step_id): jumpis={}, symbolic=False, prank=Prank(), + origin=pranked_origin, # path=ex.path, alias=ex.alias, @@ -2094,6 +2106,7 @@ def create_branch(self, ex: Exec, cond: BitVecRef, target: int) -> Exec: jumpis=deepcopy(ex.jumpis), symbolic=ex.symbolic, prank=deepcopy(ex.prank), + origin=ex.origin, # path=new_path, alias=ex.alias.copy(), @@ -2293,7 +2306,7 @@ def finalize(ex: Exec): ex.st.push(uint256(ex.caller())) elif opcode == EVM.ORIGIN: - ex.st.push(uint256(f_origin())) + ex.st.push(uint256(ex.origin)) elif opcode == EVM.ADDRESS: ex.st.push(uint256(ex.this)) @@ -2628,6 +2641,7 @@ def mk_exec( jumpis={}, symbolic=symbolic, prank=Prank(), + origin=f_origin(), # path=path, alias={}, diff --git a/src/halmos/utils.py b/src/halmos/utils.py index 6ceed36b..a1c4fe1b 100644 --- a/src/halmos/utils.py +++ b/src/halmos/utils.py @@ -139,6 +139,10 @@ def int256(x: BitVecRef) -> BitVecRef: return simplify(SignExt(256 - bitsize, x)) +def address(x: Any) -> Address: + return uint(x, 160) + + def con(n: int, size_bits=256) -> Word: return BitVecVal(n, BitVecSorts[size_bits]) diff --git a/tests/test_prank.py b/tests/test_prank.py index b735f441..3c92b8cc 100644 --- a/tests/test_prank.py +++ b/tests/test_prank.py @@ -103,8 +103,18 @@ def test_lookup_no_active_prank(prank, other): def test_prank_lookup(prank, sender, other): - # when calling lookup() after prank() + # setup an active prank prank.prank(sender) + + # when calling lookup(to=) + for cheat_code in [hevm_cheat_code, halmos_cheat_code]: + result = prank.lookup(cheat_code.address) + + # then the active prank is ignored + assert result == NO_PRANK + assert prank # still active + + # finally, when calling lookup(to=other) result = prank.lookup(other) # then the active prank is returned @@ -116,8 +126,18 @@ def test_prank_lookup(prank, sender, other): def test_startPrank_lookup(prank, sender, origin, other): - # when calling lookup() after startPrank() + # setup an active prank prank.startPrank(sender, origin) + + # when calling lookup(to=) + for cheat_code in [hevm_cheat_code, halmos_cheat_code]: + result = prank.lookup(cheat_code.address) + + # then the active prank is ignored + assert result == NO_PRANK + assert prank # still active + + # finally, when calling lookup(to=other) result = prank.lookup(other) # then the active prank is returned From 73f4478bc3dffb9b65cffff555c1aefc6b568c59 Mon Sep 17 00:00:00 2001 From: karmacoma Date: Thu, 1 Aug 2024 17:21:13 -0700 Subject: [PATCH 05/21] finish wiring up new prank cheatcodes --- src/halmos/cheatcodes.py | 4 +- tests/expected/all.json | 13 ++++- tests/regression/test/Prank.t.sol | 82 ++++++++++++++++++++++++++----- 3 files changed, 82 insertions(+), 17 deletions(-) diff --git a/src/halmos/cheatcodes.py b/src/halmos/cheatcodes.py index 6b5477ec..c418b0f1 100644 --- a/src/halmos/cheatcodes.py +++ b/src/halmos/cheatcodes.py @@ -318,13 +318,13 @@ class hevm_cheat_code: prank_sig: int = 0xCA669FA7 # bytes4(keccak256("prank(address,address)")) - prank_addr_addr_sig: int = 0x42424242 + prank_addr_addr_sig: int = 0x47E50CCE # bytes4(keccak256("startPrank(address)")) start_prank_sig: int = 0x06447D56 # bytes4(keccak256("startPrank(address,address)")) - start_prank_addr_addr_sig: int = 0x42424242 + start_prank_addr_addr_sig: int = 0x45B56078 # bytes4(keccak256("stopPrank()")) stop_prank_sig: int = 0x90C5013B diff --git a/tests/expected/all.json b/tests/expected/all.json index dcf051e5..4f77220b 100644 --- a/tests/expected/all.json +++ b/tests/expected/all.json @@ -1337,7 +1337,7 @@ ], "test/Prank.t.sol:PrankTest": [ { - "name": "check_prank(address)", + "name": "check_prank(address,address)", "exitcode": 0, "num_models": 0, "models": null, @@ -1354,6 +1354,15 @@ "time": null, "num_bounded_loops": null }, + { + "name": "check_prank_ConstructorCreate2(address,bytes32)", + "exitcode": 0, + "num_models": 0, + "models": null, + "num_paths": null, + "time": null, + "num_bounded_loops": null + }, { "name": "check_prank_External(address)", "exitcode": 0, @@ -1409,7 +1418,7 @@ "num_bounded_loops": null }, { - "name": "check_startPrank(address)", + "name": "check_startPrank(address,address)", "exitcode": 0, "num_models": 0, "models": null, diff --git a/tests/regression/test/Prank.t.sol b/tests/regression/test/Prank.t.sol index 1da70d52..d5f65adf 100644 --- a/tests/regression/test/Prank.t.sol +++ b/tests/regression/test/Prank.t.sol @@ -7,13 +7,16 @@ contract Dummy { } contract Target { address public caller; + address public origin; - function setCaller(address addr) public { - caller = addr; + function setCaller(address _caller, address _origin) public { + caller = _caller; + origin = _origin; } function recordCaller() public { caller = msg.sender; + origin = tx.origin; } } @@ -60,31 +63,85 @@ contract PrankTest is Test { vm.prank(user); } - function check_prank(address user) public { + function check_prank(address user, address origin) public { vm.prank(user); + + // check that the prank is active target.recordCaller(); assert(target.caller() == user); + assert(target.origin() == tx.origin); // not pranked + // check that the prank is no longer active target.recordCaller(); assert(target.caller() == address(this)); + assert(target.origin() == tx.origin); + + //////////////////////////// + // check alternative form // + //////////////////////////// + + vm.prank(user, origin); + + // check that the prank is active + target.recordCaller(); + assert(target.caller() == user); + assert(target.origin() == origin); + + // check that the prank is no longer active + target.recordCaller(); + assert(target.caller() == address(this)); + assert(target.origin() == tx.origin); + } - function check_startPrank(address user) public { + function check_startPrank(address user, address origin) public { vm.startPrank(user); target.recordCaller(); assert(target.caller() == user); + assert(target.origin() == tx.origin); // not pranked - target.setCaller(address(this)); + target.setCaller(address(this), address(this)); assert(target.caller() == address(this)); + assert(target.origin() == address(this)); + // prank is still active until stopPrank() is called target.recordCaller(); assert(target.caller() == user); + assert(target.origin() == tx.origin); // not pranked vm.stopPrank(); + // prank is no longer active target.recordCaller(); assert(target.caller() == address(this)); + assert(target.origin() == tx.origin); + + //////////////////////////// + // check alternative form // + //////////////////////////// + + vm.startPrank(user, origin); + + target.recordCaller(); + assert(target.caller() == user); + assert(target.origin() == origin); + + target.setCaller(address(this), address(this)); + assert(target.caller() == address(this)); + assert(target.origin() == address(this)); + + // prank is still active until stopPrank() is called + target.recordCaller(); + assert(target.caller() == user); + assert(target.origin() == origin); + + vm.stopPrank(); + + // prank is no longer active + target.recordCaller(); + assert(target.caller() == address(this)); + assert(target.origin() == tx.origin); } function check_prank_Internal(address user) public { @@ -94,13 +151,13 @@ contract PrankTest is Test { } function check_prank_External(address user) public { - ext.prank(user); // prank isn't propagated beyond the vm boundry + ext.prank(user); // prank isn't propagated beyond the vm boundary target.recordCaller(); assert(target.caller() == address(this)); } function check_prank_ExternalSelf(address user) public { - this.prank(user); // prank isn't propagated beyond the vm boundry + this.prank(user); // prank isn't propagated beyond the vm boundary target.recordCaller(); assert(target.caller() == address(this)); } @@ -146,10 +203,9 @@ contract PrankTest is Test { assert(recorder.caller() == user); } - // TODO: uncomment when we add CREATE2 support - // function check_prank_ConstructorCreate2(address user, bytes32 salt) public { - // vm.prank(user); - // ConstructorRecorder recorder = new ConstructorRecorder{salt:salt}(); - // assert(recorder.caller() == user); - // } + function check_prank_ConstructorCreate2(address user, bytes32 salt) public { + vm.prank(user); + ConstructorRecorder recorder = new ConstructorRecorder{salt:salt}(); + assert(recorder.caller() == user); + } } From 1c3373bc7dd58bf1f9a1587f338fa1d3f6d71399 Mon Sep 17 00:00:00 2001 From: karmacoma Date: Fri, 2 Aug 2024 08:56:31 -0700 Subject: [PATCH 06/21] simplify prank Target: recordCaller() -> reset() --- tests/regression/test/Prank.t.sol | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/regression/test/Prank.t.sol b/tests/regression/test/Prank.t.sol index d5f65adf..7eac8d99 100644 --- a/tests/regression/test/Prank.t.sol +++ b/tests/regression/test/Prank.t.sol @@ -9,9 +9,9 @@ contract Target { address public caller; address public origin; - function setCaller(address _caller, address _origin) public { - caller = _caller; - origin = _origin; + function reset() public { + caller = address(0); + origin = address(0); } function recordCaller() public { @@ -101,9 +101,9 @@ contract PrankTest is Test { assert(target.caller() == user); assert(target.origin() == tx.origin); // not pranked - target.setCaller(address(this), address(this)); - assert(target.caller() == address(this)); - assert(target.origin() == address(this)); + target.reset(); + assert(target.caller() == address(0)); + assert(target.origin() == address(0)); // prank is still active until stopPrank() is called target.recordCaller(); @@ -127,9 +127,9 @@ contract PrankTest is Test { assert(target.caller() == user); assert(target.origin() == origin); - target.setCaller(address(this), address(this)); - assert(target.caller() == address(this)); - assert(target.origin() == address(this)); + target.reset(); + assert(target.caller() == address(0)); + assert(target.origin() == address(0)); // prank is still active until stopPrank() is called target.recordCaller(); From 9939ca573b21d6bfd226408aae9e67cf6679e61f Mon Sep 17 00:00:00 2001 From: karmacoma Date: Fri, 2 Aug 2024 10:21:16 -0700 Subject: [PATCH 07/21] update tests/lib/multicaller to v1.3.2 (it removes the extra dependency on forge-std) --- tests/lib/multicaller | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lib/multicaller b/tests/lib/multicaller index b4a0dd03..b7ef6206 160000 --- a/tests/lib/multicaller +++ b/tests/lib/multicaller @@ -1 +1 @@ -Subproject commit b4a0dd037f1d770b2e9ae0b80bbd989707df43d0 +Subproject commit b7ef620605f426a93406248dcb005f6cead30673 From 3ec4901371484922d79b3e55d5698c91890e705f Mon Sep 17 00:00:00 2001 From: karmacoma Date: Fri, 2 Aug 2024 10:26:46 -0700 Subject: [PATCH 08/21] remove shallow from .gitmodules --- .gitmodules | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.gitmodules b/.gitmodules index e5346aeb..881921c8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,24 +1,18 @@ [submodule "tests/lib/forge-std"] path = tests/lib/forge-std url = https://github.com/foundry-rs/forge-std - shallow = true [submodule "tests/lib/halmos-cheatcodes"] path = tests/lib/halmos-cheatcodes url = https://github.com/a16z/halmos-cheatcodes - shallow = true [submodule "tests/lib/openzeppelin-contracts"] path = tests/lib/openzeppelin-contracts url = https://github.com/OpenZeppelin/openzeppelin-contracts - shallow = true [submodule "tests/lib/solmate"] path = tests/lib/solmate url = https://github.com/transmissions11/solmate - shallow = true [submodule "tests/lib/solady"] path = tests/lib/solady url = https://github.com/Vectorized/solady - shallow = true [submodule "tests/lib/multicaller"] path = tests/lib/multicaller url = https://github.com/Vectorized/multicaller - shallow = true From 324f77d02b18ac94372b0c7e8cbddae861b7e119 Mon Sep 17 00:00:00 2001 From: karmacoma Date: Fri, 2 Aug 2024 10:27:20 -0700 Subject: [PATCH 09/21] delete multicaller --- tests/lib/multicaller | 1 - 1 file changed, 1 deletion(-) delete mode 160000 tests/lib/multicaller diff --git a/tests/lib/multicaller b/tests/lib/multicaller deleted file mode 160000 index b7ef6206..00000000 --- a/tests/lib/multicaller +++ /dev/null @@ -1 +0,0 @@ -Subproject commit b7ef620605f426a93406248dcb005f6cead30673 From b6c1e4869d1692f3eb4b72aedc6cf1d057720f3e Mon Sep 17 00:00:00 2001 From: karmacoma Date: Fri, 2 Aug 2024 10:34:38 -0700 Subject: [PATCH 10/21] replace multicaller submodule with a snapshot of the file we need --- .gitmodules | 3 - examples/simple/remappings.txt | 2 +- .../src/multicaller/MulticallerWithSender.sol | 151 ++++++++++++++++++ 3 files changed, 152 insertions(+), 4 deletions(-) create mode 100644 examples/simple/src/multicaller/MulticallerWithSender.sol diff --git a/.gitmodules b/.gitmodules index 881921c8..2ecec2e0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -13,6 +13,3 @@ [submodule "tests/lib/solady"] path = tests/lib/solady url = https://github.com/Vectorized/solady -[submodule "tests/lib/multicaller"] - path = tests/lib/multicaller - url = https://github.com/Vectorized/multicaller diff --git a/examples/simple/remappings.txt b/examples/simple/remappings.txt index 37cf11ec..887332d9 100644 --- a/examples/simple/remappings.txt +++ b/examples/simple/remappings.txt @@ -1 +1 @@ -multicaller/=../../tests/lib/multicaller/src/ +multicaller/=src/multicaller/ diff --git a/examples/simple/src/multicaller/MulticallerWithSender.sol b/examples/simple/src/multicaller/MulticallerWithSender.sol new file mode 100644 index 00000000..47240298 --- /dev/null +++ b/examples/simple/src/multicaller/MulticallerWithSender.sol @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.4; + +/// from Vectorized/multicaller@v1.3.2 + +/** + * @title MulticallerWithSender + * @author vectorized.eth + * @notice Contract that allows for efficient aggregation of multiple calls + * in a single transaction, while "forwarding" the `msg.sender`. + */ +contract MulticallerWithSender { + // ============================================================= + // ERRORS + // ============================================================= + + /** + * @dev The lengths of the input arrays are not the same. + */ + error ArrayLengthsMismatch(); + + /** + * @dev This function does not support reentrancy. + */ + error Reentrancy(); + + // ============================================================= + // CONSTRUCTOR + // ============================================================= + + constructor() payable { + assembly { + // Throughout this code, we will abuse returndatasize + // in place of zero anywhere before a call to save a bit of gas. + // We will use storage slot zero to store the caller at + // bits [0..159] and reentrancy guard flag at bit 160. + sstore(returndatasize(), shl(160, 1)) + } + } + + // ============================================================= + // AGGREGATION OPERATIONS + // ============================================================= + + /** + * @dev Returns the address that called `aggregateWithSender` on this contract. + * The value is always the zero address outside a transaction. + */ + receive() external payable { + assembly { + mstore(returndatasize(), and(sub(shl(160, 1), 1), sload(returndatasize()))) + return(returndatasize(), 0x20) + } + } + + /** + * @dev Aggregates multiple calls in a single transaction. + * This method will set `sender` to the `msg.sender` temporarily + * for the span of its execution. + * This method does not support reentrancy. + * @param targets An array of addresses to call. + * @param data An array of calldata to forward to the targets. + * @param values How much ETH to forward to each target. + * @return An array of the returndata from each call. + */ + function aggregateWithSender( + address[] calldata targets, + bytes[] calldata data, + uint256[] calldata values + ) external payable returns (bytes[] memory) { + assembly { + if iszero(and(eq(targets.length, data.length), eq(data.length, values.length))) { + // Store the function selector of `ArrayLengthsMismatch()`. + mstore(returndatasize(), 0x3b800a46) + // Revert with (offset, size). + revert(0x1c, 0x04) + } + + if iszero(and(sload(returndatasize()), shl(160, 1))) { + // Store the function selector of `Reentrancy()`. + mstore(returndatasize(), 0xab143c06) + // Revert with (offset, size). + revert(0x1c, 0x04) + } + + mstore(returndatasize(), 0x20) // Store the memory offset of the `results`. + mstore(0x20, data.length) // Store `data.length` into `results`. + // Early return if no data. + if iszero(data.length) { return(returndatasize(), 0x40) } + + // Set the sender slot temporarily for the span of this transaction. + sstore(returndatasize(), caller()) + + let results := 0x40 + // Left shift by 5 is equivalent to multiplying by 0x20. + data.length := shl(5, data.length) + // Copy the offsets from calldata into memory. + calldatacopy(results, data.offset, data.length) + // Offset into `results`. + let resultsOffset := data.length + // Pointer to the end of `results`. + // Recycle `data.length` to avoid stack too deep. + data.length := add(results, data.length) + + for {} 1 {} { + // The offset of the current bytes in the calldata. + let o := add(data.offset, mload(results)) + let memPtr := add(resultsOffset, 0x40) + // Copy the current bytes from calldata to the memory. + calldatacopy( + memPtr, + add(o, 0x20), // The offset of the current bytes' bytes. + calldataload(o) // The length of the current bytes. + ) + if iszero( + call( + gas(), // Remaining gas. + calldataload(targets.offset), // Address to call. + calldataload(values.offset), // ETH to send. + memPtr, // Start of input calldata in memory. + calldataload(o), // Size of input calldata. + 0x00, // We will use returndatacopy instead. + 0x00 // We will use returndatacopy instead. + ) + ) { + // Bubble up the revert if the call reverts. + returndatacopy(0x00, 0x00, returndatasize()) + revert(0x00, returndatasize()) + } + // Advance the `targets.offset`. + targets.offset := add(targets.offset, 0x20) + // Advance the `values.offset`. + values.offset := add(values.offset, 0x20) + // Append the current `resultsOffset` into `results`. + mstore(results, resultsOffset) + results := add(results, 0x20) + // Append the returndatasize, and the returndata. + mstore(memPtr, returndatasize()) + returndatacopy(add(memPtr, 0x20), 0x00, returndatasize()) + // Advance the `resultsOffset` by `returndatasize() + 0x20`, + // rounded up to the next multiple of 0x20. + resultsOffset := and(add(add(resultsOffset, returndatasize()), 0x3f), not(0x1f)) + if iszero(lt(results, data.length)) { break } + } + // Restore the `sender` slot. + sstore(0, shl(160, 1)) + // Direct return. + return(0x00, add(resultsOffset, 0x40)) + } + } +} From c051be4d4120381866665ef9843834bbd1c3de4b Mon Sep 17 00:00:00 2001 From: karmacoma Date: Fri, 2 Aug 2024 10:50:53 -0700 Subject: [PATCH 11/21] update tests/lib/openzeppelin-contracts@v5.0.2 --- tests/lib/openzeppelin-contracts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lib/openzeppelin-contracts b/tests/lib/openzeppelin-contracts index 21bb89ef..dbb6104c 160000 --- a/tests/lib/openzeppelin-contracts +++ b/tests/lib/openzeppelin-contracts @@ -1 +1 @@ -Subproject commit 21bb89ef5bfc789b9333eb05e3ba2b7b284ac77c +Subproject commit dbb6104ce834628e473d2173bbc9d47f81a9eec3 From 8bb0d6fb5b8f9c3e133ccfbd169123cb0383ad6c Mon Sep 17 00:00:00 2001 From: karmacoma Date: Fri, 2 Aug 2024 10:54:35 -0700 Subject: [PATCH 12/21] test.yml: recursively checkout submodules The foundry step is failing with weird errors: Submodule 'tests/lib/forge-std' (https://github.com/foundry-rs/forge-std) registered for path '../lib/forge-std' Submodule 'tests/lib/halmos-cheatcodes' (https://github.com/a16z/halmos-cheatcodes) registered for path '../lib/halmos-cheatcodes' Submodule 'tests/lib/openzeppelin-contracts' (https://github.com/OpenZeppelin/openzeppelin-contracts) registered for path '../lib/openzeppelin-contracts' Submodule 'tests/lib/solady' (https://github.com/Vectorized/solady) registered for path '../lib/solady' Submodule 'tests/lib/solmate' (https://github.com/transmissions11/solmate) registered for path '../lib/solmate' fatal: not a git repository: /home/runner/work/halmos/halmos/tests/lib/forge-std/../../../.git/modules/tests/lib/forge-std Failed to clone 'tests/lib/forge-std'. Retry scheduled fatal: destination path '/home/runner/work/halmos/halmos/tests/lib/halmos-cheatcodes' already exists and is not an empty directory. fatal: clone of 'https://github.com/a16z/halmos-cheatcodes' into submodule path '/home/runner/work/halmos/halmos/tests/lib/halmos-cheatcodes' failed Failed to clone 'tests/lib/halmos-cheatcodes'. Retry scheduled --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 769fa670..c7ba2e45 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -34,7 +34,7 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 with: - submodules: false + submodules: recursive - name: Install foundry uses: foundry-rs/foundry-toolchain@v1 From 707b9e1c02e4f0142681bcdd8bd135b7e34430b9 Mon Sep 17 00:00:00 2001 From: karmacoma Date: Fri, 2 Aug 2024 11:01:41 -0700 Subject: [PATCH 13/21] test.yml: add --debug to halmos options --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c7ba2e45..75345e9e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -53,4 +53,4 @@ jobs: run: python -m pip install -e . - name: Run pytest - run: pytest -n 4 -v -k "not long and not ffi" --ignore=tests/lib --halmos-options="-st ${{ matrix.parallel }} --storage-layout ${{ matrix.storage-layout }} --solver-timeout-assertion 0 ${{ inputs.halmos-options }}" ${{ inputs.pytest-options }} + run: pytest -n 4 -v -k "not long and not ffi" --ignore=tests/lib --halmos-options="--debug -st ${{ matrix.parallel }} --storage-layout ${{ matrix.storage-layout }} --solver-timeout-assertion 0 ${{ inputs.halmos-options }}" ${{ inputs.pytest-options }} From 656c359c838acaa0d7cf4222c9691475445c52e2 Mon Sep 17 00:00:00 2001 From: karmacoma Date: Fri, 2 Aug 2024 11:04:21 -0700 Subject: [PATCH 14/21] test.yml: get back to a single pytest worker Suspecting race conditions here: ``` Build failed: ['forge', 'build', '--ast', '--root', 'tests/regression', '--extra-output', 'storageLayout', 'metadata'] ----------------------------- Captured stderr call ----------------------------- Failed to install solc 0.8.26: Text file busy (os error 26) Error: Text file busy (os error 26) ``` --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 75345e9e..f7166f15 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -53,4 +53,4 @@ jobs: run: python -m pip install -e . - name: Run pytest - run: pytest -n 4 -v -k "not long and not ffi" --ignore=tests/lib --halmos-options="--debug -st ${{ matrix.parallel }} --storage-layout ${{ matrix.storage-layout }} --solver-timeout-assertion 0 ${{ inputs.halmos-options }}" ${{ inputs.pytest-options }} + run: pytest -n 1 -v -k "not long and not ffi" --ignore=tests/lib --halmos-options="--debug -st ${{ matrix.parallel }} --storage-layout ${{ matrix.storage-layout }} --solver-timeout-assertion 0 ${{ inputs.halmos-options }}" ${{ inputs.pytest-options }} From de64d95dc2b45ea429d5d05a2d94408ec62d7320 Mon Sep 17 00:00:00 2001 From: karmacoma Date: Fri, 2 Aug 2024 11:10:01 -0700 Subject: [PATCH 15/21] Revert "test.yml: recursively checkout submodules" This reverts commit 8bb0d6fb5b8f9c3e133ccfbd169123cb0383ad6c. --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f7166f15..336aece9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -34,7 +34,7 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 with: - submodules: recursive + submodules: false - name: Install foundry uses: foundry-rs/foundry-toolchain@v1 From 1689032b344abc4c065c8ea252ec10f692b35c49 Mon Sep 17 00:00:00 2001 From: karmacoma Date: Fri, 2 Aug 2024 17:37:13 -0700 Subject: [PATCH 16/21] add Prank test with nested contexts --- tests/regression/test/Prank.t.sol | 42 ++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/tests/regression/test/Prank.t.sol b/tests/regression/test/Prank.t.sol index 7eac8d99..66181881 100644 --- a/tests/regression/test/Prank.t.sol +++ b/tests/regression/test/Prank.t.sol @@ -6,9 +6,14 @@ import "forge-std/Test.sol"; contract Dummy { } contract Target { + Target public inner; address public caller; address public origin; + function setInnerTarget(Target _inner) public { + inner = _inner; + } + function reset() public { caller = address(0); origin = address(0); @@ -17,6 +22,10 @@ contract Target { function recordCaller() public { caller = msg.sender; origin = tx.origin; + + if (address(inner) != address(0)) { + inner.recordCaller(); + } } } @@ -51,11 +60,14 @@ contract PrankSetUpTest is Test { contract PrankTest is Test { Target target; + Target inner; Ext ext; Dummy dummy; function setUp() public { target = new Target(); + inner = new Target(); + target.setInnerTarget(inner); ext = new Ext(); } @@ -63,35 +75,41 @@ contract PrankTest is Test { vm.prank(user); } - function check_prank(address user, address origin) public { + function checkNotPranked(Target _target, address realCaller) internal { + assert(_target.caller() == realCaller); + assert(_target.origin() == tx.origin); + } + + function check_prank(address user) public { vm.prank(user); - // check that the prank is active + // the outer call is pranked target.recordCaller(); assert(target.caller() == user); assert(target.origin() == tx.origin); // not pranked + // but the inner call is not pranked + checkNotPranked(inner, address(target)); + // check that the prank is no longer active target.recordCaller(); - assert(target.caller() == address(this)); - assert(target.origin() == tx.origin); - - //////////////////////////// - // check alternative form // - //////////////////////////// + checkNotPranked(target, address(this)); + } + function check_prank(address user, address origin) public { vm.prank(user, origin); - // check that the prank is active + // the outer call is pranked target.recordCaller(); assert(target.caller() == user); assert(target.origin() == origin); + // but the inner call is not pranked + checkNotPranked(inner, address(target)); + // check that the prank is no longer active target.recordCaller(); - assert(target.caller() == address(this)); - assert(target.origin() == tx.origin); - + checkNotPranked(target, address(this)); } function check_startPrank(address user, address origin) public { From 09b806656b164bc58043f4c674f873d24e2066ea Mon Sep 17 00:00:00 2001 From: karmacoma Date: Fri, 2 Aug 2024 18:32:27 -0700 Subject: [PATCH 17/21] add more Prank tests with nested contexts --- tests/regression/test/Prank.t.sol | 41 +++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/tests/regression/test/Prank.t.sol b/tests/regression/test/Prank.t.sol index 66181881..6dca3ff5 100644 --- a/tests/regression/test/Prank.t.sol +++ b/tests/regression/test/Prank.t.sol @@ -17,6 +17,10 @@ contract Target { function reset() public { caller = address(0); origin = address(0); + + if (address(inner) != address(0)) { + inner.reset(); + } } function recordCaller() public { @@ -76,11 +80,11 @@ contract PrankTest is Test { } function checkNotPranked(Target _target, address realCaller) internal { - assert(_target.caller() == realCaller); - assert(_target.origin() == tx.origin); + assertEq(_target.caller(), realCaller); + assertEq(_target.origin(), tx.origin); } - function check_prank(address user) public { + function check_prank_single(address user) public { vm.prank(user); // the outer call is pranked @@ -96,7 +100,7 @@ contract PrankTest is Test { checkNotPranked(target, address(this)); } - function check_prank(address user, address origin) public { + function check_prank_double(address user, address origin) public { vm.prank(user, origin); // the outer call is pranked @@ -112,54 +116,65 @@ contract PrankTest is Test { checkNotPranked(target, address(this)); } - function check_startPrank(address user, address origin) public { + function check_startPrank_single(address user) public { vm.startPrank(user); + // the outer call is pranked target.recordCaller(); assert(target.caller() == user); assert(target.origin() == tx.origin); // not pranked + // the inner call is not pranked + checkNotPranked(inner, address(target)); + target.reset(); assert(target.caller() == address(0)); assert(target.origin() == address(0)); + assert(inner.caller() == address(0)); + assert(inner.origin() == address(0)); // prank is still active until stopPrank() is called target.recordCaller(); assert(target.caller() == user); assert(target.origin() == tx.origin); // not pranked + checkNotPranked(inner, address(target)); vm.stopPrank(); // prank is no longer active target.recordCaller(); - assert(target.caller() == address(this)); - assert(target.origin() == tx.origin); - - //////////////////////////// - // check alternative form // - //////////////////////////// + checkNotPranked(target, address(this)); + checkNotPranked(inner, address(target)); + } + function check_startPrank_double(address user, address origin) public { vm.startPrank(user, origin); target.recordCaller(); assert(target.caller() == user); assert(target.origin() == origin); + assert(inner.caller() == address(target)); // not pranked + assert(inner.origin() == origin); // pranked target.reset(); assert(target.caller() == address(0)); assert(target.origin() == address(0)); + assert(inner.caller() == address(0)); + assert(inner.origin() == address(0)); // prank is still active until stopPrank() is called target.recordCaller(); assert(target.caller() == user); assert(target.origin() == origin); + assert(inner.caller() == address(target)); // not pranked + assert(inner.origin() == origin); // pranked vm.stopPrank(); // prank is no longer active target.recordCaller(); - assert(target.caller() == address(this)); - assert(target.origin() == tx.origin); + checkNotPranked(target, address(this)); + checkNotPranked(inner, address(target)); } function check_prank_Internal(address user) public { From df2c7a3e77afa3712f10d1d5e6d89e5f802b8fdf Mon Sep 17 00:00:00 2001 From: karmacoma Date: Fri, 9 Aug 2024 15:12:39 -0700 Subject: [PATCH 18/21] wire Prank inside CallContext rather than Exec --- src/halmos/__main__.py | 35 ++++++--------- src/halmos/cheatcodes.py | 15 +++---- src/halmos/sevm.py | 73 ++++++++++++------------------- tests/expected/all.json | 38 +++++++++++----- tests/regression/test/Prank.t.sol | 5 ++- tests/test_sevm.py | 8 ++-- 6 files changed, 81 insertions(+), 93 deletions(-) diff --git a/src/halmos/__main__.py b/src/halmos/__main__.py index 59c959c6..28d0e645 100644 --- a/src/halmos/__main__.py +++ b/src/halmos/__main__.py @@ -188,10 +188,7 @@ def mk_addr(name: str) -> Address: def mk_caller(args: HalmosConfig) -> Address: - if args.symbolic_msg_sender: - return mk_addr("msg_sender") - else: - return magic_address + return mk_addr("msg_sender") if args.symbolic_msg_sender else magic_address def mk_this() -> Address: @@ -336,27 +333,24 @@ def render_trace(context: CallContext, file=sys.stdout) -> None: def run_bytecode(hexcode: str, args: HalmosConfig) -> List[Exec]: solver = mk_solver(args) - contract = Contract.from_hexcode(hexcode) - balance = mk_balance() - block = mk_block() this = mk_this() - message = Message( target=this, caller=mk_caller(args), + origin=mk_addr("tx_origin"), value=mk_callvalue(), data=ByteVec(), call_scheme=EVM.CALL, ) + contract = Contract.from_hexcode(hexcode) sevm = SEVM(args) ex = sevm.mk_exec( code={this: contract}, storage={this: {}}, - balance=balance, - block=block, + balance=mk_balance(), + block=mk_block(), context=CallContext(message=message), - this=this, pgm=contract, symbolic=args.symbolic_storage, path=Path(solver), @@ -366,7 +360,6 @@ def run_bytecode(hexcode: str, args: HalmosConfig) -> List[Exec]: for idx, ex in enumerate(exs): result_exs.append(ex) - opcode = ex.current_opcode() error = ex.context.output.error returndata = ex.context.output.data @@ -403,6 +396,7 @@ def deploy_test( message = Message( target=this, caller=mk_caller(args), + origin=mk_addr("tx_origin"), value=0, data=ByteVec(), call_scheme=EVM.CREATE, @@ -414,7 +408,6 @@ def deploy_test( balance=mk_balance(), block=mk_block(), context=CallContext(message=message), - this=this, pgm=None, # to be added symbolic=False, path=Path(mk_solver(args)), @@ -463,7 +456,6 @@ def deploy_test( ex.st = State() ex.context.output = CallOutput() ex.jumpis = {} - ex.prank = Prank() return ex @@ -492,10 +484,12 @@ def setup( dyn_param_size = [] # TODO: propagate to run mk_calldata(abi, setup_info, calldata, dyn_param_size, args) + parent_message = setup_ex.message() setup_ex.context = CallContext( message=Message( - target=setup_ex.message().target, - caller=setup_ex.message().caller, + target=parent_message.target, + caller=parent_message.caller, + origin=parent_message.origin, value=0, data=calldata, call_scheme=EVM.CALL, @@ -503,7 +497,6 @@ def setup( ) setup_exs_all = sevm.run(setup_ex) - setup_exs_no_error = [] for idx, setup_ex in enumerate(setup_exs_all): @@ -640,8 +633,9 @@ def run( mk_calldata(abi, fun_info, cd, dyn_param_size, args) message = Message( - target=setup_ex.this, + target=setup_ex.this(), caller=setup_ex.caller(), + origin=setup_ex.origin(), value=0, data=cd, call_scheme=EVM.CALL, @@ -669,15 +663,12 @@ def run( # context=CallContext(message=message), callback=None, - this=setup_ex.this, # - pgm=setup_ex.code[setup_ex.this], + pgm=setup_ex.code[setup_ex.this()], pc=0, st=State(), jumpis={}, symbolic=args.symbolic_storage, - prank=Prank(), # prank is reset after setUp() - origin=setup_ex.origin, # path=path, alias=setup_ex.alias.copy(), diff --git a/src/halmos/cheatcodes.py b/src/halmos/cheatcodes.py index c418b0f1..8c48be1a 100644 --- a/src/halmos/cheatcodes.py +++ b/src/halmos/cheatcodes.py @@ -96,15 +96,14 @@ def __str__(self) -> str: @dataclass class Prank: """ - A mutable object to store current prank context, one per execution context. + A mutable object to store the current prank context. Because it's mutable, it must be copied across contexts. Can test for the existence of an active prank with `if prank: ...` A prank is active if either sender or origin is set. - Technically supports pranking origin but not sender, which is not - possible with the current cheatcodes: + - prank(address) sets sender - prank(address, address) sets both sender and origin """ @@ -423,7 +422,7 @@ def handle(sevm, ex, arg: ByteVec, stack, step_id) -> Optional[ByteVec]: # vm.prank(address) elif funsig == hevm_cheat_code.prank_sig: sender = uint160(arg.get_word(4)) - result = ex.prank.prank(sender) + result = ex.context.prank.prank(sender) if not result: raise HalmosException("You have an active prank already.") return ret @@ -432,7 +431,7 @@ def handle(sevm, ex, arg: ByteVec, stack, step_id) -> Optional[ByteVec]: elif funsig == hevm_cheat_code.prank_addr_addr_sig: sender = uint160(arg.get_word(4)) origin = uint160(arg.get_word(36)) - result = ex.prank.prank(sender, origin) + result = ex.context.prank.prank(sender, origin) if not result: raise HalmosException("You have an active prank already.") return ret @@ -440,7 +439,7 @@ def handle(sevm, ex, arg: ByteVec, stack, step_id) -> Optional[ByteVec]: # vm.startPrank(address) elif funsig == hevm_cheat_code.start_prank_sig: address = uint160(arg.get_word(4)) - result = ex.prank.startPrank(address) + result = ex.context.prank.startPrank(address) if not result: raise HalmosException("You have an active prank already.") return ret @@ -449,14 +448,14 @@ def handle(sevm, ex, arg: ByteVec, stack, step_id) -> Optional[ByteVec]: elif funsig == hevm_cheat_code.start_prank_addr_addr_sig: sender = uint160(arg.get_word(4)) origin = uint160(arg.get_word(36)) - result = ex.prank.startPrank(sender, origin) + result = ex.context.prank.startPrank(sender, origin) if not result: raise HalmosException("You have an active prank already.") return ret # vm.stopPrank() elif funsig == hevm_cheat_code.stop_prank_sig: - ex.prank.stopPrank() + ex.context.prank.stopPrank() return ret # vm.deal(address,uint256) diff --git a/src/halmos/sevm.py b/src/halmos/sevm.py index 4ba8d0f4..ec7f13e9 100644 --- a/src/halmos/sevm.py +++ b/src/halmos/sevm.py @@ -59,8 +59,6 @@ f_gas = Function("f_gas", BitVecSort256, BitVecSort256) # gasprice() f_gasprice = Function("f_gasprice", BitVecSort256) -# origin() -f_origin = Function("f_origin", BitVecSort160) # uninterpreted arithmetic f_div = Function("f_evm_bvudiv", BitVecSort256, BitVecSort256, BitVecSort256) @@ -192,6 +190,7 @@ class EventLog: class Message: target: Address caller: Address + origin: Address value: Word data: ByteVec @@ -232,6 +231,7 @@ class CallContext: output: CallOutput = field(default_factory=CallOutput) depth: int = 1 trace: List[TraceElement] = field(default_factory=list) + prank: Prank = field(default_factory=Prank) def subcalls(self) -> Iterator["CallContext"]: return iter(t for t in self.trace if isinstance(t, CallContext)) @@ -623,14 +623,11 @@ class Exec: # an execution path callback: Optional[Callable] # to be called when returning back to parent context # vm state - this: Address # current account address pgm: Contract pc: int st: State # stack and memory jumpis: Dict[str, Dict[bool, int]] # for loop detection symbolic: bool # symbolic or concrete storage - prank: Prank - origin: Address addresses_to_delete: Set[Address] # path @@ -656,14 +653,11 @@ def __init__(self, **kwargs) -> None: self.context = kwargs["context"] self.callback = kwargs["callback"] # - self.this = kwargs["this"] self.pgm = kwargs["pgm"] self.pc = kwargs["pc"] self.st = kwargs["st"] self.jumpis = kwargs["jumpis"] self.symbolic = kwargs["symbolic"] - self.prank = kwargs["prank"] - self.origin = kwargs["origin"] self.addresses_to_delete = kwargs.get("addresses_to_delete") or set() # self.path = kwargs["path"] @@ -677,13 +671,13 @@ def __init__(self, **kwargs) -> None: self.known_keys = kwargs["known_keys"] if "known_keys" in kwargs else {} self.known_sigs = kwargs["known_sigs"] if "known_sigs" in kwargs else {} - assert_address(self.context.message.target) - assert_address(self.context.message.caller) - assert_address(self.this) + assert_address(self.origin()) + assert_address(self.caller()) + assert_address(self.this()) def context_str(self) -> str: opcode = self.current_opcode() - return f"addr={hexify(self.this)} pc={self.pc} insn={mnemonic(opcode)}" + return f"addr={hexify(self.this())} pc={self.pc} insn={mnemonic(opcode)}" def halt( self, @@ -714,9 +708,15 @@ def calldata(self) -> ByteVec: def caller(self): return self.message().caller + def origin(self): + return self.message().origin + def callvalue(self): return self.message().value + def this(self): + return self.message().target + def message(self): return self.context.message @@ -728,9 +728,9 @@ def current_instruction(self) -> Instruction: def resolve_prank(self, to: Address) -> Tuple[Address, Address]: # this potentially "consumes" the active prank - prank_result = self.prank.lookup(to) - caller = self.this if prank_result.sender is None else prank_result.sender - origin = f_origin() if prank_result.origin is None else prank_result.origin + prank_result = self.context.prank.lookup(to) + caller = self.this() if prank_result.sender is None else prank_result.sender + origin = self.origin() if prank_result.origin is None else prank_result.origin return caller, origin def set_code(self, who: Address, code: UnionType[ByteVec, Contract]) -> None: @@ -748,7 +748,7 @@ def dump(self, print_mem=False) -> str: return hexify( "".join( [ - f"PC: {self.this} {self.pc} {mnemonic(self.current_opcode())}\n", + f"PC: {self.this()} {self.pc} {mnemonic(self.current_opcode())}\n", self.st.dump(print_mem=print_mem), f"Balance: {self.balance}\n", f"Storage:\n", @@ -1586,8 +1586,9 @@ def call_known(to: Address) -> None: send_callvalue() message = Message( - target=to if op in [EVM.CALL, EVM.STATICCALL] else ex.this, + target=to if op in [EVM.CALL, EVM.STATICCALL] else ex.this(), caller=pranked_caller if op != EVM.DELEGATECALL else ex.caller(), + origin=pranked_origin, value=fund if op != EVM.DELEGATECALL else ex.callvalue(), data=arg, is_static=(ex.context.message.is_static or op == EVM.STATICCALL), @@ -1602,8 +1603,6 @@ def callback(new_ex: Exec, stack, step_id): # restore context new_ex.context = deepcopy(ex.context) new_ex.context.trace.append(subcall) - new_ex.this = ex.this - new_ex.callback = ex.callback if subcall.is_stuck(): @@ -1618,8 +1617,6 @@ def callback(new_ex: Exec, stack, step_id): new_ex.st = deepcopy(ex.st) new_ex.jumpis = deepcopy(ex.jumpis) new_ex.symbolic = ex.symbolic - new_ex.prank = deepcopy(ex.prank) - new_ex.origin = ex.origin # set return data (in memory) effective_ret_size = min(ret_size, new_ex.returndatasize()) @@ -1654,15 +1651,12 @@ def callback(new_ex: Exec, stack, step_id): # context=CallContext(message=message, depth=ex.context.depth + 1), callback=callback, - this=message.target, # pgm=ex.code[to], pc=0, st=State(), jumpis={}, symbolic=ex.symbolic, - prank=Prank(), - origin=pranked_origin, # path=ex.path, alias=ex.alias, @@ -1787,6 +1781,7 @@ def call_unknown() -> None: message=Message( target=to, caller=pranked_caller, + origin=pranked_origin, value=fund, data=ex.st.memory.slice(arg_loc, arg_loc + arg_size), call_scheme=op, @@ -1859,7 +1854,7 @@ def create( salt = ex.st.pop() # check if there is an active prank - pranked_caller, pranked_origin = ex.resolve_prank(address(ex.this)) + pranked_caller, pranked_origin = ex.resolve_prank(con_addr(0)) # contract creation code create_hexcode = ex.st.memory.slice(loc, loc + size) @@ -1888,6 +1883,7 @@ def create( message = Message( target=new_addr, caller=pranked_caller, + origin=pranked_origin, value=value, data=create_hexcode, is_static=False, @@ -1930,18 +1926,14 @@ def callback(new_ex, stack, step_id): # pessimistic copy because the subcall results may diverge new_ex.context = deepcopy(ex.context) new_ex.context.trace.append(subcall) - new_ex.callback = ex.callback - new_ex.this = ex.this - # restore vm state new_ex.pgm = ex.pgm new_ex.pc = ex.pc new_ex.st = deepcopy(ex.st) new_ex.jumpis = deepcopy(ex.jumpis) new_ex.symbolic = ex.symbolic - new_ex.prank = deepcopy(ex.prank) if subcall.is_stuck(): # internal errors abort the current path, @@ -1977,15 +1969,12 @@ def callback(new_ex, stack, step_id): # context=CallContext(message=message, depth=ex.context.depth + 1), callback=callback, - this=new_addr, # pgm=create_code, pc=0, st=State(), jumpis={}, symbolic=False, - prank=Prank(), - origin=pranked_origin, # path=ex.path, alias=ex.alias, @@ -2098,15 +2087,12 @@ def create_branch(self, ex: Exec, cond: BitVecRef, target: int) -> Exec: # context=deepcopy(ex.context), callback=ex.callback, - this=ex.this, # pgm=ex.pgm, pc=target, st=deepcopy(ex.st), jumpis=deepcopy(ex.jumpis), symbolic=ex.symbolic, - prank=deepcopy(ex.prank), - origin=ex.origin, # path=new_path, alias=ex.alias.copy(), @@ -2306,10 +2292,10 @@ def finalize(ex: Exec): ex.st.push(uint256(ex.caller())) elif opcode == EVM.ORIGIN: - ex.st.push(uint256(ex.origin)) + ex.st.push(uint256(ex.origin())) elif opcode == EVM.ADDRESS: - ex.st.push(uint256(ex.this)) + ex.st.push(uint256(ex.this())) # TODO: define f_extcodesize for known addresses in advance elif opcode == EVM.EXTCODESIZE: @@ -2410,7 +2396,7 @@ def finalize(ex: Exec): ex.st.push(ex.balance_of(uint160(ex.st.pop()))) elif opcode == EVM.SELFBALANCE: - ex.st.push(ex.balance_of(ex.this)) + ex.st.push(ex.balance_of(ex.this())) elif opcode in [ EVM.CALL, @@ -2453,12 +2439,12 @@ def finalize(ex: Exec): elif opcode == EVM.SLOAD: slot: Word = ex.st.pop() - ex.st.push(self.sload(ex, ex.this, slot)) + ex.st.push(self.sload(ex, ex.this(), slot)) elif opcode == EVM.SSTORE: slot: Word = ex.st.pop() value: Word = ex.st.pop() - self.sstore(ex, ex.this, slot, value) + self.sstore(ex, ex.this(), slot, value) elif opcode == EVM.RETURNDATASIZE: ex.st.push(ex.returndatasize()) @@ -2552,7 +2538,7 @@ def finalize(ex: Exec): size: int = int_of(ex.st.pop(), "symbolic LOG data size") topics = list(ex.st.pop() for _ in range(num_topics)) data = ex.st.memory.slice(loc, loc + size) - ex.emit_log(EventLog(ex.this, topics, data)) + ex.emit_log(EventLog(ex.this(), topics, data)) elif opcode == EVM.PUSH0: ex.st.push(con(0)) @@ -2618,8 +2604,6 @@ def mk_exec( # context: CallContext, # - this, - # pgm, symbolic, path, @@ -2634,14 +2618,11 @@ def mk_exec( context=context, callback=None, # top-level; no callback # - this=this, pgm=pgm, pc=0, st=State(), jumpis={}, symbolic=symbolic, - prank=Prank(), - origin=f_origin(), # path=path, alias={}, diff --git a/tests/expected/all.json b/tests/expected/all.json index 4f77220b..65424da8 100644 --- a/tests/expected/all.json +++ b/tests/expected/all.json @@ -1336,15 +1336,6 @@ } ], "test/Prank.t.sol:PrankTest": [ - { - "name": "check_prank(address,address)", - "exitcode": 0, - "num_models": 0, - "models": null, - "num_paths": null, - "time": null, - "num_bounded_loops": null - }, { "name": "check_prank_Constructor(address)", "exitcode": 0, @@ -1418,7 +1409,34 @@ "num_bounded_loops": null }, { - "name": "check_startPrank(address,address)", + "name": "check_prank_double(address,address)", + "exitcode": 0, + "num_models": 0, + "models": null, + "num_paths": null, + "time": null, + "num_bounded_loops": null + }, + { + "name": "check_prank_single(address)", + "exitcode": 0, + "num_models": 0, + "models": null, + "num_paths": null, + "time": null, + "num_bounded_loops": null + }, + { + "name": "check_startPrank_double(address,address)", + "exitcode": 0, + "num_models": 0, + "models": null, + "num_paths": null, + "time": null, + "num_bounded_loops": null + }, + { + "name": "check_startPrank_single(address)", "exitcode": 0, "num_models": 0, "models": null, diff --git a/tests/regression/test/Prank.t.sol b/tests/regression/test/Prank.t.sol index 6dca3ff5..3abe41d4 100644 --- a/tests/regression/test/Prank.t.sol +++ b/tests/regression/test/Prank.t.sol @@ -108,8 +108,9 @@ contract PrankTest is Test { assert(target.caller() == user); assert(target.origin() == origin); - // but the inner call is not pranked - checkNotPranked(inner, address(target)); + // the inner call also sees the pranked origin + assert(inner.caller() == address(target)); + assert(target.origin() == origin); // check that the prank is no longer active target.recordCaller(); diff --git a/tests/test_sevm.py b/tests/test_sevm.py index 2aa17dac..bde9c664 100644 --- a/tests/test_sevm.py +++ b/tests/test_sevm.py @@ -14,7 +14,6 @@ f_mod, f_smod, f_exp, - f_origin, CallContext, Message, SEVM, @@ -30,11 +29,10 @@ from test_fixtures import args, sevm, solver caller = BitVec("msg_sender", 160) - +origin = BitVec("tx_origin", 160) this = BitVec("this_address", 160) balance = Array("balance_0", BitVecSort(160), BitVecSort(256)) - callvalue = BitVec("msg_value", 256) @@ -49,6 +47,7 @@ def mk_ex(hexcode, sevm, solver, storage, caller, this): message = Message( target=this, caller=caller, + origin=origin, value=callvalue, data=ByteVec(), call_scheme=EVM.CALL, @@ -60,7 +59,6 @@ def mk_ex(hexcode, sevm, solver, storage, caller, this): balance=balance, block=mk_block(), context=CallContext(message), - this=this, pgm=bytecode, symbolic=True, path=Path(solver), @@ -272,7 +270,7 @@ def byte_of(i, x): # TODO: SHA3 (o(EVM.ADDRESS), [], uint256(this)), (o(EVM.BALANCE), [x], Select(balance, uint160(x))), - (o(EVM.ORIGIN), [], uint256(f_origin())), + (o(EVM.ORIGIN), [], uint256(origin)), (o(EVM.CALLER), [], uint256(caller)), (o(EVM.CALLVALUE), [], callvalue), # TODO: CALLDATA*, CODE*, EXTCODE*, RETURNDATA*, CREATE* From 1ea690c057726d6b3e984a6985e7945632fd8954 Mon Sep 17 00:00:00 2001 From: karmacoma Date: Fri, 9 Aug 2024 16:06:28 -0700 Subject: [PATCH 19/21] add a startPrank in constructor test --- tests/expected/all.json | 13 +++++++-- tests/regression/test/Prank.t.sol | 47 ++++++++++++++++++++++++++++--- 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/tests/expected/all.json b/tests/expected/all.json index 65424da8..0774ace3 100644 --- a/tests/expected/all.json +++ b/tests/expected/all.json @@ -1337,7 +1337,7 @@ ], "test/Prank.t.sol:PrankTest": [ { - "name": "check_prank_Constructor(address)", + "name": "check_prank_Constructor(address,address)", "exitcode": 0, "num_models": 0, "models": null, @@ -1346,7 +1346,7 @@ "num_bounded_loops": null }, { - "name": "check_prank_ConstructorCreate2(address,bytes32)", + "name": "check_prank_ConstructorCreate2(address,address,bytes32)", "exitcode": 0, "num_models": 0, "models": null, @@ -1426,6 +1426,15 @@ "time": null, "num_bounded_loops": null }, + { + "name": "check_prank_startPrank_in_constructor(address,address)", + "exitcode": 0, + "num_models": 0, + "models": null, + "num_paths": null, + "time": null, + "num_bounded_loops": null + }, { "name": "check_startPrank_double(address,address)", "exitcode": 0, diff --git a/tests/regression/test/Prank.t.sol b/tests/regression/test/Prank.t.sol index 3abe41d4..130a52f9 100644 --- a/tests/regression/test/Prank.t.sol +++ b/tests/regression/test/Prank.t.sol @@ -35,9 +35,17 @@ contract Target { contract ConstructorRecorder { address public caller; + address public origin; constructor() { caller = msg.sender; + origin = tx.origin; + } +} + +contract PrankyConstructor is TestBase { + constructor(address user, address origin) { + vm.startPrank(user, origin); } } @@ -231,15 +239,46 @@ contract PrankTest is Test { assert(target.caller() == address(this)); } - function check_prank_Constructor(address user) public { - vm.prank(user); + function check_prank_Constructor(address user, address origin) public { + address senderBefore = msg.sender; + address originBefore = tx.origin; + + vm.prank(user, origin); ConstructorRecorder recorder = new ConstructorRecorder(); assert(recorder.caller() == user); + assert(recorder.origin() == origin); + + // origin and sender are restored + assertEq(msg.sender, senderBefore); + assertEq(tx.origin, originBefore); } - function check_prank_ConstructorCreate2(address user, bytes32 salt) public { - vm.prank(user); + function check_prank_ConstructorCreate2(address user, address origin, bytes32 salt) public { + address senderBefore = msg.sender; + address originBefore = tx.origin; + + vm.prank(user, origin); ConstructorRecorder recorder = new ConstructorRecorder{salt:salt}(); assert(recorder.caller() == user); + assert(recorder.origin() == origin); + + // origin and sender are restored + assertEq(msg.sender, senderBefore); + assertEq(tx.origin, originBefore); + } + + function check_prank_startPrank_in_constructor(address user, address origin) public { + address senderBefore = msg.sender; + address originBefore = tx.origin; + + PrankyConstructor pranky = new PrankyConstructor(user, origin); + + // results are not affected by the startPrank in the constructor + assertEq(msg.sender, senderBefore); + assertEq(tx.origin, originBefore); + + target.recordCaller(); + assert(target.caller() == address(this)); + assert(target.origin() == originBefore); } } From c1c64e68382bb5f72f44b5a4de7a03e2d229c158 Mon Sep 17 00:00:00 2001 From: karmacoma Date: Fri, 9 Aug 2024 16:34:28 -0700 Subject: [PATCH 20/21] add test_prank_in_context --- tests/test_prank.py | 76 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/tests/test_prank.py b/tests/test_prank.py index 3c92b8cc..62d1198b 100644 --- a/tests/test_prank.py +++ b/tests/test_prank.py @@ -10,6 +10,8 @@ halmos_cheat_code, ) +from halmos.sevm import Message, CallContext + @pytest.fixture def prank(): @@ -146,3 +148,77 @@ def test_startPrank_lookup(prank, sender, origin, other): # and the prank is still active assert prank + + +def test_prank_in_context(sender, origin): + """ + This is part test and part documentation. + + It implements the intended handling of messages, contexts and pranks by sevm, + and it shows the expected flow of values from prank creation to consumption. + """ + + pranked_sender = BitVec("pranked_sender", 160) + pranked_origin = BitVec("pranked_origin", 160) + CALL = 0xF1 + + # start with a basic context + context = CallContext( + message=Message( + target=BitVec("original_target", 160), + caller=sender, + origin=origin, + value=0, + data=b"", + call_scheme=CALL, + ) + ) + + assert not context.prank + + # a call to vm.prank() would mutate the context's active prank + context.prank.prank(pranked_sender, pranked_origin) + + # the context now has an active prank + assert context.prank + + # when creating a sub-context (e.g. for a new call), the prank should be consumed + call1_target = BitVec("call1_target", 160) + call1_prank_result = context.prank.lookup(call1_target) + sub_context1 = CallContext( + message=Message( + target=call1_target, + caller=call1_prank_result.sender, + origin=call1_prank_result.origin, + value=0, + data=b"", + call_scheme=CALL, + ) + ) + + assert not context.prank + assert sub_context1.message.caller == pranked_sender + assert sub_context1.message.origin == pranked_origin + + # the sub-context should have no active prank + assert not sub_context1.prank + + # subcalls do inherit the origin from the parent context + call2_target = BitVec("call2_target", 160) + assert not context.prank.lookup(call2_target) + sub_context2 = CallContext( + message=Message( + target=call2_target, + caller=sub_context1.message.target, + origin=sub_context1.message.origin, + value=0, + data=b"", + call_scheme=CALL, + ), + ) + + assert not sub_context2.prank + assert sub_context2.message.caller == sub_context1.message.target # real + assert ( + sub_context2.message.origin == sub_context1.message.origin + ) # pranked (indirectly) From 17eba86b1f3806979c81c536fb193c16bda9f226 Mon Sep 17 00:00:00 2001 From: karmacoma Date: Tue, 13 Aug 2024 15:21:39 -0700 Subject: [PATCH 21/21] less convoluted code in prank/startPrank --- src/halmos/cheatcodes.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/halmos/cheatcodes.py b/src/halmos/cheatcodes.py index 8c48be1a..faca777e 100644 --- a/src/halmos/cheatcodes.py +++ b/src/halmos/cheatcodes.py @@ -144,19 +144,19 @@ def lookup(self, to: Address) -> PrankResult: return NO_PRANK - def prank(self, sender: Address, origin: Address | None = None) -> bool: + def prank( + self, sender: Address, origin: Address | None = None, _keep: bool = False + ) -> bool: assert_address(sender) if self.active: return False self.active = PrankResult(sender=sender, origin=origin) - self.keep = False + self.keep = _keep return True def startPrank(self, sender: Address, origin: Address | None = None) -> bool: - result = self.prank(sender, origin) - self.keep = result if result else self.keep - return result + return self.prank(sender, origin, _keep=True) def stopPrank(self) -> bool: # stopPrank calls are allowed even when no active prank exists