Skip to content

Commit ed1ea36

Browse files
Adds support for label cheatcode (crytic#545)
Adds support for the vm.label cheatcode. The label cheatcode will append the label to the TestChain. These labels can then be used by the execution tracer. --------- Co-authored-by: Anish Naik <[email protected]>
1 parent 42a22ca commit ed1ea36

File tree

11 files changed

+198
-32
lines changed

11 files changed

+198
-32
lines changed

chain/standard_cheat_code_contract.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,17 @@ func getStandardCheatCodeContract(tracer *cheatCodeTracer) (*CheatCodeContract,
250250
},
251251
)
252252

253+
// Label: Sets a label for an address.
254+
contract.addMethod(
255+
"label", abi.Arguments{{Type: typeAddress}, {Type: typeString}}, abi.Arguments{},
256+
func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) {
257+
addr := inputs[0].(common.Address)
258+
label := inputs[1].(string)
259+
tracer.chain.Labels[addr] = label
260+
return nil, nil
261+
},
262+
)
263+
253264
// Load: Loads a storage slot value from a given account.
254265
contract.addMethod(
255266
"load", abi.Arguments{{Type: typeAddress}, {Type: typeBytes32}}, abi.Arguments{{Type: typeBytes32}},

chain/test_chain.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ type TestChain struct {
7474
// This is constructed over the kvstore.
7575
db ethdb.Database
7676

77+
// Labels maps an address to its label if one exists. This is useful for execution tracing.
78+
Labels map[common.Address]string
79+
7780
// callTracerRouter forwards tracers.Tracer and TestChainTracer calls to any instances added to it. This
7881
// router is used for non-state changing calls.
7982
callTracerRouter *TestChainTracerRouter
@@ -187,6 +190,7 @@ func NewTestChain(genesisAlloc types.GenesisAlloc, testChainConfig *config.TestC
187190
db: db,
188191
state: nil,
189192
stateDatabase: stateDatabase,
193+
Labels: make(map[common.Address]string),
190194
transactionTracerRouter: transactionTracerRouter,
191195
callTracerRouter: callTracerRouter,
192196
testChainConfig: testChainConfig,

fuzzing/calls/call_sequence.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ func (cse *CallSequenceElement) String() string {
261261
args, err := method.Inputs.Unpack(cse.Call.Data[4:])
262262
argsText := "<unable to unpack args>"
263263
if err == nil {
264-
argsText, err = valuegeneration.EncodeABIArgumentsToString(method.Inputs, args)
264+
argsText, err = valuegeneration.EncodeABIArgumentsToString(method.Inputs, args, nil)
265265
if err != nil {
266266
argsText = "<unresolved args>"
267267
}
@@ -286,7 +286,7 @@ func (cse *CallSequenceElement) String() string {
286286
cse.Call.GasLimit,
287287
cse.Call.GasPrice.String(),
288288
cse.Call.Value.String(),
289-
utils.TrimLeadingZeroesFromAddress(cse.Call.From.String()),
289+
cse.Call.From.String(),
290290
)
291291
}
292292

fuzzing/calls/call_sequence_execution.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func ExecuteCallSequence(chain *chain.TestChain, callSequence CallSequence) (Cal
172172
// ExecuteCallSequenceWithExecutionTracer attaches an executiontracer.ExecutionTracer to ExecuteCallSequenceIteratively and attaches execution traces to the call sequence elements.
173173
func ExecuteCallSequenceWithExecutionTracer(testChain *chain.TestChain, contractDefinitions contracts.Contracts, callSequence CallSequence, verboseTracing bool) (CallSequence, error) {
174174
// Create a new execution tracer
175-
executionTracer := executiontracer.NewExecutionTracer(contractDefinitions, testChain.CheatCodeContracts())
175+
executionTracer := executiontracer.NewExecutionTracer(contractDefinitions, testChain)
176176
defer executionTracer.Close()
177177

178178
// Execute our sequence with a simple fetch operation provided to obtain each element.

fuzzing/executiontracer/execution_trace.go

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"encoding/hex"
55
"errors"
66
"fmt"
7+
"github.com/crytic/medusa/utils"
8+
"github.com/ethereum/go-ethereum/common"
79
"regexp"
810
"strings"
911

@@ -13,7 +15,6 @@ import (
1315
"github.com/crytic/medusa/fuzzing/valuegeneration"
1416
"github.com/crytic/medusa/logging"
1517
"github.com/crytic/medusa/logging/colors"
16-
"github.com/crytic/medusa/utils"
1718
"github.com/ethereum/go-ethereum/accounts/abi"
1819
coreTypes "github.com/ethereum/go-ethereum/core/types"
1920
"github.com/ethereum/go-ethereum/core/vm"
@@ -29,13 +30,17 @@ type ExecutionTrace struct {
2930
// contractDefinitions represents the known contract definitions at the time of tracing. This is used to help
3031
// obtain any additional information regarding execution.
3132
contractDefinitions contracts.Contracts
33+
34+
// labels is a mapping that maps an address to its string representation for cleaner execution traces
35+
labels map[common.Address]string
3236
}
3337

3438
// newExecutionTrace creates and returns a new ExecutionTrace, to be used by the ExecutionTracer.
35-
func newExecutionTrace(contracts contracts.Contracts) *ExecutionTrace {
39+
func newExecutionTrace(contracts contracts.Contracts, labels map[common.Address]string) *ExecutionTrace {
3640
return &ExecutionTrace{
3741
TopLevelCallFrame: nil,
3842
contractDefinitions: contracts,
43+
labels: labels,
3944
}
4045
}
4146

@@ -70,10 +75,18 @@ func (t *ExecutionTrace) generateCallFrameEnterElements(callFrame *CallFrame) ([
7075

7176
// Resolve our contract names, as well as our method and its name from the code contract.
7277
if callFrame.ToContractAbi != nil {
78+
// Check to see if there is a label for the proxy address
7379
proxyContractName = callFrame.ToContractName
80+
if label, ok := t.labels[callFrame.ToAddress]; ok {
81+
proxyContractName = label
82+
}
7483
}
7584
if callFrame.CodeContractAbi != nil {
85+
// Check to see if there is a label for the code address
7686
codeContractName = callFrame.CodeContractName
87+
if label, ok := t.labels[callFrame.CodeAddress]; ok {
88+
codeContractName = label
89+
}
7790
if callFrame.IsContractCreation() {
7891
methodName = "constructor"
7992
method = &callFrame.CodeContractAbi.Constructor
@@ -102,8 +115,8 @@ func (t *ExecutionTrace) generateCallFrameEnterElements(callFrame *CallFrame) ([
102115
// Unpack our input values and obtain a string to represent them
103116
inputValues, err := method.Inputs.Unpack(abiDataInputBuffer)
104117
if err == nil {
105-
// Encode the ABI arguments into strings
106-
encodedInputString, err := valuegeneration.EncodeABIArgumentsToString(method.Inputs, inputValues)
118+
// Encode the ABI arguments into strings and provide the label overrides
119+
encodedInputString, err := valuegeneration.EncodeABIArgumentsToString(method.Inputs, inputValues, t.labels)
107120
if err == nil {
108121
inputArgumentsDisplayText = &encodedInputString
109122
}
@@ -137,24 +150,29 @@ func (t *ExecutionTrace) generateCallFrameEnterElements(callFrame *CallFrame) ([
137150
inputArgumentsDisplayText = &temp
138151
}
139152

153+
// Handle all label overrides
154+
toAddress := utils.AttachLabelToAddress(callFrame.ToAddress, t.labels[callFrame.ToAddress])
155+
senderAddress := utils.AttachLabelToAddress(callFrame.SenderAddress, t.labels[callFrame.SenderAddress])
156+
codeAddress := utils.AttachLabelToAddress(callFrame.CodeAddress, t.labels[callFrame.CodeAddress])
157+
140158
// Generate the message we wish to output finally, using all these display string components.
141159
// If we executed code, attach additional context such as the contract name, method, etc.
142160
var callInfo string
143161
if callFrame.IsProxyCall() {
144162
if callFrame.ExecutedCode {
145-
callInfo = fmt.Sprintf("%v -> %v.%v(%v) (addr=%v, code=%v, value=%v, sender=%v)", proxyContractName, codeContractName, methodName, *inputArgumentsDisplayText, utils.TrimLeadingZeroesFromAddress(callFrame.ToAddress.String()), utils.TrimLeadingZeroesFromAddress(callFrame.CodeAddress.String()), callFrame.CallValue, utils.TrimLeadingZeroesFromAddress(callFrame.SenderAddress.String()))
163+
callInfo = fmt.Sprintf("%v -> %v.%v(%v) (addr=%v, code=%v, value=%v, sender=%v)", proxyContractName, codeContractName, methodName, *inputArgumentsDisplayText, toAddress, codeAddress, callFrame.CallValue, senderAddress)
146164
} else {
147-
callInfo = fmt.Sprintf("(addr=%v, value=%v, sender=%v)", utils.TrimLeadingZeroesFromAddress(callFrame.ToAddress.String()), callFrame.CallValue, utils.TrimLeadingZeroesFromAddress(callFrame.SenderAddress.String()))
165+
callInfo = fmt.Sprintf("(addr=%v, value=%v, sender=%v)", toAddress, callFrame.CallValue, senderAddress)
148166
}
149167
} else {
150168
if callFrame.ExecutedCode {
151169
if callFrame.ToAddress == chain.ConsoleLogContractAddress {
152170
callInfo = fmt.Sprintf("%v.%v(%v)", codeContractName, methodName, *inputArgumentsDisplayText)
153171
} else {
154-
callInfo = fmt.Sprintf("%v.%v(%v) (addr=%v, value=%v, sender=%v)", codeContractName, methodName, *inputArgumentsDisplayText, utils.TrimLeadingZeroesFromAddress(callFrame.ToAddress.String()), callFrame.CallValue, utils.TrimLeadingZeroesFromAddress(callFrame.SenderAddress.String()))
172+
callInfo = fmt.Sprintf("%v.%v(%v) (addr=%v, value=%v, sender=%v)", codeContractName, methodName, *inputArgumentsDisplayText, toAddress, callFrame.CallValue, senderAddress)
155173
}
156174
} else {
157-
callInfo = fmt.Sprintf("(addr=%v, value=%v, sender=%v)", utils.TrimLeadingZeroesFromAddress(callFrame.ToAddress.String()), callFrame.CallValue, utils.TrimLeadingZeroesFromAddress(callFrame.SenderAddress.String()))
175+
callInfo = fmt.Sprintf("(addr=%v, value=%v, sender=%v)", toAddress, callFrame.CallValue, senderAddress)
158176
}
159177
}
160178

@@ -189,7 +207,7 @@ func (t *ExecutionTrace) generateCallFrameExitElements(callFrame *CallFrame) []a
189207
if callFrame.ReturnError == nil {
190208
outputValues, err := method.Outputs.Unpack(callFrame.ReturnData)
191209
if err == nil {
192-
encodedOutputString, err := valuegeneration.EncodeABIArgumentsToString(method.Outputs, outputValues)
210+
encodedOutputString, err := valuegeneration.EncodeABIArgumentsToString(method.Outputs, outputValues, t.labels)
193211
if err == nil {
194212
outputArgumentsDisplayText = &encodedOutputString
195213
}
@@ -232,7 +250,7 @@ func (t *ExecutionTrace) generateCallFrameExitElements(callFrame *CallFrame) []a
232250
// Try to unpack a custom Solidity error from the return values.
233251
matchedCustomError, unpackedCustomErrorArgs := abiutils.GetSolidityCustomRevertError(callFrame.CodeContractAbi, callFrame.ReturnError, callFrame.ReturnData)
234252
if matchedCustomError != nil {
235-
customErrorArgsDisplayText, err := valuegeneration.EncodeABIArgumentsToString(matchedCustomError.Inputs, unpackedCustomErrorArgs)
253+
customErrorArgsDisplayText, err := valuegeneration.EncodeABIArgumentsToString(matchedCustomError.Inputs, unpackedCustomErrorArgs, t.labels)
236254
if err == nil {
237255
elements = append(elements, colors.RedBold, fmt.Sprintf("[revert (error: %v(%v))]", matchedCustomError.Name, customErrorArgsDisplayText), colors.Reset, "\n")
238256
return elements
@@ -276,7 +294,7 @@ func (t *ExecutionTrace) generateEventEmittedElements(callFrame *CallFrame, even
276294
// If we resolved an event definition and unpacked data.
277295
if event != nil {
278296
// Format the values as a comma-separated string
279-
encodedEventValuesString, err := valuegeneration.EncodeABIArgumentsToString(event.Inputs, eventInputValues)
297+
encodedEventValuesString, err := valuegeneration.EncodeABIArgumentsToString(event.Inputs, eventInputValues, t.labels)
280298
if err == nil {
281299
// Format our event display text finally, with the event name.
282300
temp := fmt.Sprintf("%v(%v)", event.Name, encodedEventValuesString)

fuzzing/executiontracer/execution_tracer.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import (
2222
// Returns the ExecutionTrace for the call or an error if one occurs.
2323
func CallWithExecutionTrace(testChain *chain.TestChain, contractDefinitions contracts.Contracts, msg *core.Message, state *state.StateDB) (*core.ExecutionResult, *ExecutionTrace, error) {
2424
// Create an execution tracer
25-
executionTracer := NewExecutionTracer(contractDefinitions, testChain.CheatCodeContracts())
25+
executionTracer := NewExecutionTracer(contractDefinitions, testChain)
2626
defer executionTracer.Close()
2727

2828
// Call the contract on our chain with the provided state.
@@ -48,6 +48,11 @@ type ExecutionTracer struct {
4848
// trace represents the current execution trace captured by this tracer.
4949
trace *ExecutionTrace
5050

51+
// testChain represents the underlying chain that the execution tracer runs on
52+
testChain *chain.TestChain
53+
54+
// traceMap describes a mapping that allows someone to retrieve the execution trace for a common transaction
55+
// hash.
5156
traceMap map[common.Hash]*ExecutionTrace
5257

5358
// currentCallFrame references the current call frame being traced.
@@ -56,23 +61,21 @@ type ExecutionTracer struct {
5661
// contractDefinitions represents the contract definitions to match for execution traces.
5762
contractDefinitions contracts.Contracts
5863

59-
// cheatCodeContracts represents the cheat code contract definitions to match for execution traces.
60-
cheatCodeContracts map[common.Address]*chain.CheatCodeContract
61-
6264
// onNextCaptureState refers to methods which should be executed the next time OnOpcode executes.
6365
// OnOpcode is called prior to execution of an instruction. This allows actions to be performed
6466
// after some state is captured, on the next state capture (e.g. detecting a log instruction, but
6567
// using this structure to execute code later once the log is committed).
6668
onNextCaptureState []func()
6769

70+
// nativeTracer is the underlying tracer interface that the execution tracer follows
6871
nativeTracer *chain.TestChainTracer
6972
}
7073

7174
// NewExecutionTracer creates a ExecutionTracer and returns it.
72-
func NewExecutionTracer(contractDefinitions contracts.Contracts, cheatCodeContracts map[common.Address]*chain.CheatCodeContract) *ExecutionTracer {
75+
func NewExecutionTracer(contractDefinitions contracts.Contracts, testChain *chain.TestChain) *ExecutionTracer {
7376
tracer := &ExecutionTracer{
7477
contractDefinitions: contractDefinitions,
75-
cheatCodeContracts: cheatCodeContracts,
78+
testChain: testChain,
7679
traceMap: make(map[common.Hash]*ExecutionTrace),
7780
}
7881
innerTracer := &tracers.Tracer{
@@ -122,7 +125,7 @@ func (t *ExecutionTracer) OnTxEnd(receipt *coretypes.Receipt, err error) {
122125
// OnTxStart is called upon the start of transaction execution, as defined by tracers.Tracer.
123126
func (t *ExecutionTracer) OnTxStart(vm *tracing.VMContext, tx *coretypes.Transaction, from common.Address) {
124127
// Reset our capture state
125-
t.trace = newExecutionTrace(t.contractDefinitions)
128+
t.trace = newExecutionTrace(t.contractDefinitions, t.testChain.Labels)
126129
t.currentCallFrame = nil
127130
t.onNextCaptureState = nil
128131
t.traceMap = make(map[common.Hash]*ExecutionTrace)
@@ -151,7 +154,7 @@ func (t *ExecutionTracer) resolveCallFrameContractDefinitions(callFrame *CallFra
151154
// Try to resolve contract definitions for "to" address
152155
if callFrame.ToContractAbi == nil {
153156
// Try to resolve definitions from cheat code contracts
154-
if cheatCodeContract, ok := t.cheatCodeContracts[callFrame.ToAddress]; ok {
157+
if cheatCodeContract, ok := t.testChain.CheatCodeContracts()[callFrame.ToAddress]; ok {
155158
callFrame.ToContractName = cheatCodeContract.Name()
156159
callFrame.ToContractAbi = cheatCodeContract.Abi()
157160
callFrame.ExecutedCode = true
@@ -175,7 +178,7 @@ func (t *ExecutionTracer) resolveCallFrameContractDefinitions(callFrame *CallFra
175178
// Try to resolve contract definitions for "code" address
176179
if callFrame.CodeContractAbi == nil {
177180
// Try to resolve definitions from cheat code contracts
178-
if cheatCodeContract, ok := t.cheatCodeContracts[callFrame.CodeAddress]; ok {
181+
if cheatCodeContract, ok := t.testChain.CheatCodeContracts()[callFrame.CodeAddress]; ok {
179182
callFrame.CodeContractName = cheatCodeContract.Name()
180183
callFrame.CodeContractAbi = cheatCodeContract.Abi()
181184
callFrame.ExecutedCode = true

fuzzing/fuzzer_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,56 @@ func TestExecutionTraces(t *testing.T) {
602602
}
603603
}
604604

605+
// TestLabelCheatCode tests the vm.label cheatcode.
606+
func TestLabelCheatCode(t *testing.T) {
607+
// These are the expected messages in the execution trace
608+
expectedTraceMessages := []string{
609+
"ProxyContract.testVMLabel()()",
610+
"addr=ProxyContract [0xA647ff3c36cFab592509E13860ab8c4F28781a66]",
611+
"sender=MySender [0x10000]",
612+
"ProxyContract -> ImplementationContract.emitEvent(address)(ProxyContract [0xA647ff3c36cFab592509E13860ab8c4F28781a66])",
613+
"code=ImplementationContract [0x54919A19522Ce7c842E25735a9cFEcef1c0a06dA]",
614+
"[event] TestEvent(RandomAddress [0x20000])",
615+
"[return (ProxyContract [0xA647ff3c36cFab592509E13860ab8c4F28781a66])]",
616+
}
617+
runFuzzerTest(t, &fuzzerSolcFileTest{
618+
filePath: "testdata/contracts/cheat_codes/utils/label.sol",
619+
configUpdates: func(config *config.ProjectConfig) {
620+
config.Fuzzing.TargetContracts = []string{"TestContract"}
621+
// Only allow for one sender for proper testing of this unit test
622+
config.Fuzzing.SenderAddresses = []string{"0x10000"}
623+
config.Fuzzing.Testing.PropertyTesting.Enabled = false
624+
config.Fuzzing.Testing.OptimizationTesting.Enabled = false
625+
config.Slither.UseSlither = false
626+
},
627+
method: func(f *fuzzerTestContext) {
628+
// Start the fuzzer
629+
err := f.fuzzer.Start()
630+
assert.NoError(t, err)
631+
632+
// Check for failed assertion tests.
633+
failedTestCase := f.fuzzer.TestCasesWithStatus(TestCaseStatusFailed)
634+
assert.NotEmpty(t, failedTestCase, "expected to have failed test cases")
635+
636+
// Obtain our first failed test case, get the message, and verify it contains our assertion failed.
637+
failingSequence := *failedTestCase[0].CallSequence()
638+
assert.NotEmpty(t, failingSequence, "expected to have calls in the call sequence failing an assertion test")
639+
640+
// Obtain the last call
641+
lastCall := failingSequence[len(failingSequence)-1]
642+
assert.NotNilf(t, lastCall.ExecutionTrace, "expected to have an execution trace attached to call sequence for this test")
643+
644+
// Get the execution trace message
645+
executionTraceMsg := lastCall.ExecutionTrace.Log().String()
646+
647+
// Verify it contains all expected strings
648+
for _, expectedTraceMessage := range expectedTraceMessages {
649+
assert.Contains(t, executionTraceMsg, expectedTraceMessage)
650+
}
651+
},
652+
})
653+
}
654+
605655
// TestTestingScope runs tests to ensure dynamically deployed contracts are tested when the "test all contracts"
606656
// config option is specified. It also runs the fuzzer without the option enabled to ensure they are not tested.
607657
func TestTestingScope(t *testing.T) {

fuzzing/fuzzer_worker.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -570,11 +570,6 @@ func (fw *FuzzerWorker) run(baseTestChain *chain.TestChain) (bool, error) {
570570
initializedChain.Events.ContractDeploymentAddedEventEmitter.Subscribe(fw.onChainContractDeploymentAddedEvent)
571571
initializedChain.Events.ContractDeploymentRemovedEventEmitter.Subscribe(fw.onChainContractDeploymentRemovedEvent)
572572

573-
// Emit an event indicating the worker has created its chain.
574-
err = fw.Events.FuzzerWorkerChainCreated.Publish(FuzzerWorkerChainCreatedEvent{
575-
Worker: fw,
576-
Chain: initializedChain,
577-
})
578573
if err != nil {
579574
return fmt.Errorf("error returned by an event handler when emitting a worker chain created event: %v", err)
580575
}
@@ -584,6 +579,15 @@ func (fw *FuzzerWorker) run(baseTestChain *chain.TestChain) (bool, error) {
584579
fw.coverageTracer = coverage.NewCoverageTracer()
585580
initializedChain.AddTracer(fw.coverageTracer.NativeTracer(), true, false)
586581
}
582+
583+
// Copy the labels from the base chain to the worker's chain
584+
initializedChain.Labels = maps.Clone(baseTestChain.Labels)
585+
586+
// Emit an event indicating the worker has created its chain.
587+
err = fw.Events.FuzzerWorkerChainCreated.Publish(FuzzerWorkerChainCreatedEvent{
588+
Worker: fw,
589+
Chain: initializedChain,
590+
})
587591
return nil
588592
})
589593

0 commit comments

Comments
 (0)