diff --git a/crates/config/src/invariant.rs b/crates/config/src/invariant.rs index 17994c3e3b30..0eb96bdee2a2 100644 --- a/crates/config/src/invariant.rs +++ b/crates/config/src/invariant.rs @@ -32,6 +32,9 @@ pub struct InvariantConfig { /// Useful for handlers that use cheatcodes as roll or warp /// Use it with caution, introduces performance penalty. pub preserve_state: bool, + /// The maximum number of rejects via `vm.assume` which can be encountered during a single + /// invariant run. + pub max_assume_rejects: u32, } impl Default for InvariantConfig { @@ -45,6 +48,7 @@ impl Default for InvariantConfig { shrink_sequence: true, shrink_run_limit: 2usize.pow(18_u32), preserve_state: false, + max_assume_rejects: 65536, } } } diff --git a/crates/evm/evm/src/executors/invariant/error.rs b/crates/evm/evm/src/executors/invariant/error.rs index 50dedab7c81d..8acc67f6a783 100644 --- a/crates/evm/evm/src/executors/invariant/error.rs +++ b/crates/evm/evm/src/executors/invariant/error.rs @@ -53,7 +53,27 @@ pub struct InvariantFuzzTestResult { } #[derive(Clone, Debug)] -pub struct InvariantFuzzError { +pub enum InvariantFuzzError { + Revert(FailedInvariantCaseData), + BrokenInvariant(FailedInvariantCaseData), + MaxAssumeRejects(u32), +} + +impl InvariantFuzzError { + pub fn revert_reason(&self) -> Option { + match self { + Self::BrokenInvariant(case_data) | Self::Revert(case_data) => { + (!case_data.revert_reason.is_empty()).then(|| case_data.revert_reason.clone()) + } + Self::MaxAssumeRejects(allowed) => Some(format!( + "The `vm.assume` cheatcode rejected too many inputs ({allowed} allowed)" + )), + } + } +} + +#[derive(Clone, Debug)] +pub struct FailedInvariantCaseData { pub logs: Vec, pub traces: Option, /// The proptest error occurred as a result of a test case. @@ -74,7 +94,7 @@ pub struct InvariantFuzzError { pub shrink_run_limit: usize, } -impl InvariantFuzzError { +impl FailedInvariantCaseData { pub fn new( invariant_contract: &InvariantContract<'_>, error_func: Option<&Function>, @@ -93,7 +113,7 @@ impl InvariantFuzzError { .with_abi(invariant_contract.abi) .decode(call_result.result.as_ref(), Some(call_result.exit_reason)); - InvariantFuzzError { + Self { logs: call_result.logs, traces: call_result.traces, test_error: proptest::test_runner::TestError::Fail( diff --git a/crates/evm/evm/src/executors/invariant/funcs.rs b/crates/evm/evm/src/executors/invariant/funcs.rs index 810abb259d0c..237b4dad84ec 100644 --- a/crates/evm/evm/src/executors/invariant/funcs.rs +++ b/crates/evm/evm/src/executors/invariant/funcs.rs @@ -1,4 +1,4 @@ -use super::{InvariantFailures, InvariantFuzzError}; +use super::{error::FailedInvariantCaseData, InvariantFailures, InvariantFuzzError}; use crate::executors::{Executor, RawCallResult}; use alloy_dyn_abi::JsonAbiExt; use alloy_json_abi::Function; @@ -50,7 +50,7 @@ pub fn assert_invariants( if is_err { // We only care about invariants which we haven't broken yet. if invariant_failures.error.is_none() { - invariant_failures.error = Some(InvariantFuzzError::new( + let case_data = FailedInvariantCaseData::new( invariant_contract, Some(func), calldata, @@ -58,7 +58,8 @@ pub fn assert_invariants( &inner_sequence, shrink_sequence, shrink_run_limit, - )); + ); + invariant_failures.error = Some(InvariantFuzzError::BrokenInvariant(case_data)); return None } } diff --git a/crates/evm/evm/src/executors/invariant/mod.rs b/crates/evm/evm/src/executors/invariant/mod.rs index 7b87c0f06600..2ec627b11e36 100644 --- a/crates/evm/evm/src/executors/invariant/mod.rs +++ b/crates/evm/evm/src/executors/invariant/mod.rs @@ -9,7 +9,7 @@ use eyre::{eyre, ContextCompat, Result}; use foundry_common::contracts::{ContractsByAddress, ContractsByArtifact}; use foundry_config::{FuzzDictionaryConfig, InvariantConfig}; use foundry_evm_core::{ - constants::{CALLER, CHEATCODE_ADDRESS, HARDHAT_CONSOLE_ADDRESS}, + constants::{CALLER, CHEATCODE_ADDRESS, HARDHAT_CONSOLE_ADDRESS, MAGIC_ASSUME}, utils::{get_function, StateChangeset}, }; use foundry_evm_fuzz::{ @@ -38,11 +38,13 @@ use foundry_evm_fuzz::strategies::CalldataFuzzDictionary; mod funcs; pub use funcs::{assert_invariants, replay_run}; +use self::error::FailedInvariantCaseData; + /// Alias for (Dictionary for fuzzing, initial contracts to fuzz and an InvariantStrategy). type InvariantPreparation = ( EvmFuzzState, FuzzRunIdentifiedContracts, - BoxedStrategy>, + BoxedStrategy, CalldataFuzzDictionary, ); @@ -143,7 +145,9 @@ impl<'a> InvariantExecutor<'a> { // during the run. We need another proptest runner to query for random // values. let branch_runner = RefCell::new(self.runner.clone()); - let _ = self.runner.run(&strat, |mut inputs| { + let _ = self.runner.run(&strat, |first_input| { + let mut inputs = vec![first_input]; + // We stop the run immediately if we have reverted, and `fail_on_revert` is set. if self.config.fail_on_revert && failures.borrow().reverts > 0 { return Err(TestCaseError::fail("Revert occurred.")) @@ -158,7 +162,10 @@ impl<'a> InvariantExecutor<'a> { // Created contracts during a run. let mut created_contracts = vec![]; - for current_run in 0..self.config.depth { + let mut current_run = 0; + let mut assume_rejects_counter = 0; + + while current_run < self.config.depth { let (sender, (address, calldata)) = inputs.last().expect("no input generated"); // Executes the call from the randomly generated sequence. @@ -172,65 +179,77 @@ impl<'a> InvariantExecutor<'a> { .expect("could not make raw evm call") }; - // Collect data for fuzzing from the state changeset. - let mut state_changeset = - call_result.state_changeset.to_owned().expect("no changesets"); - - collect_data( - &mut state_changeset, - sender, - &call_result, - fuzz_state.clone(), - &self.config.dictionary, - ); + if call_result.result.as_ref() == MAGIC_ASSUME { + inputs.pop(); + assume_rejects_counter += 1; + if assume_rejects_counter > self.config.max_assume_rejects { + failures.borrow_mut().error = Some(InvariantFuzzError::MaxAssumeRejects( + self.config.max_assume_rejects, + )); + return Err(TestCaseError::fail("Max number of vm.assume rejects reached.")) + } + } else { + // Collect data for fuzzing from the state changeset. + let mut state_changeset = + call_result.state_changeset.to_owned().expect("no changesets"); + + collect_data( + &mut state_changeset, + sender, + &call_result, + fuzz_state.clone(), + &self.config.dictionary, + ); - if let Err(error) = collect_created_contracts( - &state_changeset, - self.project_contracts, - self.setup_contracts, - &self.artifact_filters, - targeted_contracts.clone(), - &mut created_contracts, - ) { - warn!(target: "forge::test", "{error}"); - } + if let Err(error) = collect_created_contracts( + &state_changeset, + self.project_contracts, + self.setup_contracts, + &self.artifact_filters, + targeted_contracts.clone(), + &mut created_contracts, + ) { + warn!(target: "forge::test", "{error}"); + } - // Commit changes to the database. - executor.backend.commit(state_changeset.clone()); - - fuzz_runs.push(FuzzCase { - calldata: calldata.clone(), - gas: call_result.gas_used, - stipend: call_result.stipend, - }); - - let RichInvariantResults { success: can_continue, call_result: call_results } = - can_continue( - &invariant_contract, - call_result, - &executor, - &inputs, - &mut failures.borrow_mut(), - &targeted_contracts, - state_changeset, - self.config.fail_on_revert, - self.config.shrink_sequence, - self.config.shrink_run_limit, - ); + // Commit changes to the database. + executor.backend.commit(state_changeset.clone()); + + fuzz_runs.push(FuzzCase { + calldata: calldata.clone(), + gas: call_result.gas_used, + stipend: call_result.stipend, + }); + + let RichInvariantResults { success: can_continue, call_result: call_results } = + can_continue( + &invariant_contract, + call_result, + &executor, + &inputs, + &mut failures.borrow_mut(), + &targeted_contracts, + state_changeset, + self.config.fail_on_revert, + self.config.shrink_sequence, + self.config.shrink_run_limit, + ); + + if !can_continue || current_run == self.config.depth - 1 { + *last_run_calldata.borrow_mut() = inputs.clone(); + } - if !can_continue || current_run == self.config.depth - 1 { - *last_run_calldata.borrow_mut() = inputs.clone(); - } + if !can_continue { + break + } - if !can_continue { - break + *last_call_results.borrow_mut() = call_results; + current_run += 1; } - *last_call_results.borrow_mut() = call_results; - // Generates the next call from the run using the recently updated // dictionary. - inputs.extend( + inputs.push( strat .new_tree(&mut branch_runner.borrow_mut()) .map_err(|_| TestCaseError::Fail("Could not generate case".into()))? @@ -772,7 +791,7 @@ fn can_continue( failures.reverts += 1; // If fail on revert is set, we must return immediately. if fail_on_revert { - let error = InvariantFuzzError::new( + let case_data = FailedInvariantCaseData::new( invariant_contract, None, calldata, @@ -781,8 +800,8 @@ fn can_continue( shrink_sequence, shrink_run_limit, ); - - failures.revert_reason = Some(error.revert_reason.clone()); + failures.revert_reason = Some(case_data.revert_reason.clone()); + let error = InvariantFuzzError::Revert(case_data); failures.error = Some(error); return RichInvariantResults::new(false, None) diff --git a/crates/evm/fuzz/src/strategies/invariants.rs b/crates/evm/fuzz/src/strategies/invariants.rs index 29d868bad44d..e6dedc9cd8dd 100644 --- a/crates/evm/fuzz/src/strategies/invariants.rs +++ b/crates/evm/fuzz/src/strategies/invariants.rs @@ -59,11 +59,10 @@ pub fn invariant_strat( contracts: FuzzRunIdentifiedContracts, dictionary_weight: u32, calldata_fuzz_config: CalldataFuzzDictionary, -) -> impl Strategy> { +) -> impl Strategy { // We only want to seed the first value, since we want to generate the rest as we mutate the // state generate_call(fuzz_state, senders, contracts, dictionary_weight, calldata_fuzz_config) - .prop_map(|x| vec![x]) } /// Strategy to generate a transaction where the `sender`, `target` and `calldata` are all generated diff --git a/crates/forge/src/runner.rs b/crates/forge/src/runner.rs index ca3e8362940b..29507de8d5d7 100644 --- a/crates/forge/src/runner.rs +++ b/crates/forge/src/runner.rs @@ -25,7 +25,7 @@ use foundry_evm::{ fuzz::{invariant::InvariantContract, CounterExample}, traces::{load_contracts, TraceKind}, }; -use proptest::test_runner::{TestError, TestRunner}; +use proptest::test_runner::TestRunner; use rayon::prelude::*; use std::{ collections::{BTreeMap, HashMap}, @@ -513,26 +513,28 @@ impl<'a> ContractRunner<'a> { let mut logs = logs.clone(); let mut traces = traces.clone(); let success = error.is_none(); - let reason = error - .as_ref() - .and_then(|err| (!err.revert_reason.is_empty()).then(|| err.revert_reason.clone())); + let reason = error.as_ref().and_then(|err| err.revert_reason()); let mut coverage = coverage.clone(); match error { // If invariants were broken, replay the error to collect logs and traces - Some(error @ InvariantFuzzError { test_error: TestError::Fail(_, _), .. }) => { - match error.replay( - self.executor.clone(), - known_contracts, - identified_contracts.clone(), - &mut logs, - &mut traces, - ) { - Ok(c) => counterexample = c, - Err(err) => { - error!(%err, "Failed to replay invariant error"); - } - }; - } + Some(error) => match error { + InvariantFuzzError::BrokenInvariant(case_data) | + InvariantFuzzError::Revert(case_data) => { + match case_data.replay( + self.executor.clone(), + known_contracts, + identified_contracts.clone(), + &mut logs, + &mut traces, + ) { + Ok(c) => counterexample = c, + Err(err) => { + error!(%err, "Failed to replay invariant error"); + } + }; + } + InvariantFuzzError::MaxAssumeRejects(_) => {} + }, // If invariants ran successfully, replay the last run to collect logs and // traces. diff --git a/crates/forge/tests/it/invariant.rs b/crates/forge/tests/it/invariant.rs index 28ac405cc351..99b9ad962783 100644 --- a/crates/forge/tests/it/invariant.rs +++ b/crates/forge/tests/it/invariant.rs @@ -141,6 +141,10 @@ async fn test_invariant() { "fuzz/invariant/common/InvariantCalldataDictionary.t.sol:InvariantCalldataDictionary", vec![("invariant_owner_never_changes()", true, None, None, None)], ), + ( + "fuzz/invariant/common/InvariantAssume.t.sol:InvariantAssume", + vec![("invariant_dummy()", true, None, None, None)], + ), ]), ); } @@ -367,3 +371,40 @@ async fn test_invariant_calldata_fuzz_dictionary_addresses() { )]), ); } + +#[tokio::test(flavor = "multi_thread")] +async fn test_invariant_assume_does_not_revert() { + let filter = Filter::new(".*", ".*", ".*fuzz/invariant/common/InvariantAssume.t.sol"); + let mut runner = runner(); + // Should not treat vm.assume as revert. + runner.test_options.invariant.fail_on_revert = true; + let results = runner.test_collect(&filter); + assert_multiple( + &results, + BTreeMap::from([( + "fuzz/invariant/common/InvariantAssume.t.sol:InvariantAssume", + vec![("invariant_dummy()", true, None, None, None)], + )]), + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_invariant_assume_respects_restrictions() { + let filter = Filter::new(".*", ".*", ".*fuzz/invariant/common/InvariantAssume.t.sol"); + let mut runner = runner(); + runner.test_options.invariant.max_assume_rejects = 1; + let results = runner.test_collect(&filter); + assert_multiple( + &results, + BTreeMap::from([( + "fuzz/invariant/common/InvariantAssume.t.sol:InvariantAssume", + vec![( + "invariant_dummy()", + false, + Some("The `vm.assume` cheatcode rejected too many inputs (1 allowed)".into()), + None, + None, + )], + )]), + ); +} diff --git a/crates/forge/tests/it/test_helpers.rs b/crates/forge/tests/it/test_helpers.rs index 968a0928070e..609181f1a8e3 100644 --- a/crates/forge/tests/it/test_helpers.rs +++ b/crates/forge/tests/it/test_helpers.rs @@ -111,6 +111,7 @@ pub static TEST_OPTS: Lazy = Lazy::new(|| { shrink_sequence: true, shrink_run_limit: 2usize.pow(18u32), preserve_state: false, + max_assume_rejects: 65536, }) .build(&COMPILED, &PROJECT.paths.root) .expect("Config loaded") diff --git a/testdata/fuzz/invariant/common/InvariantAssume.t.sol b/testdata/fuzz/invariant/common/InvariantAssume.t.sol new file mode 100644 index 000000000000..3065a70a550e --- /dev/null +++ b/testdata/fuzz/invariant/common/InvariantAssume.t.sol @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +pragma solidity ^0.8.0; + +import "ds-test/test.sol"; +import "../../../cheats/Vm.sol"; + +contract Handler is DSTest { + Vm constant vm = Vm(HEVM_ADDRESS); + + function doSomething(uint256 param) public { + vm.assume(param != 0); + } +} + +contract InvariantAssume is DSTest { + Handler handler; + + function setUp() public { + handler = new Handler(); + } + + function invariant_dummy() public {} +}