From 2e858354dcb3647fedb6033721724a32c7efd78a Mon Sep 17 00:00:00 2001 From: Reuben Rodrigues Date: Mon, 3 Apr 2023 18:32:28 +0530 Subject: [PATCH] feat: add count to `expectCall` cheatcode --- evm/src/executor/abi/mod.rs | 4 + .../executor/inspector/cheatcodes/expect.rs | 124 ++++++++++++--- evm/src/executor/inspector/cheatcodes/mod.rs | 53 ++++--- forge/README.md | 14 +- testdata/cheats/Cheats.sol | 13 ++ testdata/cheats/ExpectCall.t.sol | 144 ++++++++++++++++++ 6 files changed, 300 insertions(+), 52 deletions(-) diff --git a/evm/src/executor/abi/mod.rs b/evm/src/executor/abi/mod.rs index efda7d8014c1..4e5dd12acf21 100644 --- a/evm/src/executor/abi/mod.rs +++ b/evm/src/executor/abi/mod.rs @@ -102,9 +102,13 @@ abigen!( clearMockedCalls() expectCall(address,bytes) + expectCall(address,bytes,uint64) expectCall(address,uint256,bytes) + expectCall(address,uint256,bytes,uint64) expectCall(address,uint256,uint64,bytes) + expectCall(address,uint256,uint64,bytes,uint64) expectCallMinGas(address,uint256,uint64,bytes) + expectCallMinGas(address,uint256,uint64,bytes,uint64) expectSafeMemory(uint64,uint64) expectSafeMemoryCall(uint64,uint64) diff --git a/evm/src/executor/inspector/cheatcodes/expect.rs b/evm/src/executor/inspector/cheatcodes/expect.rs index 5cc0ebe07c32..7347c2d4ab8c 100644 --- a/evm/src/executor/inspector/cheatcodes/expect.rs +++ b/evm/src/executor/inspector/cheatcodes/expect.rs @@ -210,6 +210,8 @@ pub struct ExpectedCallData { pub gas: Option, /// The expected *minimum* gas supplied to the call pub min_gas: Option, + /// The number of times the call is expected to be made + pub count: u64, } #[derive(Clone, Debug, Default, PartialEq, Eq)] @@ -293,51 +295,123 @@ pub fn apply( Ok(Bytes::new()) } HEVMCalls::ExpectCall0(inner) => { - state.expected_calls.entry(inner.0).or_default().push(ExpectedCallData { - calldata: inner.1.to_vec().into(), - value: None, - gas: None, - min_gas: None, - }); + state.expected_calls.entry(inner.0).or_default().push(( + ExpectedCallData { + calldata: inner.1.to_vec().into(), + value: None, + gas: None, + min_gas: None, + count: 1, + }, + 0, + )); Ok(Bytes::new()) } HEVMCalls::ExpectCall1(inner) => { - state.expected_calls.entry(inner.0).or_default().push(ExpectedCallData { - calldata: inner.2.to_vec().into(), - value: Some(inner.1), - gas: None, - min_gas: None, - }); + state.expected_calls.entry(inner.0).or_default().push(( + ExpectedCallData { + calldata: inner.1.to_vec().into(), + value: None, + gas: None, + min_gas: None, + count: inner.2, + }, + 0, + )); Ok(Bytes::new()) } HEVMCalls::ExpectCall2(inner) => { + state.expected_calls.entry(inner.0).or_default().push(( + ExpectedCallData { + calldata: inner.2.to_vec().into(), + value: Some(inner.1), + gas: None, + min_gas: None, + count: 1, + }, + 0, + )); + Ok(Bytes::new()) + } + HEVMCalls::ExpectCall3(inner) => { + state.expected_calls.entry(inner.0).or_default().push(( + ExpectedCallData { + calldata: inner.2.to_vec().into(), + value: Some(inner.1), + gas: None, + min_gas: None, + count: inner.3, + }, + 0, + )); + Ok(Bytes::new()) + } + HEVMCalls::ExpectCall4(inner) => { let value = inner.1; // If the value of the transaction is non-zero, the EVM adds a call stipend of 2300 gas // to ensure that the basic fallback function can be called. let positive_value_cost_stipend = if value > U256::zero() { 2300 } else { 0 }; - state.expected_calls.entry(inner.0).or_default().push(ExpectedCallData { - calldata: inner.3.to_vec().into(), - value: Some(value), - gas: Some(inner.2 + positive_value_cost_stipend), - min_gas: None, - }); + state.expected_calls.entry(inner.0).or_default().push(( + ExpectedCallData { + calldata: inner.3.to_vec().into(), + value: Some(value), + gas: Some(inner.2 + positive_value_cost_stipend), + min_gas: None, + count: 1, + }, + 0, + )); + Ok(Bytes::new()) + } + HEVMCalls::ExpectCall5(inner) => { + let value = inner.1; + let positive_value_cost_stipend = if value > U256::zero() { 2300 } else { 0 }; + state.expected_calls.entry(inner.0).or_default().push(( + ExpectedCallData { + calldata: inner.3.to_vec().into(), + value: Some(value), + gas: Some(inner.2 + positive_value_cost_stipend), + min_gas: None, + count: inner.4, + }, + 0, + )); Ok(Bytes::new()) } - HEVMCalls::ExpectCallMinGas(inner) => { + HEVMCalls::ExpectCallMinGas0(inner) => { let value = inner.1; // If the value of the transaction is non-zero, the EVM adds a call stipend of 2300 gas // to ensure that the basic fallback function can be called. let positive_value_cost_stipend = if value > U256::zero() { 2300 } else { 0 }; - state.expected_calls.entry(inner.0).or_default().push(ExpectedCallData { - calldata: inner.3.to_vec().into(), - value: Some(value), - gas: None, - min_gas: Some(inner.2 + positive_value_cost_stipend), - }); + state.expected_calls.entry(inner.0).or_default().push(( + ExpectedCallData { + calldata: inner.3.to_vec().into(), + value: Some(value), + gas: None, + min_gas: Some(inner.2 + positive_value_cost_stipend), + count: 1, + }, + 0, + )); + Ok(Bytes::new()) + } + HEVMCalls::ExpectCallMinGas1(inner) => { + let value = inner.1; + let positive_value_cost_stipend = if value > U256::zero() { 2300 } else { 0 }; + state.expected_calls.entry(inner.0).or_default().push(( + ExpectedCallData { + calldata: inner.3.to_vec().into(), + value: Some(value), + gas: None, + min_gas: Some(inner.2 + positive_value_cost_stipend), + count: inner.4, + }, + 0, + )); Ok(Bytes::new()) } HEVMCalls::MockCall0(inner) => { diff --git a/evm/src/executor/inspector/cheatcodes/mod.rs b/evm/src/executor/inspector/cheatcodes/mod.rs index 7d1d7d3158d0..77ff7b588ab5 100644 --- a/evm/src/executor/inspector/cheatcodes/mod.rs +++ b/evm/src/executor/inspector/cheatcodes/mod.rs @@ -121,7 +121,7 @@ pub struct Cheatcodes { pub mocked_calls: BTreeMap>, /// Expected calls - pub expected_calls: BTreeMap>, + pub expected_calls: BTreeMap>, /// Expected emits pub expected_emits: Vec, @@ -542,14 +542,14 @@ where } else if call.contract != HARDHAT_CONSOLE_ADDRESS { // Handle expected calls if let Some(expecteds) = self.expected_calls.get_mut(&call.contract) { - if let Some(found_match) = expecteds.iter().position(|expected| { + if let Some((_, count)) = expecteds.iter_mut().find(|(expected, _)| { expected.calldata.len() <= call.input.len() && expected.calldata == call.input[..expected.calldata.len()] && expected.value.map_or(true, |value| value == call.transfer.value) && expected.gas.map_or(true, |gas| gas == call.gas_limit) && expected.min_gas.map_or(true, |min_gas| min_gas <= call.gas_limit) }) { - expecteds.remove(found_match); + *count += 1; } } @@ -738,28 +738,31 @@ where // If the depth is 0, then this is the root call terminating if data.journaled_state.depth() == 0 { - // Handle expected calls that were not fulfilled - if let Some((address, expecteds)) = - self.expected_calls.iter().find(|(_, expecteds)| !expecteds.is_empty()) - { - let ExpectedCallData { calldata, gas, min_gas, value } = &expecteds[0]; - let calldata = ethers::types::Bytes::from(calldata.clone()); - let expected_values = [ - Some(format!("data {calldata}")), - value.map(|v| format!("value {v}")), - gas.map(|g| format!("gas {g}")), - min_gas.map(|g| format!("minimum gas {g}")), - ] - .into_iter() - .flatten() - .join(" and "); - return ( - Return::Revert, - remaining_gas, - format!("Expected a call to {address:?} with {expected_values}, but got none") - .encode() - .into(), - ) + for (address, expecteds) in &self.expected_calls { + for (expected, actual_count) in expecteds { + let ExpectedCallData { calldata, gas, min_gas, value, count } = expected; + let calldata = ethers::types::Bytes::from(calldata.clone()); + if *count != *actual_count { + let expected_values = [ + Some(format!("data {calldata}")), + value.map(|v| format!("value {v}")), + gas.map(|g| format!("gas {g}")), + min_gas.map(|g| format!("minimum gas {g}")), + ] + .into_iter() + .flatten() + .join(" and "); + return ( + Return::Revert, + remaining_gas, + format!( + "Expected call to {address:?} with {expected_values} to be called {count} time(s), but was called {actual_count} time(s)" + ) + .encode() + .into(), + ) + } + } } // Check if we have any leftover expected emits diff --git a/forge/README.md b/forge/README.md index eb2f068800f1..35cb196ced43 100644 --- a/forge/README.md +++ b/forge/README.md @@ -318,13 +318,23 @@ interface Hevm { function clearMockedCalls() external; // Expect a call to an address with the specified calldata. // Calldata can either be strict or a partial match - function expectCall(address,bytes calldata) external; + function expectCall(address, bytes calldata) external; + // Expect given number of calls to an address with the specified calldata. + // Calldata can either be strict or a partial match + function expectCall(address, bytes calldata, uint64) external; // Expect a call to an address with the specified msg.value and calldata - function expectCall(address,uint256,bytes calldata) external; + function expectCall(address, uint256, bytes calldata) external; + // Expect a given number of calls to an address with the specified msg.value and calldata + function expectCall(address, uint256, bytes calldata, uint64) external; // Expect a call to an address with the specified msg.value, gas, and calldata. function expectCall(address, uint256, uint64, bytes calldata) external; + // Expect a given number of calls to an address with the specified msg.value, gas, and calldata. + function expectCall(address, uint256, uint64, bytes calldata, uint64) external; // Expect a call to an address with the specified msg.value and calldata, and a *minimum* amount of gas. function expectCallMinGas(address, uint256, uint64, bytes calldata) external; + // Expect a given number of calls to an address with the specified msg.value and calldata, and a *minimum* amount of gas. + function expectCallMinGas(address, uint256, uint64, bytes calldata, uint64) external; + // Only allows memory writes to offsets [0x00, 0x60) ∪ [min, max) in the current subcontext. If any other // memory is written to, the test will fail. function expectSafeMemory(uint64, uint64) external; diff --git a/testdata/cheats/Cheats.sol b/testdata/cheats/Cheats.sol index 198378403658..7b40e8e3fb34 100644 --- a/testdata/cheats/Cheats.sol +++ b/testdata/cheats/Cheats.sol @@ -208,15 +208,28 @@ interface Cheats { // Calldata can either be strict or a partial match function expectCall(address, bytes calldata) external; + // Expect given number of calls to an address with the specified calldata. + // Calldata can either be strict or a partial match + function expectCall(address, bytes calldata, uint64) external; + // Expect a call to an address with the specified msg.value and calldata function expectCall(address, uint256, bytes calldata) external; + // Expect a given number of calls to an address with the specified msg.value and calldata + function expectCall(address, uint256, bytes calldata, uint64) external; + // Expect a call to an address with the specified msg.value, gas, and calldata. function expectCall(address, uint256, uint64, bytes calldata) external; + // Expect a given number of calls to an address with the specified msg.value, gas, and calldata. + function expectCall(address, uint256, uint64, bytes calldata, uint64) external; + // Expect a call to an address with the specified msg.value and calldata, and a *minimum* amount of gas. function expectCallMinGas(address, uint256, uint64, bytes calldata) external; + // Expect a given number of calls to an address with the specified msg.value and calldata, and a *minimum* amount of gas. + function expectCallMinGas(address, uint256, uint64, bytes calldata, uint64) external; + // Only allows memory writes to offsets [0x00, 0x60) ∪ [min, max) in the current subcontext. If any other // memory is written to, the test will fail. function expectSafeMemory(uint64, uint64) external; diff --git a/testdata/cheats/ExpectCall.t.sol b/testdata/cheats/ExpectCall.t.sol index b9330b8db979..ffa3c1168139 100644 --- a/testdata/cheats/ExpectCall.t.sol +++ b/testdata/cheats/ExpectCall.t.sol @@ -161,3 +161,147 @@ contract ExpectCallTest is DSTest { target.addHardGasLimit(); } } + +contract ExpectCallCountTest is DSTest { + Cheats constant cheats = Cheats(HEVM_ADDRESS); + + function testExpectCallCountWithData() public { + Contract target = new Contract(); + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2), 3); + target.add(1, 2); + target.add(1, 2); + target.add(1, 2); + } + + function testExpectZeroCallCountAssert() public { + Contract target = new Contract(); + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2), 0); + target.add(3, 3); + } + + function testFailExpectCallCountWithWrongCount() public { + Contract target = new Contract(); + cheats.expectCall(address(target), abi.encodeWithSelector(target.add.selector, 1, 2), 2); + target.add(1, 2); + } + + function testExpectCountInnerCall() public { + Contract inner = new Contract(); + NestedContract target = new NestedContract(inner); + + cheats.expectCall(address(inner), abi.encodeWithSelector(inner.numberB.selector), 1); + target.sum(); + } + + function testFailExpectCountInnerCall() public { + Contract inner = new Contract(); + NestedContract target = new NestedContract(inner); + + cheats.expectCall(address(inner), abi.encodeWithSelector(inner.numberB.selector), 1); + + // this function does not call inner + target.hello(); + } + + function testExpectCountInnerAndOuterCalls() public { + Contract inner = new Contract(); + NestedContract target = new NestedContract(inner); + + cheats.expectCall(address(inner), abi.encodeWithSelector(inner.numberB.selector), 2); + inner.numberB(); + target.sum(); + } + + function testExpectCallCountWithValue() public { + Contract target = new Contract(); + cheats.expectCall(address(target), 1, abi.encodeWithSelector(target.pay.selector, 2), 1); + target.pay{value: 1}(2); + } + + function testExpectZeroCallCountValue() public { + Contract target = new Contract(); + cheats.expectCall(address(target), 1, abi.encodeWithSelector(target.pay.selector, 2), 0); + target.pay{value: 2}(2); + } + + function testFailExpectCallCountValue() public { + Contract target = new Contract(); + cheats.expectCall(address(target), 1, abi.encodeWithSelector(target.pay.selector, 2), 1); + target.pay{value: 2}(2); + } + + function testExpectCallCountWithValueWithoutParameters() public { + Contract target = new Contract(); + cheats.expectCall(address(target), 3, abi.encodeWithSelector(target.pay.selector), 3); + target.pay{value: 3}(100); + target.pay{value: 3}(100); + target.pay{value: 3}(100); + } + + function testExpectCallCountWithValueAndGas() public { + Contract inner = new Contract(); + NestedContract target = new NestedContract(inner); + + cheats.expectCall(address(inner), 1, 50_000, abi.encodeWithSelector(inner.pay.selector, 1), 2); + target.forwardPay{value: 1}(); + target.forwardPay{value: 1}(); + } + + function testExpectCallCountWithNoValueAndGas() public { + Contract inner = new Contract(); + NestedContract target = new NestedContract(inner); + + cheats.expectCall(address(inner), 0, 50_000, abi.encodeWithSelector(inner.add.selector, 1, 1), 1); + target.addHardGasLimit(); + } + + function testExpectZeroCallCountWithNoValueAndWrongGas() public { + Contract inner = new Contract(); + NestedContract target = new NestedContract(inner); + + cheats.expectCall(address(inner), 0, 25_000, abi.encodeWithSelector(inner.add.selector, 1, 1), 0); + target.addHardGasLimit(); + } + + function testFailExpectCallCountWithNoValueAndWrongGas() public { + Contract inner = new Contract(); + NestedContract target = new NestedContract(inner); + + cheats.expectCall(address(inner), 0, 25_000, abi.encodeWithSelector(inner.add.selector, 1, 1), 2); + target.addHardGasLimit(); + target.addHardGasLimit(); + } + + function testExpectCallCountWithValueAndMinGas() public { + Contract inner = new Contract(); + NestedContract target = new NestedContract(inner); + + cheats.expectCallMinGas(address(inner), 1, 50_000, abi.encodeWithSelector(inner.pay.selector, 1), 1); + target.forwardPay{value: 1}(); + } + + function testExpectCallCountWithNoValueAndMinGas() public { + Contract inner = new Contract(); + NestedContract target = new NestedContract(inner); + + cheats.expectCallMinGas(address(inner), 0, 25_000, abi.encodeWithSelector(inner.add.selector, 1, 1), 2); + target.addHardGasLimit(); + target.addHardGasLimit(); + } + + function testExpectCallZeroCountWithNoValueAndWrongMinGas() public { + Contract inner = new Contract(); + NestedContract target = new NestedContract(inner); + + cheats.expectCallMinGas(address(inner), 0, 50_001, abi.encodeWithSelector(inner.add.selector, 1, 1), 0); + target.addHardGasLimit(); + } + + function testFailExpectCallCountWithNoValueAndWrongMinGas() public { + Contract inner = new Contract(); + NestedContract target = new NestedContract(inner); + + cheats.expectCallMinGas(address(inner), 0, 50_001, abi.encodeWithSelector(inner.add.selector, 1, 1), 1); + target.addHardGasLimit(); + } +}