diff --git a/fuzzing/calls/call_sequence_execution.go b/fuzzing/calls/call_sequence_execution.go index ca983f0d..3416536c 100644 --- a/fuzzing/calls/call_sequence_execution.go +++ b/fuzzing/calls/call_sequence_execution.go @@ -3,10 +3,13 @@ package calls import ( "fmt" + "math/big" + "github.com/crytic/medusa/chain" "github.com/crytic/medusa/fuzzing/contracts" "github.com/crytic/medusa/fuzzing/executiontracer" "github.com/crytic/medusa/utils" + "github.com/ethereum/go-ethereum/core" ) // ExecuteCallSequenceFetchElementFunc describes a function that is called to obtain the next call sequence element to @@ -51,107 +54,163 @@ func ExecuteCallSequenceIteratively(chain *chain.TestChain, fetchElementFunc Exe break } - // We try to add the transaction with our call more than once. If the pending block is too full, we may hit a - // block gas limit, which we handle by committing the pending block without this tx, and creating a new pending - // block that is empty to try adding this tx there instead. - // If we encounter an error on an empty block, we throw the error as there is nothing more we can do. - for { - // If we have a pending block, but we intend to delay this call from the last, we commit that block. - if chain.PendingBlock() != nil && callSequenceElement.BlockNumberDelay > 0 { - err := chain.PendingBlockCommit() - if err != nil { - return callSequenceExecuted, err - } - } + // Process the call sequence element + err = processCallSequenceElement(chain, callSequenceElement, &callSequenceExecuted, additionalTracers...) + if err != nil { + return callSequenceExecuted, err + } - // If we have no pending block to add a tx containing our call to, we must create one. - if chain.PendingBlock() == nil { - // The minimum step between blocks must be 1 in block number and timestamp, so we ensure this is the - // case. - numberDelay := callSequenceElement.BlockNumberDelay - timeDelay := callSequenceElement.BlockTimestampDelay - if numberDelay == 0 { - numberDelay = 1 - } - if timeDelay == 0 { - timeDelay = 1 - } + // We added our call to the block as a transaction. Call our step function with the update and check + // if it returned an error. + if executionCheckFunc != nil { + execCheckFuncRequestedBreak, err = executionCheckFunc(callSequenceExecuted) + if err != nil { + return callSequenceExecuted, err + } - // Each timestamp/block number must be unique as well, so we cannot jump more block numbers than time. - if numberDelay > timeDelay { - numberDelay = timeDelay - } - _, err := chain.PendingBlockCreateWithParameters(chain.Head().Header.Number.Uint64()+numberDelay, chain.Head().Header.Time+timeDelay, nil) - if err != nil { - return callSequenceExecuted, err - } + // If post-execution check requested we break execution, break out of our "retry loop" + if execCheckFuncRequestedBreak { + break } + } + } - // Try to add our transaction to this block. - err = chain.PendingBlockAddTx(callSequenceElement.Call.ToCoreMessage(), additionalTracers...) + return callSequenceExecuted, nil +} - if err != nil { - // If we encountered a block gas limit error, this tx is too expensive to fit in this block. - // If there are other transactions in the block, this makes sense. The block is "full". - // In that case, we commit the pending block without this tx, and create a new pending block to add - // our tx to, and iterate to try and add it again. - // TODO: This should also check the condition that this is a block gas error specifically. For now, we - // simply assume it is and try processing in an empty block (if that fails, that error will be - // returned). - if len(chain.PendingBlock().Messages) > 0 { - err := chain.PendingBlockCommit() - if err != nil { - return callSequenceExecuted, err - } - continue - } +// processCallSequenceElement handles the execution of a single call sequence element, including the setup hook. +func processCallSequenceElement(chain *chain.TestChain, callSequenceElement *CallSequenceElement, callSequenceExecuted *CallSequence, additionalTracers ...*chain.TestChainTracer) error { + // We try to add the transaction with our call more than once. If the pending block is too full, we may hit a + // block gas limit, which we handle by committing the pending block without this tx, and creating a new pending + // block that is empty to try adding this tx there instead. + // If we encounter an error on an empty block, we throw the error as there is nothing more we can do. - // If there are no transactions in our block, and we failed to add this one, return the error - return callSequenceExecuted, err + // Process contract setup hook if present + if callSequenceElement.Contract.SetupHook != nil { + err := executeContractSetupHook(chain, callSequenceElement, callSequenceExecuted) + if err != nil { + return err + } + } + + // Process the main call sequence element + return executeCall(chain, callSequenceElement, callSequenceExecuted, additionalTracers...) +} + +// executeContractSetupHook processes the contract setup hook for the call sequence element. +func executeContractSetupHook(chain *chain.TestChain, callSequenceElement *CallSequenceElement, callSequenceExecuted *CallSequence) error { + // Get our contract setup hook + contractSetupHook := callSequenceElement.Contract.SetupHook + + // Create a call targeting our setup hook + msg := NewCallMessageWithAbiValueData(contractSetupHook.DeployerAddress, callSequenceElement.Call.To, 0, big.NewInt(0), callSequenceElement.Call.GasLimit, nil, nil, nil, &CallMessageDataAbiValues{ + Method: contractSetupHook.Method, + InputValues: nil, + }) + msg.FillFromTestChainProperties(chain) + + // Execute the call + // If we have no pending block to add a tx containing our call to, we must create one. + err := addTxToPendingBlock(chain, callSequenceElement.BlockNumberDelay, callSequenceElement.BlockTimestampDelay, msg.ToCoreMessage()) + if err != nil { + return err + } + + setupCallSequenceElement := NewCallSequenceElement(callSequenceElement.Contract, msg, callSequenceElement.BlockNumberDelay, callSequenceElement.BlockTimestampDelay) + setupCallSequenceElement.ChainReference = &CallSequenceElementChainReference{ + Block: chain.PendingBlock(), + TransactionIndex: len(chain.PendingBlock().Messages) - 1, + } + + // Register the call in our call sequence so it gets registered in coverage. + *callSequenceExecuted = append(*callSequenceExecuted, setupCallSequenceElement) + return nil +} + +// executeCall processes the main call of the call sequence element. +func executeCall(chain *chain.TestChain, callSequenceElement *CallSequenceElement, callSequenceExecuted *CallSequence, additionalTracers ...*chain.TestChainTracer) error { + // Update call sequence element call message if setup hook was executed + if callSequenceElement.Contract.SetupHook != nil { + callSequenceElement.Call.FillFromTestChainProperties(chain) + } + + // Try to add our transaction to this block. + err := addTxToPendingBlock(chain, callSequenceElement.BlockNumberDelay, callSequenceElement.BlockTimestampDelay, callSequenceElement.Call.ToCoreMessage(), additionalTracers...) + if err != nil { + return err + } + + // Update our chain reference for this element. + callSequenceElement.ChainReference = &CallSequenceElementChainReference{ + Block: chain.PendingBlock(), + TransactionIndex: len(chain.PendingBlock().Messages) - 1, + } + + // Add to our executed call sequence + *callSequenceExecuted = append(*callSequenceExecuted, callSequenceElement) + return nil +} + +// addTxToPendingBlock attempts to add a transaction to the pending block, handling block creation and retries as necessary. +func addTxToPendingBlock(chain *chain.TestChain, numberDelay, timeDelay uint64, txMessage *core.Message, additionalTracers ...*chain.TestChainTracer) error { + for { + // If we have a pending block, but we intend to delay this call from the last, we commit that block. + if chain.PendingBlock() != nil && numberDelay > 0 { + err := chain.PendingBlockCommit() + if err != nil { + return err } + } - // Update our chain reference for this element. - callSequenceElement.ChainReference = &CallSequenceElementChainReference{ - Block: chain.PendingBlock(), - TransactionIndex: len(chain.PendingBlock().Messages) - 1, + // If we have no pending block to add a tx containing our call to, we must create one. + if chain.PendingBlock() == nil { + // The minimum step between blocks must be 1 in block number and timestamp, so we ensure this is the + // case. + if numberDelay == 0 { + numberDelay = 1 + } + if timeDelay == 0 { + timeDelay = 1 } - // Add to our executed call sequence - callSequenceExecuted = append(callSequenceExecuted, callSequenceElement) + // Each timestamp/block number must be unique as well, so we cannot jump more block numbers than time. + if numberDelay > timeDelay { + numberDelay = timeDelay + } + _, err := chain.PendingBlockCreateWithParameters(chain.Head().Header.Number.Uint64()+numberDelay, chain.Head().Header.Time+timeDelay, nil) + if err != nil { + return err + } + } - // We added our call to the block as a transaction. Call our step function with the update and check - // if it returned an error. - if executionCheckFunc != nil { - execCheckFuncRequestedBreak, err = executionCheckFunc(callSequenceExecuted) + // Try to add our transaction to this block. + err := chain.PendingBlockAddTx(txMessage, additionalTracers...) + if err != nil { + // If we encountered a block gas limit error, this tx is too expensive to fit in this block. + // If there are other transactions in the block, this makes sense. The block is "full". + // In that case, we commit the pending block without this tx, and create a new pending block to add + // our tx to, and iterate to try and add it again. + // TODO: This should also check the condition that this is a block gas error specifically. For now, we + // simply assume it is and try processing in an empty block (if that fails, that error will be + // returned). + if len(chain.PendingBlock().Messages) > 0 { + err := chain.PendingBlockCommit() if err != nil { - return callSequenceExecuted, err - } - - // If post-execution check requested we break execution, break out of our "retry loop" - if execCheckFuncRequestedBreak { - break + return err } + continue } - // We didn't encounter an error, so we were successful in adding this transaction. Break out of this - // inner "retry loop" and move onto processing the next element in the outer loop. - break + // If there are no transactions in our block, and we failed to add this one, return the error + return err } - // If post-execution check requested we break execution, break out of our "execute next call sequence loop" - if execCheckFuncRequestedBreak { - break - } + // We didn't encounter an error, so we were successful in adding this transaction. Break out of this + // inner "retry loop" and move onto processing the next element in the outer loop. + break } - // Commit the last pending block. - if chain.PendingBlock() != nil { - err := chain.PendingBlockCommit() - if err != nil { - return callSequenceExecuted, err - } - } - return callSequenceExecuted, nil + return nil } // ExecuteCallSequence executes a provided CallSequence on the provided chain. diff --git a/fuzzing/contracts/contract.go b/fuzzing/contracts/contract.go index 30ad094a..c0895d4e 100644 --- a/fuzzing/contracts/contract.go +++ b/fuzzing/contracts/contract.go @@ -1,16 +1,27 @@ package contracts import ( - "golang.org/x/exp/slices" "strings" + "golang.org/x/exp/slices" + "github.com/crytic/medusa/compilation/types" "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/common" ) // Contracts describes an array of contracts type Contracts []*Contract +// ContractSetupHook describes a contract setup hook +type ContractSetupHook struct { + // Method represents the setup function + Method *abi.Method + + // DeployerAddress represents the fuzzer's deployer address, to be used when calling the setup hook. + DeployerAddress common.Address +} + // MatchBytecode takes init and/or runtime bytecode and attempts to match it to a contract definition in the // current list of contracts. It returns the contract definition if found. Otherwise, it returns nil. func (c Contracts) MatchBytecode(initBytecode []byte, runtimeBytecode []byte) *Contract { @@ -40,6 +51,9 @@ type Contract struct { // compilation describes the compilation which contains the compiledContract. compilation *types.Compilation + // setupHook describes the contract's setup hook, if it exists. + SetupHook *ContractSetupHook + // PropertyTestMethods are the methods that are property tests. PropertyTestMethods []abi.Method diff --git a/fuzzing/fuzzer.go b/fuzzing/fuzzer.go index 4fcc00d5..cd64f7ca 100644 --- a/fuzzing/fuzzer.go +++ b/fuzzing/fuzzer.go @@ -307,7 +307,7 @@ func (f *Fuzzer) AddCompilationTargets(compilations []compilationTypes.Compilati contractDefinition := fuzzerTypes.NewContract(contractName, sourcePath, &contract, compilation) // Sort available methods by type - assertionTestMethods, propertyTestMethods, optimizationTestMethods := fuzzingutils.BinTestByType(&contract, + assertionTestMethods, propertyTestMethods, optimizationTestMethods, setupHook := fuzzingutils.BinTestByType(&contract, f.config.Fuzzing.Testing.PropertyTesting.TestPrefixes, f.config.Fuzzing.Testing.OptimizationTesting.TestPrefixes, f.config.Fuzzing.Testing.AssertionTesting.TestViewMethods) @@ -315,6 +315,11 @@ func (f *Fuzzer) AddCompilationTargets(compilations []compilationTypes.Compilati contractDefinition.PropertyTestMethods = propertyTestMethods contractDefinition.OptimizationTestMethods = optimizationTestMethods + // Register the contract's setup hook, if exists + if setupHook != nil { + contractDefinition.SetupHook = &fuzzerTypes.ContractSetupHook{Method: setupHook, DeployerAddress: f.deployer} + } + // Filter and record methods available for assertion testing. Property and optimization tests are always run. if len(f.config.Fuzzing.Testing.TargetFunctionSignatures) > 0 { // Only consider methods that are in the target methods list diff --git a/fuzzing/fuzzer_test.go b/fuzzing/fuzzer_test.go index 5f56ca4c..68c5f07f 100644 --- a/fuzzing/fuzzer_test.go +++ b/fuzzing/fuzzer_test.go @@ -182,6 +182,26 @@ func TestOptimizationMode(t *testing.T) { } } +// TestSetupHook runs a test to ensure that setup hooks work as expected. +func TestSetupHook(t *testing.T) { + runFuzzerTest(t, &fuzzerSolcFileTest{ + filePath: "testdata/contracts/assertions/assert_setup_hook.sol", + configUpdates: func(config *config.ProjectConfig) { + config.Fuzzing.TargetContracts = []string{"TestContract", "TestContract2"} + config.Fuzzing.TestLimit = 10_000 + config.Fuzzing.Testing.AssertionTesting.Enabled = true + }, + method: func(f *fuzzerTestContext) { + // Start the fuzzer + err := f.fuzzer.Start() + assert.NoError(t, err) + + // Assert that we should not have failures. + assertFailedTestsExpected(f, false) + }, + }) +} + // TestChainBehaviour runs tests to ensure the chain behaves as expected. func TestChainBehaviour(t *testing.T) { // Run a test to simulate out of gas errors to make sure its handled well by the Chain and does not panic. diff --git a/fuzzing/testdata/contracts/assertions/assert_setup_hook.sol b/fuzzing/testdata/contracts/assertions/assert_setup_hook.sol new file mode 100644 index 00000000..b097186f --- /dev/null +++ b/fuzzing/testdata/contracts/assertions/assert_setup_hook.sol @@ -0,0 +1,46 @@ +// These contracts ensure that setUp hooks work as expected. +contract TestContract { + bool public state = false; + + function setUp() public { + state = true; + } + + function one() public { + assert(state); + + state = false; + } + + function two() public { + assert(state); + + state = false; + } + + function three() public { + assert(state); + + state = false; + } +} + +contract TestContract2 { + uint256 public num = 0; + + function setUp() public { + num = 3; + } + + function four() public { + assert(num == 3); + } + + function five() public { + assert(num == 3); + } + + function six() public { + assert(num == 3); + } +} \ No newline at end of file diff --git a/fuzzing/utils/fuzz_method_utils.go b/fuzzing/utils/fuzz_method_utils.go index 70b77a12..6783777b 100644 --- a/fuzzing/utils/fuzz_method_utils.go +++ b/fuzzing/utils/fuzz_method_utils.go @@ -10,6 +10,10 @@ import ( // IsOptimizationTest checks whether the method is an optimization test given potential naming prefixes it must conform to // and its underlying input/output arguments. func IsOptimizationTest(method abi.Method, prefixes []string) bool { + // Don't register setup hook as a test case + if IsSetupHook(method) { + return false + } // Loop through all enabled prefixes to find a match for _, prefix := range prefixes { if strings.HasPrefix(method.Name, prefix) { @@ -25,6 +29,10 @@ func IsOptimizationTest(method abi.Method, prefixes []string) bool { // IsPropertyTest checks whether the method is a property test given potential naming prefixes it must conform to // and its underlying input/output arguments. func IsPropertyTest(method abi.Method, prefixes []string) bool { + // Don't register setup hook as a test case + if IsSetupHook(method) { + return false + } // Loop through all enabled prefixes to find a match for _, prefix := range prefixes { // The property test must simply have the right prefix and take no inputs and return a boolean @@ -37,10 +45,17 @@ func IsPropertyTest(method abi.Method, prefixes []string) bool { return false } -// BinTestByType sorts a contract's methods by whether they are assertion, property, or optimization tests. -func BinTestByType(contract *compilationTypes.CompiledContract, propertyTestPrefixes, optimizationTestPrefixes []string, testViewMethods bool) (assertionTests, propertyTests, optimizationTests []abi.Method) { +// IsSetupHook checks whether the method is a setup hook +func IsSetupHook(method abi.Method) bool { + return method.Name == "setUp" +} + +// BinTestByType sorts a contract's methods by whether they are the contract's setup hook,assertion, property, or optimization tests. +func BinTestByType(contract *compilationTypes.CompiledContract, propertyTestPrefixes, optimizationTestPrefixes []string, testViewMethods bool) (assertionTests, propertyTests, optimizationTests []abi.Method, setupHook *abi.Method) { for _, method := range contract.Abi.Methods { - if IsPropertyTest(method, propertyTestPrefixes) { + if IsSetupHook(method) { + setupHook = &method + } else if IsPropertyTest(method, propertyTestPrefixes) { propertyTests = append(propertyTests, method) } else if IsOptimizationTest(method, optimizationTestPrefixes) { optimizationTests = append(optimizationTests, method) @@ -48,5 +63,5 @@ func BinTestByType(contract *compilationTypes.CompiledContract, propertyTestPref assertionTests = append(assertionTests, method) } } - return assertionTests, propertyTests, optimizationTests + return assertionTests, propertyTests, optimizationTests, setupHook }