From 275191f789284882c931aed55214c359972e2759 Mon Sep 17 00:00:00 2001 From: alpharush <0xalpharush@protonmail.com> Date: Thu, 25 Jul 2024 15:09:11 -0500 Subject: [PATCH 01/12] fix: weight methods correctly to avoid skipping some --- fuzzing/fuzzer_worker.go | 7 ------ fuzzing/fuzzer_worker_sequence_generator.go | 25 ++++++++++++++++----- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/fuzzing/fuzzer_worker.go b/fuzzing/fuzzer_worker.go index 7ac958b2..57437227 100644 --- a/fuzzing/fuzzer_worker.go +++ b/fuzzing/fuzzer_worker.go @@ -11,7 +11,6 @@ import ( "github.com/crytic/medusa/fuzzing/coverage" "github.com/crytic/medusa/fuzzing/valuegeneration" "github.com/crytic/medusa/utils" - "github.com/crytic/medusa/utils/randomutils" "github.com/ethereum/go-ethereum/common" "golang.org/x/exp/maps" ) @@ -44,9 +43,6 @@ type FuzzerWorker struct { // pureMethods is a list of contract functions which are side-effect free with respect to the EVM (view and/or pure in terms of Solidity mutability). pureMethods []fuzzerTypes.DeployedContractMethod - // methodChooser uses a weighted selection algorithm to choose a method to call, prioritizing state changing methods over pure ones. - methodChooser *randomutils.WeightedRandomChooser[fuzzerTypes.DeployedContractMethod] - // randomProvider provides random data as inputs to decisions throughout the worker. randomProvider *rand.Rand // sequenceGenerator creates entirely new or mutated call sequences based on corpus call sequences, for use in @@ -94,7 +90,6 @@ func newFuzzerWorker(fuzzer *Fuzzer, workerIndex int, randomProvider *rand.Rand) coverageTracer: nil, randomProvider: randomProvider, valueSet: valueSet, - methodChooser: randomutils.NewWeightedRandomChooser[fuzzerTypes.DeployedContractMethod](), } worker.sequenceGenerator = NewCallSequenceGenerator(worker, callSequenceGenConfig) worker.shrinkingValueMutator = shrinkingValueMutator @@ -245,10 +240,8 @@ func (fw *FuzzerWorker) updateMethods() { // We favor calling state changing methods over view/pure methods. if method.IsConstant() { fw.pureMethods = append(fw.pureMethods, fuzzerTypes.DeployedContractMethod{Address: contractAddress, Contract: contractDefinition, Method: method}) - fw.methodChooser.AddChoices(randomutils.NewWeightedRandomChoice(fuzzerTypes.DeployedContractMethod{Address: contractAddress, Contract: contractDefinition, Method: method}, big.NewInt(1))) } else { fw.stateChangingMethods = append(fw.stateChangingMethods, fuzzerTypes.DeployedContractMethod{Address: contractAddress, Contract: contractDefinition, Method: method}) - fw.methodChooser.AddChoices(randomutils.NewWeightedRandomChoice(fuzzerTypes.DeployedContractMethod{Address: contractAddress, Contract: contractDefinition, Method: method}, big.NewInt(100))) } } } diff --git a/fuzzing/fuzzer_worker_sequence_generator.go b/fuzzing/fuzzer_worker_sequence_generator.go index edc3b224..c0267ab3 100644 --- a/fuzzing/fuzzer_worker_sequence_generator.go +++ b/fuzzing/fuzzer_worker_sequence_generator.go @@ -5,6 +5,7 @@ import ( "math/big" "github.com/crytic/medusa/fuzzing/calls" + "github.com/crytic/medusa/fuzzing/contracts" "github.com/crytic/medusa/fuzzing/valuegeneration" "github.com/crytic/medusa/utils" "github.com/crytic/medusa/utils/randomutils" @@ -274,14 +275,26 @@ func (g *CallSequenceGenerator) PopSequenceElement() (*calls.CallSequenceElement // deployed to the CallSequenceGenerator's parent FuzzerWorker chain, with fuzzed call data. // Returns the call sequence element, or an error if one was encountered. func (g *CallSequenceGenerator) generateNewElement() (*calls.CallSequenceElement, error) { - // Verify we have state changing methods to call if we are not testing view/pure methods. - if len(g.worker.stateChangingMethods) == 0 && !g.worker.fuzzer.config.Fuzzing.Testing.AssertionTesting.TestViewMethods { - return nil, fmt.Errorf("cannot generate fuzzed tx as there are no state changing methods to call") + // Verify we have state changing methods to call. + onlyPure := false + if len(g.worker.stateChangingMethods) == 0 { + if !g.worker.fuzzer.config.Fuzzing.Testing.AssertionTesting.TestViewMethods { + return nil, fmt.Errorf("cannot generate fuzzed tx as there are no state changing methods to call") + } else if len(g.worker.pureMethods) == 0 { + return nil, fmt.Errorf("cannot generate fuzzed call as there are no methods to call") + } else { + // TestViewMethods && len(g.worker.pureMethods) > 0 + onlyPure = true + } } // Select a random method and sender - selectedMethod, err := g.worker.methodChooser.Choose() - if err != nil { - return nil, err + + // If available, 1 out 100 calls will be pure/view method calls. + var selectedMethod *contracts.DeployedContractMethod + if len(g.worker.pureMethods) > 0 && g.worker.randomProvider.Intn(100) == 0 || onlyPure { + selectedMethod = &g.worker.pureMethods[g.worker.randomProvider.Intn(len(g.worker.pureMethods))] + } else { + selectedMethod = &g.worker.stateChangingMethods[g.worker.randomProvider.Intn(len(g.worker.stateChangingMethods))] } selectedSender := g.worker.fuzzer.senders[g.worker.randomProvider.Intn(len(g.worker.fuzzer.senders))] From 7a1d236be771760231eb2950c3ade8c4b37d44b8 Mon Sep 17 00:00:00 2001 From: Anish Naik Date: Thu, 25 Jul 2024 16:54:37 -0400 Subject: [PATCH 02/12] fix commenting --- fuzzing/fuzzer_worker_sequence_generator.go | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/fuzzing/fuzzer_worker_sequence_generator.go b/fuzzing/fuzzer_worker_sequence_generator.go index c0267ab3..ae91a267 100644 --- a/fuzzing/fuzzer_worker_sequence_generator.go +++ b/fuzzing/fuzzer_worker_sequence_generator.go @@ -275,21 +275,28 @@ func (g *CallSequenceGenerator) PopSequenceElement() (*calls.CallSequenceElement // deployed to the CallSequenceGenerator's parent FuzzerWorker chain, with fuzzed call data. // Returns the call sequence element, or an error if one was encountered. func (g *CallSequenceGenerator) generateNewElement() (*calls.CallSequenceElement, error) { - // Verify we have state changing methods to call. + // This will track if the fuzzer worker can only invoke pure functions onlyPure := false + + // Verify we have any method (state-changing or pure) to call if len(g.worker.stateChangingMethods) == 0 { if !g.worker.fuzzer.config.Fuzzing.Testing.AssertionTesting.TestViewMethods { + // There are no state-changing methods and we are not testing view/pure methods return nil, fmt.Errorf("cannot generate fuzzed tx as there are no state changing methods to call") } else if len(g.worker.pureMethods) == 0 { + // There are no pure functions to call either return nil, fmt.Errorf("cannot generate fuzzed call as there are no methods to call") - } else { - // TestViewMethods && len(g.worker.pureMethods) > 0 - onlyPure = true } + // Now we know that there are no state-changing functions, there are pure functions, and we can call the + // pure functions + onlyPure = true } - // Select a random method and sender - // If available, 1 out 100 calls will be pure/view method calls. + // Select a random sender + selectedSender := g.worker.fuzzer.senders[g.worker.randomProvider.Intn(len(g.worker.fuzzer.senders))] + + // Select a random method + // There is a 1/100 chance that a pure method will be invoked (or there are onl pure functions) var selectedMethod *contracts.DeployedContractMethod if len(g.worker.pureMethods) > 0 && g.worker.randomProvider.Intn(100) == 0 || onlyPure { selectedMethod = &g.worker.pureMethods[g.worker.randomProvider.Intn(len(g.worker.pureMethods))] @@ -297,8 +304,6 @@ func (g *CallSequenceGenerator) generateNewElement() (*calls.CallSequenceElement selectedMethod = &g.worker.stateChangingMethods[g.worker.randomProvider.Intn(len(g.worker.stateChangingMethods))] } - selectedSender := g.worker.fuzzer.senders[g.worker.randomProvider.Intn(len(g.worker.fuzzer.senders))] - // Generate fuzzed parameters for the function call args := make([]any, len(selectedMethod.Method.Inputs)) for i := 0; i < len(args); i++ { From 3f1b0a976812917c2280b9f7f2c99e8adb20efc3 Mon Sep 17 00:00:00 2001 From: alpharush <0xalpharush@protonmail.com> Date: Thu, 25 Jul 2024 16:38:11 -0500 Subject: [PATCH 03/12] add debugging scripts --- DEV.md | 104 ++++++++++++++++++++ fuzzing/fuzzer_worker_sequence_generator.go | 6 +- scripts/corpus_diff.py | 63 ++++++++++++ scripts/corpus_stats.py | 57 +++++++++++ 4 files changed, 227 insertions(+), 3 deletions(-) create mode 100644 DEV.md create mode 100644 scripts/corpus_diff.py create mode 100644 scripts/corpus_stats.py diff --git a/DEV.md b/DEV.md new file mode 100644 index 00000000..7543fe56 --- /dev/null +++ b/DEV.md @@ -0,0 +1,104 @@ +# Debugging and Development + +## Debugging + +The following scripts are available for Medusa developers for debugging changes to the fuzzer. + +### Corpus diff + +The corpus diff script is used to compare two corpora and identify the methods that are present in one but not the other. This is useful for identifying methods that are missing from a corpus that should be present. + +```shell +python3 scripts/corpus_diff.py corpus1 corpus2 +``` + +```shell +Methods only in ~/corpus1: +- clampSplitWeight(uint32,uint32) + +Methods only in ~/corpus2: + +``` + +### Corpus stats + +The corpus stats script is used to generate statistics about a corpus. This includes the number of sequences, the average length of sequences, and the frequency of methods called. + +```shell +python3 scripts/corpus_stats.py corpus +``` + +```shell +Number of Sequences in ~/corpus: 130 + +Average Length of Transactions List: 43 + +Frequency of Methods Called: +- testReceiversReceivedSplit(uint8): 280 +- setMaxEndHints(uint32,uint32): 174 +- setStreamBalanceWithdrawAll(uint8): 139 +- giveClampedAmount(uint8,uint8,uint128): 136 +- receiveStreamsSplitAndCollectToSelf(uint8): 133 +- testSqueezeViewVsActual(uint8,uint8): 128 +- testSqueeze(uint8,uint8): 128 +- testSetStreamBalance(uint8,int128): 128 +- addStreamWithClamping(uint8,uint8,uint160,uint32,uint32,int128): 125 +- removeAllSplits(uint8): 118 +- testSplittableAfterSplit(uint8): 113 +- testSqueezableVsReceived(uint8): 111 +- testBalanceAtInFuture(uint8,uint8,uint160): 108 +- testRemoveStreamShouldNotRevert(uint8,uint256): 103 +- invariantWithdrawAllTokensShouldNotRevert(): 103 +- collect(uint8,uint8): 101 +- invariantAmtPerSecVsMinAmtPerSec(uint8,uint256): 98 +- testSqueezableAmountCantBeWithdrawn(uint8,uint8): 97 +- split(uint8): 97 +- invariantWithdrawAllTokens(): 95 +- testReceiveStreams(uint8,uint32): 93 +- invariantAccountingVsTokenBalance(): 92 +- testSqueezeWithFuzzedHistoryShouldNotRevert(uint8,uint8,uint256,bytes32): 91 +- testSqueezableAmountCantBeUndone(uint8,uint8,uint160,uint32,uint32,int128): 87 +- testCollect(uint8,uint8): 86 +- testSetStreamBalanceWithdrawAllShouldNotRevert(uint8): 86 +- testAddStreamShouldNotRevert(uint8,uint8,uint160,uint32,uint32,int128): 85 +- testReceiveStreamsShouldNotRevert(uint8): 84 +- addSplitsReceiver(uint8,uint8,uint32): 84 +- setStreamBalanceWithClamping(uint8,int128): 82 +- addSplitsReceiverWithClamping(uint8,uint8,uint32): 80 +- testSetStreamBalanceShouldNotRevert(uint8,int128): 80 +- testSplitShouldNotRevert(uint8): 80 +- squeezeAllAndReceiveAndSplitAndCollectToSelf(uint8): 79 +- addStreamImmediatelySqueezable(uint8,uint8,uint160): 79 +- testSetSplitsShouldNotRevert(uint8,uint8,uint32): 78 +- invariantSumAmtDeltaIsZero(uint8): 78 +- testReceiveStreamsViewConsistency(uint8,uint32): 76 +- squeezeToSelf(uint8): 74 +- collectToSelf(uint8): 72 +- setStreams(uint8,uint8,uint160,uint32,uint32,int128): 70 +- receiveStreamsAllCycles(uint8): 69 +- invariantWithdrawShouldAlwaysFail(uint256): 68 +- addStream(uint8,uint8,uint160,uint32,uint32,int128): 68 +- squeezeWithFuzzedHistory(uint8,uint8,uint256,bytes32): 67 +- setStreamsWithClamping(uint8,uint8,uint160,uint32,uint32,int128): 67 +- splitAndCollectToSelf(uint8): 67 +- testSqueezeWithFullyHashedHistory(uint8,uint8): 65 +- give(uint8,uint8,uint128): 65 +- setSplits(uint8,uint8,uint32): 65 +- testSqueezeTwice(uint8,uint8,uint256,bytes32): 65 +- testSetStreamsShouldNotRevert(uint8,uint8,uint160,uint32,uint32,int128): 64 +- squeezeAllSenders(uint8): 63 +- removeStream(uint8,uint256): 62 +- testCollectableAfterSplit(uint8): 58 +- testCollectShouldNotRevert(uint8,uint8): 56 +- testReceiveStreamsViewVsActual(uint8,uint32): 55 +- receiveStreams(uint8,uint32): 55 +- setSplitsWithClamping(uint8,uint8,uint32): 55 +- testGiveShouldNotRevert(uint8,uint8,uint128): 47 +- setStreamBalance(uint8,int128): 47 +- squeezeWithDefaultHistory(uint8,uint8): 45 +- testSplitViewVsActual(uint8): 45 +- testAddSplitsShouldNotRevert(uint8,uint8,uint32): 30 +- testSqueezeWithDefaultHistoryShouldNotRevert(uint8,uint8): 23 + +Number of Unique Methods: 65 +``` diff --git a/fuzzing/fuzzer_worker_sequence_generator.go b/fuzzing/fuzzer_worker_sequence_generator.go index ae91a267..b3f66caf 100644 --- a/fuzzing/fuzzer_worker_sequence_generator.go +++ b/fuzzing/fuzzer_worker_sequence_generator.go @@ -287,8 +287,8 @@ func (g *CallSequenceGenerator) generateNewElement() (*calls.CallSequenceElement // There are no pure functions to call either return nil, fmt.Errorf("cannot generate fuzzed call as there are no methods to call") } - // Now we know that there are no state-changing functions, there are pure functions, and we can call the - // pure functions + // Since there are no state-changing functions, there are pure functions, and TestViewMethods is enabled, we can call + // exclusively pure functions. onlyPure = true } @@ -296,7 +296,7 @@ func (g *CallSequenceGenerator) generateNewElement() (*calls.CallSequenceElement selectedSender := g.worker.fuzzer.senders[g.worker.randomProvider.Intn(len(g.worker.fuzzer.senders))] // Select a random method - // There is a 1/100 chance that a pure method will be invoked (or there are onl pure functions) + // There is a 1/100 chance that a pure method will be invoked (or there are only pure functions) var selectedMethod *contracts.DeployedContractMethod if len(g.worker.pureMethods) > 0 && g.worker.randomProvider.Intn(100) == 0 || onlyPure { selectedMethod = &g.worker.pureMethods[g.worker.randomProvider.Intn(len(g.worker.pureMethods))] diff --git a/scripts/corpus_diff.py b/scripts/corpus_diff.py new file mode 100644 index 00000000..b622f212 --- /dev/null +++ b/scripts/corpus_diff.py @@ -0,0 +1,63 @@ +import os +import json +import sys + +def load_json_files_from_subdirectory(subdirectory): + json_data = [] + for root, _, files in os.walk(subdirectory): + for file in files: + if file.endswith('.json'): + with open(os.path.join(root, file), 'r') as f: + data = json.load(f) + json_data.extend(data) + return json_data + +def extract_unique_methods(transactions): + unique_methods = set() + for tx in transactions: + call_data = tx.get('call', {}) + data_abi_values = call_data.get('dataAbiValues', {}) + method_signature = data_abi_values.get('methodSignature', '') + if method_signature: + unique_methods.add(method_signature) + return unique_methods + +def compare_methods(subdirectory1, subdirectory2): + transactions1 = load_json_files_from_subdirectory(subdirectory1) + transactions2 = load_json_files_from_subdirectory(subdirectory2) + + unique_methods1 = extract_unique_methods(transactions1) + unique_methods2 = extract_unique_methods(transactions2) + + only_in_subdir1 = unique_methods1 - unique_methods2 + only_in_subdir2 = unique_methods2 - unique_methods1 + + return only_in_subdir1, only_in_subdir2 + +def main(subdirectory1, subdirectory2): + + only_in_subdir1, only_in_subdir2 = compare_methods(subdirectory1, subdirectory2) + + print(f"Methods only in {subdirectory1}:") + if len(only_in_subdir1) == 0: + print(" ") + else: + for method in only_in_subdir1: + print(f"- {method}") + print("\n") + + + print(f"Methods only in {subdirectory2}:") + if len(only_in_subdir2) == 0: + print(" ") + else: + for method in only_in_subdir2: + print(f"- {method}") + print("\n") + +if __name__ == '__main__': + if len(sys.argv) != 3: + print("Usage: python3 unique.py ") + print("Compares the unique methods in the two given corpora.") + sys.exit(1) + main(sys.argv[1], sys.argv[2]) diff --git a/scripts/corpus_stats.py b/scripts/corpus_stats.py new file mode 100644 index 00000000..a5c818f8 --- /dev/null +++ b/scripts/corpus_stats.py @@ -0,0 +1,57 @@ +import os +import json +from collections import Counter +import sys + +def load_json_files_from_subdirectory(subdirectory): + json_data = [] + for root, _, files in os.walk(subdirectory): + for file in files: + if file.endswith('.json'): + with open(os.path.join(root, file), 'r') as f: + data = json.load(f) + json_data.append(data) + return json_data + + +def analyze_transactions(transactions, method_counter): + + for tx in transactions: + call_data = tx.get('call', {}) + data_abi_values = call_data.get('dataAbiValues', {}) + method_signature = data_abi_values.get('methodSignature', '') + + method_counter[method_signature] += 1 + + + +def main(subdirectory): + transaction_seqs = load_json_files_from_subdirectory(subdirectory) + + method_counter = Counter() + total_length = 0 + + for seq in transaction_seqs: + analyze_transactions(seq, method_counter) + total_length += len(seq) + + average_length = total_length // len(transaction_seqs) + + print(f"Number of Sequences in {subdirectory}: {len(transaction_seqs)}") + print("\n") + + print(f"Average Length of Transactions List: {average_length}") + print("\n") + print("Frequency of Methods Called:") + for method, count in method_counter.most_common(): + print(f"- {method}: {count}") + print("\n") + print(f"Number of Unique Methods: {len(method_counter)}") + print("\n") + +if __name__ == '__main__': + if len(sys.argv) != 2: + print("Usage: python3 corpus_stats.py ") + print("Computes statistics on the transactions in the given corpus.") + sys.exit(1) + main(sys.argv[1]) From ca07f6c96354336d81f93444c0b05f06afe4fd13 Mon Sep 17 00:00:00 2001 From: Anish Naik Date: Thu, 25 Jul 2024 18:14:09 -0400 Subject: [PATCH 04/12] zero clue if i optimized anything at all... --- fuzzing/fuzzer_worker.go | 6 ++-- fuzzing/fuzzer_worker_sequence_generator.go | 32 +++++++++------------ 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/fuzzing/fuzzer_worker.go b/fuzzing/fuzzer_worker.go index 57437227..b5f2db80 100644 --- a/fuzzing/fuzzer_worker.go +++ b/fuzzing/fuzzer_worker.go @@ -237,9 +237,11 @@ func (fw *FuzzerWorker) updateMethods() { // If we deployed the contract, also enumerate property tests and state changing methods. for _, method := range contractDefinition.AssertionTestMethods { // Any non-constant method should be tracked as a state changing method. - // We favor calling state changing methods over view/pure methods. if method.IsConstant() { - fw.pureMethods = append(fw.pureMethods, fuzzerTypes.DeployedContractMethod{Address: contractAddress, Contract: contractDefinition, Method: method}) + // Only track the pure/view method if testing view methods is enabled + if fw.fuzzer.config.Fuzzing.Testing.AssertionTesting.TestViewMethods { + fw.pureMethods = append(fw.pureMethods, fuzzerTypes.DeployedContractMethod{Address: contractAddress, Contract: contractDefinition, Method: method}) + } } else { fw.stateChangingMethods = append(fw.stateChangingMethods, fuzzerTypes.DeployedContractMethod{Address: contractAddress, Contract: contractDefinition, Method: method}) } diff --git a/fuzzing/fuzzer_worker_sequence_generator.go b/fuzzing/fuzzer_worker_sequence_generator.go index ae91a267..666efb23 100644 --- a/fuzzing/fuzzer_worker_sequence_generator.go +++ b/fuzzing/fuzzer_worker_sequence_generator.go @@ -275,35 +275,29 @@ func (g *CallSequenceGenerator) PopSequenceElement() (*calls.CallSequenceElement // deployed to the CallSequenceGenerator's parent FuzzerWorker chain, with fuzzed call data. // Returns the call sequence element, or an error if one was encountered. func (g *CallSequenceGenerator) generateNewElement() (*calls.CallSequenceElement, error) { - // This will track if the fuzzer worker can only invoke pure functions - onlyPure := false - - // Verify we have any method (state-changing or pure) to call - if len(g.worker.stateChangingMethods) == 0 { - if !g.worker.fuzzer.config.Fuzzing.Testing.AssertionTesting.TestViewMethods { - // There are no state-changing methods and we are not testing view/pure methods - return nil, fmt.Errorf("cannot generate fuzzed tx as there are no state changing methods to call") - } else if len(g.worker.pureMethods) == 0 { - // There are no pure functions to call either - return nil, fmt.Errorf("cannot generate fuzzed call as there are no methods to call") - } - // Now we know that there are no state-changing functions, there are pure functions, and we can call the - // pure functions - onlyPure = true + // Check to make sure that we have any functions to call + if len(g.worker.stateChangingMethods) == 0 && len(g.worker.pureMethods) == 0 { + return nil, fmt.Errorf("cannot generate fuzzed call as there are no methods to call") } - // Select a random sender - selectedSender := g.worker.fuzzer.senders[g.worker.randomProvider.Intn(len(g.worker.fuzzer.senders))] + // Only call view functions if there are no state-changing methods + var callOnlyPureFunctions bool + if len(g.worker.stateChangingMethods) == 0 && len(g.worker.pureMethods) > 0 { + callOnlyPureFunctions = true + } // Select a random method - // There is a 1/100 chance that a pure method will be invoked (or there are onl pure functions) + // There is a 1/100 chance that a pure method will be invoked or if there are only pure functions that are callable var selectedMethod *contracts.DeployedContractMethod - if len(g.worker.pureMethods) > 0 && g.worker.randomProvider.Intn(100) == 0 || onlyPure { + if (len(g.worker.pureMethods) > 0 && g.worker.randomProvider.Intn(100) == 0) || callOnlyPureFunctions { selectedMethod = &g.worker.pureMethods[g.worker.randomProvider.Intn(len(g.worker.pureMethods))] } else { selectedMethod = &g.worker.stateChangingMethods[g.worker.randomProvider.Intn(len(g.worker.stateChangingMethods))] } + // Select a random sender + selectedSender := g.worker.fuzzer.senders[g.worker.randomProvider.Intn(len(g.worker.fuzzer.senders))] + // Generate fuzzed parameters for the function call args := make([]any, len(selectedMethod.Method.Inputs)) for i := 0; i < len(args); i++ { From d1d6344cc3ccb31062bd46017d31c2cd43fbb322 Mon Sep 17 00:00:00 2001 From: alpharush <0xalpharush@protonmail.com> Date: Thu, 25 Jul 2024 22:32:52 -0500 Subject: [PATCH 05/12] upload artifact on every PR --- .github/workflows/ci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 98641bb0..36aade25 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -87,7 +87,6 @@ jobs: inputs: ./medusa-*.tar.gz - name: Upload artifact - if: github.ref == 'refs/heads/master' || (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')) uses: actions/upload-artifact@v4 with: name: medusa-${{ runner.os }}-${{ runner.arch }} From f54d046ccaae92c3fd68f765aa8830ffe46a17be Mon Sep 17 00:00:00 2001 From: alpharush <0xalpharush@protonmail.com> Date: Fri, 26 Jul 2024 15:50:44 -0500 Subject: [PATCH 06/12] fix: log number of workers shrinking (#8) * fix: log number of workers shrinking * report total # failed sequences/ total sequences tested --- fuzzing/fuzzer.go | 7 +++++-- fuzzing/fuzzer_metrics.go | 12 ++++++++++++ fuzzing/test_case_assertion_provider.go | 2 ++ fuzzing/test_case_property_provider.go | 1 + 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/fuzzing/fuzzer.go b/fuzzing/fuzzer.go index 960ebfe2..f1171cc2 100644 --- a/fuzzing/fuzzer.go +++ b/fuzzing/fuzzer.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "github.com/ethereum/go-ethereum/crypto" "math/big" "math/rand" "os" @@ -16,6 +15,8 @@ import ( "sync" "time" + "github.com/ethereum/go-ethereum/crypto" + "github.com/crytic/medusa/fuzzing/executiontracer" "github.com/crytic/medusa/fuzzing/coverage" @@ -846,6 +847,7 @@ func (f *Fuzzer) printMetricsLoop() { // Obtain our metrics callsTested := f.metrics.CallsTested() sequencesTested := f.metrics.SequencesTested() + failedSequences := f.metrics.FailedSequences() workerStartupCount := f.metrics.WorkerStartupCount() workersShrinking := f.metrics.WorkersShrinkingCount() @@ -865,8 +867,9 @@ func (f *Fuzzer) printMetricsLoop() { logBuffer.Append(", calls: ", colors.Bold, fmt.Sprintf("%d (%d/sec)", callsTested, uint64(float64(new(big.Int).Sub(callsTested, lastCallsTested).Uint64())/secondsSinceLastUpdate)), colors.Reset) logBuffer.Append(", seq/s: ", colors.Bold, fmt.Sprintf("%d", uint64(float64(new(big.Int).Sub(sequencesTested, lastSequencesTested).Uint64())/secondsSinceLastUpdate)), colors.Reset) logBuffer.Append(", coverage: ", colors.Bold, fmt.Sprintf("%d", f.corpus.ActiveMutableSequenceCount()), colors.Reset) + logBuffer.Append(", shrinking: ", colors.Bold, fmt.Sprintf("%v", workersShrinking), colors.Reset) + logBuffer.Append(", failures: ", colors.Bold, fmt.Sprintf("%d/%d", failedSequences, sequencesTested), colors.Reset) if f.logger.Level() <= zerolog.DebugLevel { - logBuffer.Append(", shrinking: ", colors.Bold, fmt.Sprintf("%v", workersShrinking), colors.Reset) logBuffer.Append(", mem: ", colors.Bold, fmt.Sprintf("%v/%v MB", memoryUsedMB, memoryTotalMB), colors.Reset) logBuffer.Append(", resets/s: ", colors.Bold, fmt.Sprintf("%d", uint64(float64(new(big.Int).Sub(workerStartupCount, lastWorkerStartupCount).Uint64())/secondsSinceLastUpdate)), colors.Reset) } diff --git a/fuzzing/fuzzer_metrics.go b/fuzzing/fuzzer_metrics.go index 70fc3788..b0984ab0 100644 --- a/fuzzing/fuzzer_metrics.go +++ b/fuzzing/fuzzer_metrics.go @@ -14,6 +14,9 @@ type fuzzerWorkerMetrics struct { // sequencesTested describes the amount of sequences of transactions which tests were run against. sequencesTested *big.Int + //failedSequences describes the amount of sequences of transactions which tests failed. + failedSequences *big.Int + // callsTested describes the amount of transactions/calls the fuzzer executed and ran tests against. callsTested *big.Int @@ -33,12 +36,21 @@ func newFuzzerMetrics(workerCount int) *FuzzerMetrics { } for i := 0; i < len(metrics.workerMetrics); i++ { metrics.workerMetrics[i].sequencesTested = big.NewInt(0) + metrics.workerMetrics[i].failedSequences = big.NewInt(0) metrics.workerMetrics[i].callsTested = big.NewInt(0) metrics.workerMetrics[i].workerStartupCount = big.NewInt(0) } return &metrics } +func (m *FuzzerMetrics) FailedSequences() *big.Int { + failedSequences := big.NewInt(0) + for _, workerMetrics := range m.workerMetrics { + failedSequences.Add(failedSequences, workerMetrics.failedSequences) + } + return failedSequences +} + // SequencesTested returns the amount of sequences of transactions the fuzzer executed and ran tests against. func (m *FuzzerMetrics) SequencesTested() *big.Int { sequencesTested := big.NewInt(0) diff --git a/fuzzing/test_case_assertion_provider.go b/fuzzing/test_case_assertion_provider.go index 8ab4a5bd..f9b9978a 100644 --- a/fuzzing/test_case_assertion_provider.go +++ b/fuzzing/test_case_assertion_provider.go @@ -1,6 +1,7 @@ package fuzzing import ( + "math/big" "sync" "github.com/crytic/medusa/compilation/abiutils" @@ -212,6 +213,7 @@ func (t *AssertionTestCaseProvider) callSequencePostCallTest(worker *FuzzerWorke // Update our test state and report it finalized. testCase.status = TestCaseStatusFailed testCase.callSequence = &shrunkenCallSequence + worker.workerMetrics().failedSequences.Add(worker.workerMetrics().failedSequences, big.NewInt(1)) worker.Fuzzer().ReportTestCaseFinished(testCase) return nil }, diff --git a/fuzzing/test_case_property_provider.go b/fuzzing/test_case_property_provider.go index 3681d218..6bb6d419 100644 --- a/fuzzing/test_case_property_provider.go +++ b/fuzzing/test_case_property_provider.go @@ -332,6 +332,7 @@ func (t *PropertyTestCaseProvider) callSequencePostCallTest(worker *FuzzerWorker testCase.status = TestCaseStatusFailed testCase.callSequence = &shrunkenCallSequence testCase.propertyTestTrace = executionTrace + worker.workerMetrics().failedSequences.Add(worker.workerMetrics().failedSequences, big.NewInt(1)) worker.Fuzzer().ReportTestCaseFinished(testCase) return nil }, From 21022fd4749ab8dabc7a45891b6d78661ac62b89 Mon Sep 17 00:00:00 2001 From: alpharush <0xalpharush@protonmail.com> Date: Sat, 27 Jul 2024 23:19:32 -0500 Subject: [PATCH 07/12] fix: lookup by source unit id instead of reyling on sourceList ordering --- chain/test_chain_test.go | 12 +-- compilation/platforms/crytic_compile.go | 60 +++++++++--- compilation/platforms/crytic_compile_test.go | 43 ++++----- compilation/platforms/solc.go | 47 ++++++---- compilation/types/ast.go | 98 ++++++++++++++++++++ compilation/types/compilation.go | 26 ++---- compilation/types/compiled_contract.go | 3 + compilation/types/compiled_source.go | 7 +- compilation/types/source_maps.go | 11 ++- fuzzing/coverage/source_analysis.go | 30 +++--- 10 files changed, 242 insertions(+), 95 deletions(-) create mode 100644 compilation/types/ast.go diff --git a/chain/test_chain_test.go b/chain/test_chain_test.go index 5d33ceea..ff0ca589 100644 --- a/chain/test_chain_test.go +++ b/chain/test_chain_test.go @@ -215,7 +215,7 @@ func TestChainDynamicDeployments(t *testing.T) { compilations, _, err := cryticCompile.Compile() assert.NoError(t, err) assert.EqualValues(t, 1, len(compilations)) - assert.EqualValues(t, 1, len(compilations[0].Sources)) + assert.EqualValues(t, 1, len(compilations[0].SourcePathToArtifact)) // Obtain our chain and senders chain, senders := createChain(t) @@ -223,7 +223,7 @@ func TestChainDynamicDeployments(t *testing.T) { // Deploy each contract that has no construct arguments. deployCount := 0 for _, compilation := range compilations { - for _, source := range compilation.Sources { + for _, source := range compilation.SourcePathToArtifact { for _, contract := range source.Contracts { contract := contract if len(contract.Abi.Constructor.Inputs) == 0 { @@ -329,7 +329,7 @@ func TestChainDeploymentWithArgs(t *testing.T) { compilations, _, err := cryticCompile.Compile() assert.NoError(t, err) assert.EqualValues(t, 1, len(compilations)) - assert.EqualValues(t, 1, len(compilations[0].Sources)) + assert.EqualValues(t, 1, len(compilations[0].SourcePathToArtifact)) // Obtain our chain and senders chain, senders := createChain(t) @@ -346,7 +346,7 @@ func TestChainDeploymentWithArgs(t *testing.T) { // Deploy each contract deployCount := 0 for _, compilation := range compilations { - for _, source := range compilation.Sources { + for _, source := range compilation.SourcePathToArtifact { for contractName, contract := range source.Contracts { contract := contract @@ -467,7 +467,7 @@ func TestChainCloning(t *testing.T) { // Deploy each contract that has no construct arguments 10 times. for _, compilation := range compilations { - for _, source := range compilation.Sources { + for _, source := range compilation.SourcePathToArtifact { for _, contract := range source.Contracts { contract := contract if len(contract.Abi.Constructor.Inputs) == 0 { @@ -563,7 +563,7 @@ func TestChainCallSequenceReplayMatchSimple(t *testing.T) { // Deploy each contract that has no construct arguments 10 times. for _, compilation := range compilations { - for _, source := range compilation.Sources { + for _, source := range compilation.SourcePathToArtifact { for _, contract := range source.Contracts { contract := contract if len(contract.Abi.Constructor.Inputs) == 0 { diff --git a/compilation/platforms/crytic_compile.go b/compilation/platforms/crytic_compile.go index 377d1a6d..1b74da1c 100644 --- a/compilation/platforms/crytic_compile.go +++ b/compilation/platforms/crytic_compile.go @@ -8,6 +8,8 @@ import ( "os" "os/exec" "path/filepath" + "regexp" + "strconv" "strings" "github.com/crytic/medusa/compilation/types" @@ -87,6 +89,19 @@ func (c *CryticCompilationConfig) getArgs() ([]string, error) { return args, nil } +func getSourceUnitID(src string) int { + re := regexp.MustCompile(`[0-9]*:[0-9]*:([0-9]*)`) + sourceUnitCandidates := re.FindStringSubmatch(src) + + if len(sourceUnitCandidates) == 2 { // FindStringSubmatch includes the whole match as the first element + sourceUnit, err := strconv.Atoi(sourceUnitCandidates[1]) + if err == nil { + return sourceUnit + } + } + return -1 +} + // Compile uses the CryticCompilationConfig provided to compile a given target, parse the artifacts, and then // create a list of types.Compilation. func (c *CryticCompilationConfig) Compile() ([]types.Compilation, string, error) { @@ -143,7 +158,7 @@ func (c *CryticCompilationConfig) Compile() ([]types.Compilation, string, error) var compilationList []types.Compilation // Define the structure of our crytic-compile export data. - type solcExportSource struct { + type solcSourceUnit struct { AST any `json:"AST"` } type solcExportContract struct { @@ -154,9 +169,8 @@ func (c *CryticCompilationConfig) Compile() ([]types.Compilation, string, error) BinRuntime string `json:"bin-runtime"` } type solcExportData struct { - Sources map[string]solcExportSource `json:"sources"` - Contracts map[string]solcExportContract `json:"contracts"` - SourceList []string `json:"sourceList"` + Sources map[string]solcSourceUnit `json:"sources"` + Contracts map[string]solcExportContract `json:"contracts"` } // Loop through each .json file for compilation units. @@ -176,14 +190,35 @@ func (c *CryticCompilationConfig) Compile() ([]types.Compilation, string, error) // Create a compilation object that will store the contracts and source information. compilation := types.NewCompilation() - compilation.SourceList = solcExport.SourceList + + // Create a map of contract names to their kinds + contractKinds := make(map[string]types.ContractKind) // Loop through all sources and parse them into our types. for sourcePath, source := range solcExport.Sources { - compilation.Sources[sourcePath] = types.CompiledSource{ - Ast: source.AST, - Contracts: make(map[string]types.CompiledContract), + + var ast types.AST + b, _ := json.Marshal(source.AST) + err := json.Unmarshal(b, &ast) + if err != nil { + return nil, "", fmt.Errorf("could not parse AST from sources, error: %v", err) + } + // From the AST, extract the contract kinds. + for _, node := range ast.Nodes { + if node.GetNodeType() == "ContractDefinition" { + cdef := node.(types.ContractDefinition) + contractKinds[cdef.CanonicalName] = cdef.ContractKind + } + } + + sourceUnitId := getSourceUnitID(ast.Src) + compilation.SourcePathToArtifact[sourcePath] = types.SourceArtifact{ + // TODO our types.AST is not the same as the original AST but we could parse it and avoid using "any" + Ast: source.AST, + Contracts: make(map[string]types.CompiledContract), + SourceUnitId: sourceUnitId, } + compilation.SourceIdToPath[sourceUnitId] = sourcePath } // Loop through all contracts and parse them into our types. @@ -198,12 +233,12 @@ func (c *CryticCompilationConfig) Compile() ([]types.Compilation, string, error) // Ensure a source exists for this, or create one if our path somehow differed from any // path not existing in the "sources" key at the root of the export. - if _, ok := compilation.Sources[sourcePath]; !ok { - parentSource := types.CompiledSource{ + if _, ok := compilation.SourcePathToArtifact[sourcePath]; !ok { + parentSource := types.SourceArtifact{ Ast: nil, Contracts: make(map[string]types.CompiledContract), } - compilation.Sources[sourcePath] = parentSource + compilation.SourcePathToArtifact[sourcePath] = parentSource } // Parse the ABI @@ -223,12 +258,13 @@ func (c *CryticCompilationConfig) Compile() ([]types.Compilation, string, error) } // Add contract details - compilation.Sources[sourcePath].Contracts[contractName] = types.CompiledContract{ + compilation.SourcePathToArtifact[sourcePath].Contracts[contractName] = types.CompiledContract{ Abi: *contractAbi, InitBytecode: initBytecode, RuntimeBytecode: runtimeBytecode, SrcMapsInit: contract.SrcMap, SrcMapsRuntime: contract.SrcMapRuntime, + Kind: contractKinds[contractName], } } diff --git a/compilation/platforms/crytic_compile_test.go b/compilation/platforms/crytic_compile_test.go index 3c4ac420..5f0f6e70 100644 --- a/compilation/platforms/crytic_compile_test.go +++ b/compilation/platforms/crytic_compile_test.go @@ -1,22 +1,23 @@ package platforms import ( - "github.com/crytic/medusa/compilation/types" - "github.com/crytic/medusa/utils" - "github.com/crytic/medusa/utils/testutils" - "github.com/stretchr/testify/assert" "os" "os/exec" "path/filepath" "strings" "testing" + + "github.com/crytic/medusa/compilation/types" + "github.com/crytic/medusa/utils" + "github.com/crytic/medusa/utils/testutils" + "github.com/stretchr/testify/assert" ) // testCryticGetCompiledSourceByBaseName checks if a given source file exists in a given compilation's map of sources. // The source file is the file name of a specific file. This function simply checks one of the paths ends with // this name. Avoid including any directories in case the path separators differ per system. // Returns the types.CompiledSource (mapping value) associated to the path if it is found. Returns nil otherwise. -func testCryticGetCompiledSourceByBaseName(sources map[string]types.CompiledSource, name string) *types.CompiledSource { +func testCryticGetCompiledSourceByBaseName(sources map[string]types.SourceArtifact, name string) *types.SourceArtifact { // Obtain a lower case version of our name to search for lowerName := strings.ToLower(name) @@ -53,10 +54,10 @@ func TestCryticSingleFileAbsolutePath(t *testing.T) { // One compilation object assert.EqualValues(t, 1, len(compilations)) // One source because we specified one file - assert.EqualValues(t, 1, len(compilations[0].Sources)) + assert.EqualValues(t, 1, len(compilations[0].SourcePathToArtifact)) // Two contracts in SimpleContract.sol contractCount := 0 - for _, source := range compilations[0].Sources { + for _, source := range compilations[0].SourcePathToArtifact { contractCount += len(source.Contracts) } assert.EqualValues(t, 2, contractCount) @@ -82,10 +83,10 @@ func TestCryticSingleFileRelativePathSameDirectory(t *testing.T) { // One compilation object assert.EqualValues(t, 1, len(compilations)) // One source because we specified one file - assert.EqualValues(t, 1, len(compilations[0].Sources)) + assert.EqualValues(t, 1, len(compilations[0].SourcePathToArtifact)) // Two contracts in SimpleContract.sol contractCount := 0 - for _, source := range compilations[0].Sources { + for _, source := range compilations[0].SourcePathToArtifact { contractCount += len(source.Contracts) } assert.EqualValues(t, 2, contractCount) @@ -118,10 +119,10 @@ func TestCryticSingleFileRelativePathChildDirectory(t *testing.T) { // One compilation object assert.EqualValues(t, 1, len(compilations)) // One source because we specified one file - assert.EqualValues(t, 1, len(compilations[0].Sources)) + assert.EqualValues(t, 1, len(compilations[0].SourcePathToArtifact)) // Two contracts in SimpleContract.sol contractCount := 0 - for _, source := range compilations[0].Sources { + for _, source := range compilations[0].SourcePathToArtifact { contractCount += len(source.Contracts) } assert.EqualValues(t, 2, contractCount) @@ -160,9 +161,9 @@ func TestCryticSingleFileBuildDirectoryArgRelativePath(t *testing.T) { // One compilation object assert.EqualValues(t, 1, len(compilations)) // One source because we specified one file - assert.EqualValues(t, 1, len(compilations[0].Sources)) + assert.EqualValues(t, 1, len(compilations[0].SourcePathToArtifact)) // Two contracts in SimpleContract.sol. - compiledSource := testCryticGetCompiledSourceByBaseName(compilations[0].Sources, contractName) + compiledSource := testCryticGetCompiledSourceByBaseName(compilations[0].SourcePathToArtifact, contractName) assert.NotNil(t, compiledSource, "source file could not be resolved in compilation sources") assert.EqualValues(t, 2, len(compiledSource.Contracts)) }) @@ -215,11 +216,11 @@ func TestCryticMultipleFiles(t *testing.T) { // Verify there is one compilation object assert.EqualValues(t, 1, len(compilations)) // Verify there are two sources - assert.EqualValues(t, 2, len(compilations[0].Sources)) + assert.EqualValues(t, 2, len(compilations[0].SourcePathToArtifact)) // Verify there are three contracts contractCount := 0 - for _, source := range compilations[0].Sources { + for _, source := range compilations[0].SourcePathToArtifact { contractCount += len(source.Contracts) } assert.EqualValues(t, 3, contractCount) @@ -247,16 +248,16 @@ func TestCryticDirectoryNoArgs(t *testing.T) { // Two compilation objects assert.EqualValues(t, 2, len(compilations)) // One source per compilation unit - assert.EqualValues(t, 1, len(compilations[0].Sources)) - assert.EqualValues(t, 1, len(compilations[1].Sources)) + assert.EqualValues(t, 1, len(compilations[0].SourcePathToArtifact)) + assert.EqualValues(t, 1, len(compilations[1].SourcePathToArtifact)) // Obtain the compiled source from both compilation units firstContractName := "FirstContract.sol" secondContractName := "SecondContract.sol" - firstUnitFirstContractSource := testCryticGetCompiledSourceByBaseName(compilations[0].Sources, firstContractName) - firstUnitSecondContractSource := testCryticGetCompiledSourceByBaseName(compilations[0].Sources, secondContractName) - secondUnitFirstContractSource := testCryticGetCompiledSourceByBaseName(compilations[1].Sources, firstContractName) - secondUnitSecondContractSource := testCryticGetCompiledSourceByBaseName(compilations[1].Sources, secondContractName) + firstUnitFirstContractSource := testCryticGetCompiledSourceByBaseName(compilations[0].SourcePathToArtifact, firstContractName) + firstUnitSecondContractSource := testCryticGetCompiledSourceByBaseName(compilations[0].SourcePathToArtifact, secondContractName) + secondUnitFirstContractSource := testCryticGetCompiledSourceByBaseName(compilations[1].SourcePathToArtifact, firstContractName) + secondUnitSecondContractSource := testCryticGetCompiledSourceByBaseName(compilations[1].SourcePathToArtifact, secondContractName) // Assert that each compilation unit should have two contracts in it. // Compilation unit ordering is non-deterministic in JSON output diff --git a/compilation/platforms/solc.go b/compilation/platforms/solc.go index 283d8611..572553f0 100644 --- a/compilation/platforms/solc.go +++ b/compilation/platforms/solc.go @@ -104,23 +104,14 @@ func (s *SolcCompilationConfig) Compile() ([]types.Compilation, string, error) { // Create a compilation unit out of this. compilation := types.NewCompilation() - if sourceList, ok := results["sourceList"]; ok { - if sourceListCasted, ok := sourceList.([]any); ok { - compilation.SourceList = make([]string, len(sourceListCasted)) - for i := 0; i < len(sourceListCasted); i++ { - compilation.SourceList[i] = sourceListCasted[i].(string) - } - } else { - return nil, "", fmt.Errorf("could not parse compiled source artifact because 'sourcesList' was not a []string type") - } - } else { - return nil, "", fmt.Errorf("could not parse compiled source artifact because 'sourcesList' did not exist") - } + + // Create a map of contract names to their kinds + contractKinds := make(map[string]types.ContractKind) // Parse our sources from solc output if sources, ok := results["sources"]; ok { if sourcesMap, ok := sources.(map[string]any); ok { - for name, source := range sourcesMap { + for sourcePath, source := range sourcesMap { // Treat our source as a key-value lookup sourceDict, sourceCorrectType := source.(map[string]any) if !sourceCorrectType { @@ -128,16 +119,35 @@ func (s *SolcCompilationConfig) Compile() ([]types.Compilation, string, error) { } // Try to obtain our AST key - ast, hasAST := sourceDict["AST"] + origAST, hasAST := sourceDict["AST"] if !hasAST { return nil, "", fmt.Errorf("could not parse AST from sources, AST field could not be found") } + var ast types.AST + b, _ := json.Marshal(sourceDict["AST"]) + err := json.Unmarshal(b, &ast) + if err != nil { + return nil, "", fmt.Errorf("could not parse AST from sources, error: %v", err) + } + // From the AST, extract the contract kinds. + for _, node := range ast.Nodes { + if node.GetNodeType() == "ContractDefinition" { + cdef := node.(types.ContractDefinition) + contractKinds[cdef.CanonicalName] = cdef.ContractKind + } + } + + sourceUnitId := getSourceUnitID(ast.Src) // Construct our compiled source object - compilation.Sources[name] = types.CompiledSource{ - Ast: ast, - Contracts: make(map[string]types.CompiledContract), + compilation.SourcePathToArtifact[sourcePath] = types.SourceArtifact{ + // TODO our types.AST is not the same as the original AST but we could parse it and avoid using "any" + Ast: origAST, + Contracts: make(map[string]types.CompiledContract), + SourceUnitId: sourceUnitId, } + compilation.SourceIdToPath[sourceUnitId] = sourcePath + } } } @@ -171,12 +181,13 @@ func (s *SolcCompilationConfig) Compile() ([]types.Compilation, string, error) { } // Construct our compiled contract - compilation.Sources[sourcePath].Contracts[contractName] = types.CompiledContract{ + compilation.SourcePathToArtifact[sourcePath].Contracts[contractName] = types.CompiledContract{ Abi: *contractAbi, InitBytecode: initBytecode, RuntimeBytecode: runtimeBytecode, SrcMapsInit: contract.Info.SrcMap.(string), SrcMapsRuntime: contract.Info.SrcMapRuntime, + Kind: contractKinds[contractName], } } diff --git a/compilation/types/ast.go b/compilation/types/ast.go new file mode 100644 index 00000000..3f0d2866 --- /dev/null +++ b/compilation/types/ast.go @@ -0,0 +1,98 @@ +package types + +import ( + "encoding/json" + "fmt" +) + +// ContractKind represents the kind of contract +type ContractKind string + +const ( + ContractKindContract ContractKind = "contract" + ContractKindLibrary ContractKind = "library" + ContractKindInterface ContractKind = "interface" +) + +// ContractKindFromString converts a string to a ContractKind +func ContractKindFromString(s string) ContractKind { + switch s { + case "contract": + return ContractKindContract + case "library": + return ContractKindLibrary + case "interface": + return ContractKindInterface + default: + panic(fmt.Sprintf("unknown contract kind: %s", s)) + } +} + +// Node interface represents a generic AST node +type Node interface { + GetNodeType() string +} + +// ContractDefinition is the contract definition node +type ContractDefinition struct { + NodeType string `json:"nodeType"` + CanonicalName string `json:"canonicalName,omitempty"` + ContractKind ContractKind `json:"contractKind,omitempty"` +} + +func (s ContractDefinition) GetNodeType() string { + return s.NodeType +} + +// AST is the abstract syntax tree +type AST struct { + NodeType string `json:"nodeType"` + Nodes []Node `json:"nodes"` + Src string `json:"src"` +} + +// UnmarshalJSON custom unmarshaller for AST +func (a *AST) UnmarshalJSON(data []byte) error { + type Alias AST + aux := &struct { + Nodes []json.RawMessage `json:"nodes"` + *Alias + }{ + Alias: (*Alias)(a), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + // Check if nodeType is "SourceUnit" + if aux.NodeType != "SourceUnit" { + return nil + } + + for _, nodeData := range aux.Nodes { + var nodeType struct { + NodeType string `json:"nodeType"` + } + + if err := json.Unmarshal(nodeData, &nodeType); err != nil { + return err + } + + var node Node + switch nodeType.NodeType { + case "ContractDefinition": + var cdef ContractDefinition + if err := json.Unmarshal(nodeData, &cdef); err != nil { + return err + } + node = cdef + // Add cases for other node types as needed + default: + continue + } + + a.Nodes = append(a.Nodes, node) + } + + return nil +} diff --git a/compilation/types/compilation.go b/compilation/types/compilation.go index 9958930b..94539f54 100644 --- a/compilation/types/compilation.go +++ b/compilation/types/compilation.go @@ -3,19 +3,16 @@ package types import ( "errors" "fmt" - "golang.org/x/exp/slices" "os" ) // Compilation represents the artifacts of a smart contract compilation. type Compilation struct { - // Sources describes the CompiledSource objects provided in a compilation, housing information regarding source - // files, mappings, ASTs, and contracts. - Sources map[string]CompiledSource + // SourcePathToArtifact maps source file paths to their corresponding SourceArtifact. + SourcePathToArtifact map[string]SourceArtifact - // SourceList describes the CompiledSource keys in Sources, in order. The file identifier used for a SourceMap - // corresponds to an index in this list. - SourceList []string + // SourceIdToPath is a mapping of source unit IDs to source file paths. + SourceIdToPath map[int]string // SourceCode is a lookup of a source file path from SourceList to source code. This is populated by // CacheSourceCode. @@ -26,29 +23,22 @@ type Compilation struct { func NewCompilation() *Compilation { // Create our compilation compilation := &Compilation{ - Sources: make(map[string]CompiledSource), - SourceList: make([]string, 0), - SourceCode: make(map[string][]byte), + SourcePathToArtifact: make(map[string]SourceArtifact), + SourceCode: make(map[string][]byte), + SourceIdToPath: make(map[int]string), } // Return the compilation. return compilation } -// GetSourceFileId obtains the file identifier for a given source file path. This simply checks the index of the -// source file path in SourceList. -// Returns the identifier of the source file, or -1 if it could not be found. -func (c *Compilation) GetSourceFileId(sourcePath string) int { - return slices.Index(c.SourceList, sourcePath) -} - // CacheSourceCode caches source code for each CompiledSource in the compilation in the CompiledSource.SourceCode field. // This method will attempt to populate each CompiledSource.SourceCode which has not yet been populated (is nil) before // returning an error, if one occurs. func (c *Compilation) CacheSourceCode() error { // Loop through each source file, try to read it, and collect errors in an aggregated string if we encounter any. var errStr string - for sourcePath := range c.Sources { + for sourcePath := range c.SourcePathToArtifact { if _, ok := c.SourceCode[sourcePath]; !ok { sourceCodeBytes, sourceReadErr := os.ReadFile(sourcePath) if sourceReadErr != nil { diff --git a/compilation/types/compiled_contract.go b/compilation/types/compiled_contract.go index 8cb9f57a..33693ec8 100644 --- a/compilation/types/compiled_contract.go +++ b/compilation/types/compiled_contract.go @@ -29,6 +29,9 @@ type CompiledContract struct { // SrcMapsRuntime describes the source mappings to associate source file and bytecode segments in RuntimeBytecode. SrcMapsRuntime string + + // Kind describes the kind of contract, i.e. contract, library, interface. + Kind ContractKind } // IsMatch returns a boolean indicating whether provided contract bytecode is a match to this compiled contract diff --git a/compilation/types/compiled_source.go b/compilation/types/compiled_source.go index 9adbd283..2950a74b 100644 --- a/compilation/types/compiled_source.go +++ b/compilation/types/compiled_source.go @@ -1,8 +1,8 @@ package types -// CompiledSource represents a source descriptor for a smart contract compilation, including AST and contained +// SourceArtifact represents a source descriptor for a smart contract compilation, including AST and contained // CompiledContract instances. -type CompiledSource struct { +type SourceArtifact struct { // Ast describes the abstract syntax tree artifact of a source file compilation, providing tokenization of the // source file components. Ast any @@ -10,4 +10,7 @@ type CompiledSource struct { // Contracts describes a mapping of contract names to contract definition structures which are contained within // the source. Contracts map[string]CompiledContract + + // SourceUnitId refers to the identifier of the source unit within the compilation. + SourceUnitId int } diff --git a/compilation/types/source_maps.go b/compilation/types/source_maps.go index c5a92ad0..57da3fb2 100644 --- a/compilation/types/source_maps.go +++ b/compilation/types/source_maps.go @@ -2,9 +2,10 @@ package types import ( "fmt" - "github.com/ethereum/go-ethereum/core/vm" "strconv" "strings" + + "github.com/ethereum/go-ethereum/core/vm" ) // Reference: Source mapping is performed according to the rules specified in solidity documentation: @@ -47,8 +48,8 @@ type SourceMapElement struct { // Length refers to the byte length of the source range the instruction maps to. Length int - // FileID refers to an identifier for the CompiledSource file which houses the relevant source code. - FileID int + // SourceUnitID refers to an identifier for the CompiledSource file which houses the relevant source code. + SourceUnitID int // JumpType refers to the SourceMapJumpType which provides information about any type of jump that occurred. JumpType SourceMapJumpType @@ -83,7 +84,7 @@ func ParseSourceMap(sourceMapStr string) (SourceMap, error) { Index: -1, Offset: -1, Length: -1, - FileID: -1, + SourceUnitID: -1, JumpType: "", ModifierDepth: 0, } @@ -120,7 +121,7 @@ func ParseSourceMap(sourceMapStr string) (SourceMap, error) { // If the source file identifier exists, update our current element data. if len(fields) > 2 && fields[2] != "" { - current.FileID, err = strconv.Atoi(fields[2]) + current.SourceUnitID, err = strconv.Atoi(fields[2]) if err != nil { return nil, err } diff --git a/fuzzing/coverage/source_analysis.go b/fuzzing/coverage/source_analysis.go index baed6505..49b8aafa 100644 --- a/fuzzing/coverage/source_analysis.go +++ b/fuzzing/coverage/source_analysis.go @@ -3,9 +3,10 @@ package coverage import ( "bytes" "fmt" + "sort" + "github.com/crytic/medusa/compilation/types" "golang.org/x/exp/maps" - "sort" ) // SourceAnalysis describes source code coverage across a list of compilations, after analyzing associated CoverageMaps. @@ -117,7 +118,7 @@ func AnalyzeSourceCoverage(compilations []types.Compilation, coverageMaps *Cover // Loop through all sources in all compilations to add them to our source file analysis container. for _, compilation := range compilations { - for sourcePath := range compilation.Sources { + for sourcePath := range compilation.SourcePathToArtifact { // If we have no source code loaded for this source, skip it. if _, ok := compilation.SourceCode[sourcePath]; !ok { return nil, fmt.Errorf("could not perform source code analysis, code was not cached for '%v'", sourcePath) @@ -135,9 +136,13 @@ func AnalyzeSourceCoverage(compilations []types.Compilation, coverageMaps *Cover // Loop through all sources in all compilations to process coverage information. for _, compilation := range compilations { - for _, source := range compilation.Sources { + for _, source := range compilation.SourcePathToArtifact { // Loop for each contract in this source for _, contract := range source.Contracts { + // Skip interfaces. + if contract.Kind == types.ContractKindInterface { + continue + } // Obtain coverage map data for this contract. initCoverageMapData, err := coverageMaps.GetContractCoverageMap(contract.InitBytecode, true) if err != nil { @@ -196,20 +201,19 @@ func analyzeContractSourceCoverage(compilation types.Compilation, sourceAnalysis for _, sourceMapElement := range sourceMap { // If this source map element doesn't map to any file (compiler generated inline code), it will have no // relevance to the coverage map, so we skip it. - if sourceMapElement.FileID == -1 { + if sourceMapElement.SourceUnitID == -1 { continue } - // Verify this file ID is not out of bounds for a source file index - if sourceMapElement.FileID < 0 || sourceMapElement.FileID >= len(compilation.SourceList) { - // TODO: We may also go out of bounds because this maps to a "generated source" which we do not have. - // For now, we silently skip these cases. + // Obtain our source for this file ID + sourcePath, idExists := compilation.SourceIdToPath[sourceMapElement.SourceUnitID] + + // TODO: We may also go out of bounds because this maps to a "generated source" which we do not have. + // For now, we silently skip these cases. + if !idExists { continue } - // Obtain our source for this file ID - sourcePath := compilation.SourceList[sourceMapElement.FileID] - // Check if the source map element was executed. sourceMapElementCovered := false sourceMapElementCoveredReverted := false @@ -258,7 +262,7 @@ func filterSourceMaps(compilation types.Compilation, sourceMap types.SourceMap) // Loop for each source map entry and determine if it should be included. for i, sourceMapElement := range sourceMap { // Verify this file ID is not out of bounds for a source file index - if sourceMapElement.FileID < 0 || sourceMapElement.FileID >= len(compilation.SourceList) { + if _, exists := compilation.SourceIdToPath[sourceMapElement.SourceUnitID]; !exists { // TODO: We may also go out of bounds because this maps to a "generated source" which we do not have. // For now, we silently skip these cases. continue @@ -267,7 +271,7 @@ func filterSourceMaps(compilation types.Compilation, sourceMap types.SourceMap) // Verify this source map does not overlap another encapsulatesOtherMapping := false for x, sourceMapElement2 := range sourceMap { - if i != x && sourceMapElement.FileID == sourceMapElement2.FileID && + if i != x && sourceMapElement.SourceUnitID == sourceMapElement2.SourceUnitID && !(sourceMapElement.Offset == sourceMapElement2.Offset && sourceMapElement.Length == sourceMapElement2.Length) { if sourceMapElement2.Offset >= sourceMapElement.Offset && sourceMapElement2.Offset+sourceMapElement2.Length <= sourceMapElement.Offset+sourceMapElement.Length { From 3e315a14c337001350c7f2f26fcc8e353cdf3b37 Mon Sep 17 00:00:00 2001 From: alpharush <0xalpharush@protonmail.com> Date: Sat, 27 Jul 2024 23:20:15 -0500 Subject: [PATCH 08/12] skip library and interface in target contract lookup --- fuzzing/fuzzer.go | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/fuzzing/fuzzer.go b/fuzzing/fuzzer.go index f1171cc2..c28ac9d8 100644 --- a/fuzzing/fuzzer.go +++ b/fuzzing/fuzzer.go @@ -291,13 +291,19 @@ func (f *Fuzzer) AddCompilationTargets(compilations []compilationTypes.Compilati compilation := &f.compilations[len(f.compilations)-1] // Loop for each source - for sourcePath, source := range compilation.Sources { + for sourcePath, source := range compilation.SourcePathToArtifact { // Seed our base value set from every source's AST f.baseValueSet.SeedFromAst(source.Ast) // Loop for every contract and register it in our contract definitions for contractName := range source.Contracts { contract := source.Contracts[contractName] + + // Skip interfaces. + if contract.Kind == compilationTypes.ContractKindInterface { + continue + } + contractDefinition := fuzzerTypes.NewContract(contractName, sourcePath, &contract, compilation) // Sort available methods by type @@ -395,11 +401,17 @@ func chainSetupFromCompilations(fuzzer *Fuzzer, testChain *chain.TestChain) (*ex // Verify that target contracts is not empty. If it's empty, but we only have one contract definition, // we can infer the target contracts. Otherwise, we report an error. if len(fuzzer.config.Fuzzing.TargetContracts) == 0 { - // TODO filter libraries - if len(fuzzer.contractDefinitions) == 1 { - fuzzer.config.Fuzzing.TargetContracts = []string{fuzzer.contractDefinitions[0].Name()} - } else { - return nil, fmt.Errorf("missing target contracts") + var found bool + for _, contract := range fuzzer.contractDefinitions { + // If only one contract is defined, we can infer the target contract by filtering interfaces/libraries. + if contract.CompiledContract().Kind == compilationTypes.ContractKindContract { + if !found { + fuzzer.config.Fuzzing.TargetContracts = []string{contract.Name()} + found = true + } else { + return nil, fmt.Errorf("specify target contract(s)") + } + } } } @@ -816,7 +828,11 @@ func (f *Fuzzer) Start() error { if err == nil && f.config.Fuzzing.CorpusDirectory != "" { coverageReportPath := filepath.Join(f.config.Fuzzing.CorpusDirectory, "coverage_report.html") err = coverage.GenerateReport(f.compilations, f.corpus.CoverageMaps(), coverageReportPath) - f.logger.Info("Coverage report saved to file: ", colors.Bold, coverageReportPath, colors.Reset) + if err != nil { + f.logger.Error("Failed to generate coverage report", err) + } else { + f.logger.Info("Coverage report saved to file: ", colors.Bold, coverageReportPath, colors.Reset) + } } // Return any encountered error. From 2dcc0f9fe9eac2eb4ada34a812e9ab067605d45d Mon Sep 17 00:00:00 2001 From: Anish Naik Date: Wed, 31 Jul 2024 14:36:29 -0400 Subject: [PATCH 09/12] improve commenting and error handling --- compilation/platforms/crytic_compile.go | 14 ++++++++++---- compilation/types/ast.go | 24 ++++++------------------ 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/compilation/platforms/crytic_compile.go b/compilation/platforms/crytic_compile.go index 1b74da1c..d43a6d12 100644 --- a/compilation/platforms/crytic_compile.go +++ b/compilation/platforms/crytic_compile.go @@ -196,13 +196,19 @@ func (c *CryticCompilationConfig) Compile() ([]types.Compilation, string, error) // Loop through all sources and parse them into our types. for sourcePath, source := range solcExport.Sources { - + // Convert the AST into our version of the AST (types.AST) var ast types.AST - b, _ := json.Marshal(source.AST) - err := json.Unmarshal(b, &ast) + b, err = json.Marshal(source.AST) + if err != nil { + + return nil, "", fmt.Errorf("could not encode AST from sources: %v", err) + } + err = json.Unmarshal(b, &ast) if err != nil { - return nil, "", fmt.Errorf("could not parse AST from sources, error: %v", err) + + return nil, "", fmt.Errorf("could not parse AST from sources: %v", err) } + // From the AST, extract the contract kinds. for _, node := range ast.Nodes { if node.GetNodeType() == "ContractDefinition" { diff --git a/compilation/types/ast.go b/compilation/types/ast.go index 3f0d2866..fac97f51 100644 --- a/compilation/types/ast.go +++ b/compilation/types/ast.go @@ -2,32 +2,20 @@ package types import ( "encoding/json" - "fmt" ) -// ContractKind represents the kind of contract +// ContractKind represents the kind of contract represented by an AST node type ContractKind string const ( - ContractKindContract ContractKind = "contract" - ContractKindLibrary ContractKind = "library" + // ContractKindContract represents a contract node + ContractKindContract ContractKind = "contract" + // ContractKindLibrary represents a library node + ContractKindLibrary ContractKind = "library" + // ContractKindInterface represents an interface node ContractKindInterface ContractKind = "interface" ) -// ContractKindFromString converts a string to a ContractKind -func ContractKindFromString(s string) ContractKind { - switch s { - case "contract": - return ContractKindContract - case "library": - return ContractKindLibrary - case "interface": - return ContractKindInterface - default: - panic(fmt.Sprintf("unknown contract kind: %s", s)) - } -} - // Node interface represents a generic AST node type Node interface { GetNodeType() string From 2328a683c98d3e0b35d186d4629a64eadf398fb8 Mon Sep 17 00:00:00 2001 From: Anish Naik Date: Wed, 31 Jul 2024 15:24:45 -0400 Subject: [PATCH 10/12] added more comments --- compilation/platforms/crytic_compile.go | 9 +++---- compilation/types/ast.go | 31 ++++++++++++++++--------- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/compilation/platforms/crytic_compile.go b/compilation/platforms/crytic_compile.go index d43a6d12..4ee08d1a 100644 --- a/compilation/platforms/crytic_compile.go +++ b/compilation/platforms/crytic_compile.go @@ -209,17 +209,18 @@ func (c *CryticCompilationConfig) Compile() ([]types.Compilation, string, error) return nil, "", fmt.Errorf("could not parse AST from sources: %v", err) } - // From the AST, extract the contract kinds. + // From the AST, extract the contract kinds where the contract definition could be for a contract, library, + // or interface for _, node := range ast.Nodes { if node.GetNodeType() == "ContractDefinition" { - cdef := node.(types.ContractDefinition) - contractKinds[cdef.CanonicalName] = cdef.ContractKind + contractDefinition := node.(types.ContractDefinition) + contractKinds[contractDefinition.CanonicalName] = contractDefinition.Kind } } sourceUnitId := getSourceUnitID(ast.Src) compilation.SourcePathToArtifact[sourcePath] = types.SourceArtifact{ - // TODO our types.AST is not the same as the original AST but we could parse it and avoid using "any" + // TODO: Our types.AST is not the same as the original AST but we could parse it and avoid using "any" Ast: source.AST, Contracts: make(map[string]types.CompiledContract), SourceUnitId: sourceUnitId, diff --git a/compilation/types/ast.go b/compilation/types/ast.go index fac97f51..6b495312 100644 --- a/compilation/types/ast.go +++ b/compilation/types/ast.go @@ -4,7 +4,7 @@ import ( "encoding/json" ) -// ContractKind represents the kind of contract represented by an AST node +// ContractKind represents the kind of contract definition represented by an AST node type ContractKind string const ( @@ -23,11 +23,15 @@ type Node interface { // ContractDefinition is the contract definition node type ContractDefinition struct { - NodeType string `json:"nodeType"` - CanonicalName string `json:"canonicalName,omitempty"` - ContractKind ContractKind `json:"contractKind,omitempty"` + // NodeType represents the AST node type (note that it will always be a contract definition) + NodeType string `json:"nodeType"` + // CanonicalName is the name of the contract definition + CanonicalName string `json:"canonicalName,omitempty"` + // Kind is a ContractKind that represents what type of contract definition this is (contract, interface, or library) + Kind ContractKind `json:"contractKind,omitempty"` } +// GetNodeType implements the Node interface and returns the node type for the contract definition func (s ContractDefinition) GetNodeType() string { return s.NodeType } @@ -39,8 +43,9 @@ type AST struct { Src string `json:"src"` } -// UnmarshalJSON custom unmarshaller for AST +// UnmarshalJSON unmarshals from JSON func (a *AST) UnmarshalJSON(data []byte) error { + // Unmarshal the top-level AST into our own representation. Defer the unmarshaling of all the individual nodes until later type Alias AST aux := &struct { Nodes []json.RawMessage `json:"nodes"` @@ -52,33 +57,37 @@ func (a *AST) UnmarshalJSON(data []byte) error { return err } - // Check if nodeType is "SourceUnit" + // Check if nodeType is "SourceUnit". Return early otherwise if aux.NodeType != "SourceUnit" { return nil } + // Iterate through all the nodes of the source unit for _, nodeData := range aux.Nodes { + // Unmarshal the node data to retrieve the node type var nodeType struct { NodeType string `json:"nodeType"` } - if err := json.Unmarshal(nodeData, &nodeType); err != nil { return err } + // Unmarshal the contents of the node based on the node type var node Node switch nodeType.NodeType { case "ContractDefinition": - var cdef ContractDefinition - if err := json.Unmarshal(nodeData, &cdef); err != nil { + // If this is a contract definition, unmarshal it + var contractDefinition ContractDefinition + if err := json.Unmarshal(nodeData, &contractDefinition); err != nil { return err } - node = cdef - // Add cases for other node types as needed + node = contractDefinition + // TODO: Add cases for other node types as needed default: continue } + // Append the node a.Nodes = append(a.Nodes, node) } From 4eb55b896723973c30080f7bc934d5c33076ea06 Mon Sep 17 00:00:00 2001 From: Anish Naik Date: Wed, 31 Jul 2024 15:42:03 -0400 Subject: [PATCH 11/12] migrate getSourceUnitID function to AST class --- compilation/platforms/crytic_compile.go | 20 ++------------------ compilation/platforms/solc.go | 19 +++++++++++++------ compilation/types/ast.go | 23 +++++++++++++++++++++-- 3 files changed, 36 insertions(+), 26 deletions(-) diff --git a/compilation/platforms/crytic_compile.go b/compilation/platforms/crytic_compile.go index 4ee08d1a..a9e07fae 100644 --- a/compilation/platforms/crytic_compile.go +++ b/compilation/platforms/crytic_compile.go @@ -8,8 +8,6 @@ import ( "os" "os/exec" "path/filepath" - "regexp" - "strconv" "strings" "github.com/crytic/medusa/compilation/types" @@ -89,19 +87,6 @@ func (c *CryticCompilationConfig) getArgs() ([]string, error) { return args, nil } -func getSourceUnitID(src string) int { - re := regexp.MustCompile(`[0-9]*:[0-9]*:([0-9]*)`) - sourceUnitCandidates := re.FindStringSubmatch(src) - - if len(sourceUnitCandidates) == 2 { // FindStringSubmatch includes the whole match as the first element - sourceUnit, err := strconv.Atoi(sourceUnitCandidates[1]) - if err == nil { - return sourceUnit - } - } - return -1 -} - // Compile uses the CryticCompilationConfig provided to compile a given target, parse the artifacts, and then // create a list of types.Compilation. func (c *CryticCompilationConfig) Compile() ([]types.Compilation, string, error) { @@ -200,12 +185,10 @@ func (c *CryticCompilationConfig) Compile() ([]types.Compilation, string, error) var ast types.AST b, err = json.Marshal(source.AST) if err != nil { - return nil, "", fmt.Errorf("could not encode AST from sources: %v", err) } err = json.Unmarshal(b, &ast) if err != nil { - return nil, "", fmt.Errorf("could not parse AST from sources: %v", err) } @@ -218,7 +201,8 @@ func (c *CryticCompilationConfig) Compile() ([]types.Compilation, string, error) } } - sourceUnitId := getSourceUnitID(ast.Src) + // Retrieve the source unit ID + sourceUnitId := ast.GetSourceUnitID() compilation.SourcePathToArtifact[sourcePath] = types.SourceArtifact{ // TODO: Our types.AST is not the same as the original AST but we could parse it and avoid using "any" Ast: source.AST, diff --git a/compilation/platforms/solc.go b/compilation/platforms/solc.go index 572553f0..4ceef747 100644 --- a/compilation/platforms/solc.go +++ b/compilation/platforms/solc.go @@ -124,21 +124,28 @@ func (s *SolcCompilationConfig) Compile() ([]types.Compilation, string, error) { return nil, "", fmt.Errorf("could not parse AST from sources, AST field could not be found") } + // Convert the AST into our version of the AST (types.AST) var ast types.AST - b, _ := json.Marshal(sourceDict["AST"]) - err := json.Unmarshal(b, &ast) + b, err := json.Marshal(origAST) + if err != nil { + return nil, "", fmt.Errorf("could not encode AST from sources: %v", err) + } + err = json.Unmarshal(b, &ast) if err != nil { return nil, "", fmt.Errorf("could not parse AST from sources, error: %v", err) } - // From the AST, extract the contract kinds. + + // From the AST, extract the contract kinds where the contract definition could be for a contract, library, + // or interface for _, node := range ast.Nodes { if node.GetNodeType() == "ContractDefinition" { - cdef := node.(types.ContractDefinition) - contractKinds[cdef.CanonicalName] = cdef.ContractKind + contractDefinition := node.(types.ContractDefinition) + contractKinds[contractDefinition.CanonicalName] = contractDefinition.Kind } } - sourceUnitId := getSourceUnitID(ast.Src) + // Get the source unit ID + sourceUnitId := ast.GetSourceUnitID() // Construct our compiled source object compilation.SourcePathToArtifact[sourcePath] = types.SourceArtifact{ // TODO our types.AST is not the same as the original AST but we could parse it and avoid using "any" diff --git a/compilation/types/ast.go b/compilation/types/ast.go index 6b495312..f6b21612 100644 --- a/compilation/types/ast.go +++ b/compilation/types/ast.go @@ -2,6 +2,8 @@ package types import ( "encoding/json" + "regexp" + "strconv" ) // ContractKind represents the kind of contract definition represented by an AST node @@ -38,9 +40,12 @@ func (s ContractDefinition) GetNodeType() string { // AST is the abstract syntax tree type AST struct { + // NodeType represents the node type (currently we only evaluate source unit node types) NodeType string `json:"nodeType"` - Nodes []Node `json:"nodes"` - Src string `json:"src"` + // Nodes is a list of Nodes within the AST + Nodes []Node `json:"nodes"` + // Src is the source file for this AST + Src string `json:"src"` } // UnmarshalJSON unmarshals from JSON @@ -93,3 +98,17 @@ func (a *AST) UnmarshalJSON(data []byte) error { return nil } + +// GetSourceUnitID returns the source unit ID based on the source of the AST +func (a *AST) GetSourceUnitID() int { + re := regexp.MustCompile(`[0-9]*:[0-9]*:([0-9]*)`) + sourceUnitCandidates := re.FindStringSubmatch(a.Src) + + if len(sourceUnitCandidates) == 2 { // FindStringSubmatch includes the whole match as the first element + sourceUnit, err := strconv.Atoi(sourceUnitCandidates[1]) + if err == nil { + return sourceUnit + } + } + return -1 +} From f4d710f254581f9898bf2659b7980364163eb914 Mon Sep 17 00:00:00 2001 From: Anish Naik Date: Fri, 2 Aug 2024 12:31:20 -0400 Subject: [PATCH 12/12] fix comments --- fuzzing/fuzzer_metrics.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fuzzing/fuzzer_metrics.go b/fuzzing/fuzzer_metrics.go index b0984ab0..ef495a80 100644 --- a/fuzzing/fuzzer_metrics.go +++ b/fuzzing/fuzzer_metrics.go @@ -14,7 +14,7 @@ type fuzzerWorkerMetrics struct { // sequencesTested describes the amount of sequences of transactions which tests were run against. sequencesTested *big.Int - //failedSequences describes the amount of sequences of transactions which tests failed. + // failedSequences describes the amount of sequences of transactions which tests failed. failedSequences *big.Int // callsTested describes the amount of transactions/calls the fuzzer executed and ran tests against. @@ -43,6 +43,7 @@ func newFuzzerMetrics(workerCount int) *FuzzerMetrics { return &metrics } +// FailedSequences returns the number of sequences that led to failures across all workers func (m *FuzzerMetrics) FailedSequences() *big.Int { failedSequences := big.NewInt(0) for _, workerMetrics := range m.workerMetrics {