Skip to content

Commit 6d5fb76

Browse files
committed
re-execute call sequences, lookup exec. trace by tx hash
1 parent bfe0ba8 commit 6d5fb76

12 files changed

+179
-76
lines changed

chain/test_chain.go

Lines changed: 74 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"github.com/crytic/medusa/chain/config"
1010
"github.com/ethereum/go-ethereum/core/rawdb"
1111
"github.com/ethereum/go-ethereum/core/tracing"
12+
"github.com/ethereum/go-ethereum/crypto"
13+
"github.com/ethereum/go-ethereum/eth/tracers"
1214
"github.com/ethereum/go-ethereum/triedb"
1315
"github.com/ethereum/go-ethereum/triedb/hashdb"
1416
"github.com/holiman/uint256"
@@ -84,7 +86,7 @@ type TestChain struct {
8486
// NewTestChain creates a simulated Ethereum backend used for testing, or returns an error if one occurred.
8587
// This creates a test chain with a test chain configuration and the provided genesis allocation and config.
8688
// If a nil config is provided, a default one is used.
87-
func NewTestChain(genesisAlloc core.GenesisAlloc, testChainConfig *config.TestChainConfig) (*TestChain, error) {
89+
func NewTestChain(genesisAlloc types.GenesisAlloc, testChainConfig *config.TestChainConfig) (*TestChain, error) {
8890
// Copy our chain config, so it is not shared across chains.
8991
chainConfig, err := utils.CopyChainConfig(params.TestChainConfig)
9092
if err != nil {
@@ -143,7 +145,7 @@ func NewTestChain(genesisAlloc core.GenesisAlloc, testChainConfig *config.TestCh
143145
return nil, err
144146
}
145147
for _, cheatContract := range cheatContracts {
146-
genesisDefinition.Alloc[cheatContract.address] = core.GenesisAccount{
148+
genesisDefinition.Alloc[cheatContract.address] = types.Account{
147149
Balance: big.NewInt(0),
148150
Code: []byte{0xFF},
149151
}
@@ -251,7 +253,7 @@ func (t *TestChain) Clone(onCreateFunc func(chain *TestChain) error) (*TestChain
251253
// Now add each transaction/message to it.
252254
messages := t.blocks[i].Messages
253255
for j := 0; j < len(messages); j++ {
254-
err = targetChain.PendingBlockAddTx(messages[j])
256+
err = targetChain.PendingBlockAddTx(messages[j], nil)
255257
if err != nil {
256258
return nil, err
257259
}
@@ -561,7 +563,7 @@ func (t *TestChain) CallContract(msg *core.Message, state *state.StateDB, additi
561563
}
562564

563565
// Obtain our state snapshot to revert any changes after our call
564-
// snapshot := state.Snapshot()
566+
snapshot := state.Snapshot()
565567

566568
// Set infinite balance to the fake caller account
567569
state.AddBalance(msg.From, uint256.MustFromBig(math.MaxBig256), tracing.BalanceChangeUnspecified)
@@ -585,19 +587,57 @@ func (t *TestChain) CallContract(msg *core.Message, state *state.StateDB, additi
585587
})
586588
t.evm = evm
587589

590+
tx := utils.MessageToTransaction(msg)
588591
if evm.Config.Tracer != nil && evm.Config.Tracer.OnTxStart != nil {
589-
evm.Config.Tracer.OnTxStart(evm.GetVMContext(), utils.MessageToTransaction(msg), msg.From)
592+
evm.Config.Tracer.OnTxStart(evm.GetVMContext(), tx, msg.From)
590593
}
591594
// Fund the gas pool, so it can execute endlessly (no block gas limit).
592595
gasPool := new(core.GasPool).AddGas(math.MaxUint64)
593596

594597
// Perform our state transition to obtain the result.
595-
res, err := core.NewStateTransition(evm, msg, gasPool).TransitionDb()
598+
msgResult, err := core.ApplyMessage(evm, msg, gasPool)
596599

597600
// Revert to our state snapshot to undo any changes.
598-
// state.RevertToSnapshot(snapshot)
601+
if err != nil {
602+
state.RevertToSnapshot(snapshot)
603+
}
604+
605+
// Receipt:
606+
var root []byte
607+
if t.chainConfig.IsByzantium(blockContext.BlockNumber) {
608+
t.state.Finalise(true)
609+
} else {
610+
root = state.IntermediateRoot(t.chainConfig.IsEIP158(blockContext.BlockNumber)).Bytes()
611+
}
599612

600-
return res, err
613+
// Create a new receipt for the transaction, storing the intermediate root and
614+
// gas used by the tx.
615+
receipt := &types.Receipt{Type: tx.Type(), PostState: root, CumulativeGasUsed: msgResult.UsedGas}
616+
if msgResult.Failed() {
617+
receipt.Status = types.ReceiptStatusFailed
618+
} else {
619+
receipt.Status = types.ReceiptStatusSuccessful
620+
}
621+
receipt.TxHash = tx.Hash()
622+
receipt.GasUsed = msgResult.UsedGas
623+
624+
// If the transaction created a contract, store the creation address in the receipt.
625+
if msg.To == nil {
626+
receipt.ContractAddress = crypto.CreateAddress(evm.TxContext.Origin, tx.Nonce())
627+
}
628+
629+
// Set the receipt logs and create the bloom filter.
630+
receipt.Logs = t.state.GetLogs(tx.Hash(), blockContext.BlockNumber.Uint64(), blockContext.GetHash(blockContext.BlockNumber.Uint64()))
631+
receipt.Bloom = types.CreateBloom(types.Receipts{receipt})
632+
receipt.TransactionIndex = uint(0)
633+
634+
if evm.Config.Tracer != nil {
635+
if evm.Config.Tracer.OnTxEnd != nil {
636+
evm.Config.Tracer.OnTxEnd(receipt, nil)
637+
}
638+
}
639+
640+
return msgResult, err
601641
}
602642

603643
// PendingBlock describes the current pending block which is being constructed and awaiting commitment to the chain.
@@ -701,16 +741,19 @@ func (t *TestChain) PendingBlockCreateWithParameters(blockNumber uint64, blockTi
701741

702742
// PendingBlockAddTx takes a message (internal txs) and adds it to the current pending block, updating the header
703743
// with relevant execution information. If a pending block was not created, an error is returned.
704-
// Returns the constructed block, or an error if one occurred.
705-
func (t *TestChain) PendingBlockAddTx(message *core.Message) error {
744+
// Returns an error if one occurred.
745+
func (t *TestChain) PendingBlockAddTx(message *core.Message, getTracerFn func(txIndex int, txHash common.Hash) *tracers.Tracer) error {
746+
if getTracerFn == nil {
747+
getTracerFn = func(txIndex int, txHash common.Hash) *tracers.Tracer {
748+
return t.transactionTracerRouter.NativeTracer.Tracer
749+
}
750+
}
751+
706752
// If we don't have a pending block, return an error
707753
if t.pendingBlock == nil {
708754
return errors.New("could not add tx to the chain's pending block because no pending block was created")
709755
}
710756

711-
// Obtain our state root hash prior to execution.
712-
// previousStateRoot := t.pendingBlock.Header.Root
713-
714757
// Create a gas pool indicating how much gas can be spent executing the transaction.
715758
gasPool := new(core.GasPool).AddGas(t.pendingBlock.Header.GasLimit - t.pendingBlock.Header.GasUsed)
716759

@@ -721,17 +764,25 @@ func (t *TestChain) PendingBlockAddTx(message *core.Message) error {
721764
// TODO reuse
722765
blockContext := newTestChainBlockContext(t, t.pendingBlock.Header)
723766

724-
// Create our EVM instance.
725-
evm := vm.NewEVM(blockContext, core.NewEVMTxContext(message), t.state, t.chainConfig, vm.Config{
767+
vmConfig := vm.Config{
726768
//Debug: true,
727-
Tracer: t.transactionTracerRouter.NativeTracer.Hooks,
728769
NoBaseFee: true,
729770
ConfigExtensions: t.vmConfigExtensions,
730-
})
731-
t.evm = evm
771+
}
772+
773+
tracer := getTracerFn(len(t.pendingBlock.Messages), tx.Hash())
774+
if tracer != nil {
775+
vmConfig.Tracer = tracer.Hooks
776+
}
732777

733778
t.state.SetTxContext(tx.Hash(), len(t.pendingBlock.Messages))
734779

780+
// Create our EVM instance.
781+
evm := vm.NewEVM(blockContext, core.NewEVMTxContext(message), t.state, t.chainConfig, vmConfig)
782+
783+
// Set our EVM instance for the test chain in order for cheatcodes to access EVM interpreter's block context.
784+
t.evm = evm
785+
735786
if evm.Config.Tracer != nil && evm.Config.Tracer.OnTxStart != nil {
736787
evm.Config.Tracer.OnTxStart(evm.GetVMContext(), tx, message.From)
737788
}
@@ -934,11 +985,11 @@ func (t *TestChain) emitContractChangeEvents(reverting bool, messageResults ...*
934985
Contract: deploymentChange.Contract,
935986
})
936987
} else if deploymentChange.Destroyed {
937-
err = t.Events.ContractDeploymentAddedEventEmitter.Publish(ContractDeploymentsAddedEvent{
938-
Chain: t,
939-
Contract: deploymentChange.Contract,
940-
DynamicDeployment: deploymentChange.DynamicCreation,
941-
})
988+
// err = t.Events.ContractDeploymentAddedEventEmitter.Publish(ContractDeploymentsAddedEvent{
989+
// Chain: t,
990+
// Contract: deploymentChange.Contract,
991+
// DynamicDeployment: deploymentChange.DynamicCreation,
992+
// })
942993
}
943994
if err != nil {
944995
return err

chain/test_chain_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ func createChain(t *testing.T) (*TestChain, []common.Address) {
6464
assert.NoError(t, err)
6565

6666
// NOTE: Sharing GenesisAlloc between nodes will result in some accounts not being funded for some reason.
67-
genesisAlloc := make(core.GenesisAlloc)
67+
genesisAlloc := make(types.GenesisAlloc)
6868

6969
// Fund all of our sender addresses in the genesis block
7070
initBalance := new(big.Int).Div(abi.MaxInt256, big.NewInt(2))
7171
for _, sender := range senders {
72-
genesisAlloc[sender] = core.GenesisAccount{
72+
genesisAlloc[sender] = types.Account{
7373
Balance: initBalance,
7474
}
7575
}
@@ -260,7 +260,7 @@ func TestChainDynamicDeployments(t *testing.T) {
260260
assert.NoError(t, err)
261261

262262
// Add our transaction to the block
263-
err = chain.PendingBlockAddTx(&msg)
263+
err = chain.PendingBlockAddTx(&msg, nil)
264264
assert.NoError(t, err)
265265

266266
// Commit the pending block to the chain, so it becomes the new head.
@@ -385,7 +385,7 @@ func TestChainDeploymentWithArgs(t *testing.T) {
385385
assert.NoError(t, err)
386386

387387
// Add our transaction to the block
388-
err = chain.PendingBlockAddTx(&msg)
388+
err = chain.PendingBlockAddTx(&msg, nil)
389389
assert.NoError(t, err)
390390

391391
// Commit the pending block to the chain, so it becomes the new head.
@@ -494,7 +494,7 @@ func TestChainCloning(t *testing.T) {
494494
assert.NoError(t, err)
495495

496496
// Add our transaction to the block
497-
err = chain.PendingBlockAddTx(&msg)
497+
err = chain.PendingBlockAddTx(&msg, nil)
498498
assert.NoError(t, err)
499499

500500
// Commit the pending block to the chain, so it becomes the new head.
@@ -588,7 +588,7 @@ func TestChainCallSequenceReplayMatchSimple(t *testing.T) {
588588
assert.NoError(t, err)
589589

590590
// Add our transaction to the block
591-
err = chain.PendingBlockAddTx(&msg)
591+
err = chain.PendingBlockAddTx(&msg, nil)
592592
assert.NoError(t, err)
593593

594594
// Commit the pending block to the chain, so it becomes the new head.
@@ -627,7 +627,7 @@ func TestChainCallSequenceReplayMatchSimple(t *testing.T) {
627627
_, err := recreatedChain.PendingBlockCreate()
628628
assert.NoError(t, err)
629629
for _, message := range chain.blocks[i].Messages {
630-
err = recreatedChain.PendingBlockAddTx(message)
630+
err = recreatedChain.PendingBlockAddTx(message, nil)
631631
assert.NoError(t, err)
632632
}
633633
err = recreatedChain.PendingBlockCommit()

compilation/types/compiled_contract.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ func (c *CompiledContract) IsMatch(initBytecode []byte, runtimeBytecode []byte)
4848
deploymentBytecodeHash := deploymentMetadata.ExtractBytecodeHash()
4949
definitionBytecodeHash := definitionMetadata.ExtractBytecodeHash()
5050
if deploymentBytecodeHash != nil && definitionBytecodeHash != nil {
51-
x := bytes.Equal(deploymentBytecodeHash, definitionBytecodeHash)
52-
return x
51+
return bytes.Equal(deploymentBytecodeHash, definitionBytecodeHash)
5352
}
5453
}
5554
}

fuzzing/calls/call_sequence_execution.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ package calls
22

33
import (
44
"fmt"
5+
56
"github.com/crytic/medusa/chain"
7+
"github.com/ethereum/go-ethereum/common"
8+
"github.com/ethereum/go-ethereum/eth/tracers"
69
)
710

811
// ExecuteCallSequenceFetchElementFunc describes a function that is called to obtain the next call sequence element to
@@ -22,7 +25,7 @@ type ExecuteCallSequenceExecutionCheckFunc func(currentExecutedSequence CallSequ
2225
// A "post element executed check" function is provided to check whether execution should stop after each element is
2326
// executed.
2427
// Returns the call sequence which was executed and an error if one occurs.
25-
func ExecuteCallSequenceIteratively(chain *chain.TestChain, fetchElementFunc ExecuteCallSequenceFetchElementFunc, executionCheckFunc ExecuteCallSequenceExecutionCheckFunc) (CallSequence, error) {
28+
func ExecuteCallSequenceIteratively(chain *chain.TestChain, fetchElementFunc ExecuteCallSequenceFetchElementFunc, executionCheckFunc ExecuteCallSequenceExecutionCheckFunc, getTracerFn func(txIndex int, txHash common.Hash) *tracers.Tracer) (CallSequence, error) {
2629
// If there is no fetch element function provided, throw an error
2730
if fetchElementFunc == nil {
2831
return nil, fmt.Errorf("could not execute call sequence on chain as the 'fetch element function' provided was nil")
@@ -84,7 +87,8 @@ func ExecuteCallSequenceIteratively(chain *chain.TestChain, fetchElementFunc Exe
8487
}
8588

8689
// Try to add our transaction to this block.
87-
err = chain.PendingBlockAddTx(callSequenceElement.Call.ToCoreMessage())
90+
err = chain.PendingBlockAddTx(callSequenceElement.Call.ToCoreMessage(), getTracerFn)
91+
8892
if err != nil {
8993
// If we encountered a block gas limit error, this tx is too expensive to fit in this block.
9094
// If there are other transactions in the block, this makes sense. The block is "full".
@@ -161,6 +165,18 @@ func ExecuteCallSequence(chain *chain.TestChain, callSequence CallSequence) (Cal
161165
return nil, nil
162166
}
163167

168+
return ExecuteCallSequenceIteratively(chain, fetchElementFunc, nil, nil)
169+
}
170+
171+
func ExecuteCallSequenceWithTracer(chain *chain.TestChain, callSequence CallSequence, getTracerFn func(txIndex int, txHash common.Hash) *tracers.Tracer) (CallSequence, error) {
172+
// Execute our sequence with a simple fetch operation provided to obtain each element.
173+
fetchElementFunc := func(currentIndex int) (*CallSequenceElement, error) {
174+
if currentIndex < len(callSequence) {
175+
return callSequence[currentIndex], nil
176+
}
177+
return nil, nil
178+
}
179+
164180
// Execute our provided call sequence iteratively.
165-
return ExecuteCallSequenceIteratively(chain, fetchElementFunc, nil)
181+
return ExecuteCallSequenceIteratively(chain, fetchElementFunc, nil, getTracerFn)
166182
}

fuzzing/corpus/corpus.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ func (c *Corpus) initializeSequences(sequenceFiles *corpusDirectory[calls.CallSe
210210
}
211211

212212
// Execute each call sequence, populating runtime data and collecting coverage data along the way.
213-
_, err = calls.ExecuteCallSequenceIteratively(testChain, fetchElementFunc, executionCheckFunc)
213+
_, err = calls.ExecuteCallSequenceIteratively(testChain, fetchElementFunc, executionCheckFunc, nil)
214214

215215
// If we failed to replay a sequence and measure coverage due to an unexpected error, report it.
216216
if err != nil {
@@ -228,8 +228,7 @@ func (c *Corpus) initializeSequences(sequenceFiles *corpusDirectory[calls.CallSe
228228
}
229229

230230
// Revert chain state to our starting point to test the next sequence.
231-
err = testChain.RevertToBlockNumber(baseBlockNumber)
232-
if err != nil {
231+
if err := testChain.RevertToBlockNumber(baseBlockNumber); err != nil {
233232
return fmt.Errorf("failed to reset the chain while seeding coverage: %v\n", err)
234233
}
235234
}

fuzzing/executiontracer/execution_tracer.go

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/crytic/medusa/chain"
88
"github.com/crytic/medusa/fuzzing/contracts"
9+
"github.com/crytic/medusa/utils"
910
"github.com/ethereum/go-ethereum/common"
1011
"github.com/ethereum/go-ethereum/core"
1112
"github.com/ethereum/go-ethereum/core/state"
@@ -23,15 +24,16 @@ import (
2324
func CallWithExecutionTrace(testChain *chain.TestChain, contractDefinitions contracts.Contracts, msg *core.Message, state *state.StateDB) (*core.ExecutionResult, *ExecutionTrace, error) {
2425
// Create an execution tracer
2526
executionTracer := NewExecutionTracer(contractDefinitions, testChain.CheatCodeContracts())
26-
27+
defer executionTracer.Close()
2728
// Call the contract on our chain with the provided state.
2829
executionResult, err := testChain.CallContract(msg, state, executionTracer.NativeTracer)
2930
if err != nil {
3031
return nil, nil, err
3132
}
3233

3334
// Obtain our trace
34-
trace := executionTracer.Trace()
35+
hash := utils.MessageToTransaction(msg).Hash()
36+
trace := executionTracer.GetTrace(hash)
3537

3638
// Return the trace
3739
return executionResult, trace, nil
@@ -49,6 +51,8 @@ type ExecutionTracer struct {
4951
// trace represents the current execution trace captured by this tracer.
5052
trace *ExecutionTrace
5153

54+
traceMap map[common.Hash]*ExecutionTrace
55+
5256
// currentCallFrame references the current call frame being traced.
5357
currentCallFrame *CallFrame
5458

@@ -72,11 +76,13 @@ func NewExecutionTracer(contractDefinitions contracts.Contracts, cheatCodeContra
7276
tracer := &ExecutionTracer{
7377
contractDefinitions: contractDefinitions,
7478
cheatCodeContracts: cheatCodeContracts,
79+
traceMap: make(map[common.Hash]*ExecutionTrace),
7580
}
7681
nativeTracer := &tracers.Tracer{
7782
Hooks: &tracing.Hooks{
7883
OnTxStart: tracer.OnTxStart,
7984
OnEnter: tracer.OnEnter,
85+
OnTxEnd: tracer.OnTxEnd,
8086
OnExit: tracer.OnExit,
8187
OnOpcode: tracer.OnOpcode,
8288
},
@@ -85,10 +91,19 @@ func NewExecutionTracer(contractDefinitions contracts.Contracts, cheatCodeContra
8591

8692
return tracer
8793
}
94+
func (t *ExecutionTracer) Close() {
95+
t.traceMap = nil
96+
}
8897

89-
// Trace returns the currently recording or last recorded execution trace by the tracer.
90-
func (t *ExecutionTracer) Trace() *ExecutionTrace {
91-
return t.trace
98+
// GetTrace returns the currently recording or last recorded execution trace by the tracer.
99+
func (t *ExecutionTracer) GetTrace(txHash common.Hash) *ExecutionTrace {
100+
if trace, ok := t.traceMap[txHash]; ok {
101+
return trace
102+
}
103+
return nil
104+
}
105+
func (t *ExecutionTracer) OnTxEnd(receipt *coretypes.Receipt, err error) {
106+
t.traceMap[receipt.TxHash] = t.trace
92107
}
93108

94109
// CaptureTxStart is called upon the start of transaction execution, as defined by tracers.Tracer.

0 commit comments

Comments
 (0)