diff --git a/chain/cheat_code_contract.go b/chain/cheat_code_contract.go index 08ceb6ea..94f130f5 100644 --- a/chain/cheat_code_contract.go +++ b/chain/cheat_code_contract.go @@ -3,6 +3,7 @@ package chain import ( "encoding/binary" "fmt" + "github.com/crytic/medusa/logging" "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common" @@ -27,6 +28,9 @@ type CheatCodeContract struct { // abi refers to the cheat code contract's ABI definition. abi abi.ABI + + // storage holds values stored by cheatcodes + storage map[string]map[string]any } // cheatCodeMethod defines the method information for a given precompiledContract. @@ -94,6 +98,7 @@ func newCheatCodeContract(tracer *cheatCodeTracer, address common.Address, name Fallback: abi.Method{}, Receive: abi.Method{}, }, + storage: make(map[string]map[string]any), } } diff --git a/chain/standard_cheat_code_contract.go b/chain/standard_cheat_code_contract.go index 97eb3288..86a359b2 100644 --- a/chain/standard_cheat_code_contract.go +++ b/chain/standard_cheat_code_contract.go @@ -9,6 +9,7 @@ import ( "strconv" "strings" + "github.com/crytic/medusa/chain/types" "github.com/crytic/medusa/utils" "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common" @@ -238,7 +239,7 @@ func getStandardCheatCodeContract(tracer *cheatCodeTracer) (*CheatCodeContract, }, ) - // Prank: Sets the msg.sender within the next EVM call scope created by the caller. + // prank: Sets the msg.sender within the next EVM call scope created by the caller. contract.addMethod( "prank", abi.Arguments{{Type: typeAddress}}, abi.Arguments{}, func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { @@ -261,7 +262,60 @@ func getStandardCheatCodeContract(tracer *cheatCodeTracer) (*CheatCodeContract, }, ) - // PrankHere: Sets the msg.sender within caller EVM scope until it is exited. + // startPrank: Sets the msg.sender within external calls until stopPrank is called + contract.addMethod( + "startPrank", abi.Arguments{{Type: typeAddress}}, abi.Arguments{}, + func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { + // Obtain the caller frame. + cheatCodeCallerFrame := tracer.PreviousCallFrame() + + // Initialize storage for startPrank + _, ok := contract.storage["startPrank"] + if !ok { + contract.storage["startPrank"] = map[string]any{} + } + + // Store new caller address + newCallerAddress := inputs[0].(common.Address) + contract.storage["startPrank"]["newCallerAddress"] = newCallerAddress + + var nextFrameEnterHook types.GenericHookFunc + nextFrameEnterHook = func() { + prankCallFrame := tracer.CurrentCallFrame() + scopeContext := prankCallFrame.vmScope.(*vm.ScopeContext) + + // If we don't have a caller address, stopPrank was called and we can stop propagating the hook + // If the caller address has changed, there has been a new call to startPrank and we can stop propagating the hook + if callerAddress, ok := contract.storage["startPrank"]["newCallerAddress"]; ok && callerAddress == newCallerAddress { + // Override the caller address for the current scope + scopeContext.Contract.CallerAddress = newCallerAddress + + // Re-attach hook to override caller address for subsequent external calls + prankCallFrame.onFrameExitRestoreHooks.Push(func() { + currentCallFrame := tracer.PreviousCallFrame() + currentCallFrame.onNextFrameEnterHooks.Push(nextFrameEnterHook) + }) + } + } + + cheatCodeCallerFrame.onNextFrameEnterHooks.Push(nextFrameEnterHook) + + return nil, nil + }, + ) + + // stopPrank: Resets msg.sender altered by startPrank + contract.addMethod( + "stopPrank", abi.Arguments{}, abi.Arguments{}, + func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { + // Delete new caller address + delete(contract.storage["startPrank"], "newCallerAddress") + + return nil, nil + }, + ) + + // prankHere: Sets the msg.sender within caller EVM scope until it is exited. contract.addMethod( "prankHere", abi.Arguments{{Type: typeAddress}}, abi.Arguments{}, func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { diff --git a/fuzzing/fuzzer_test.go b/fuzzing/fuzzer_test.go index 0af56dd9..44bbec2d 100644 --- a/fuzzing/fuzzer_test.go +++ b/fuzzing/fuzzer_test.go @@ -222,6 +222,7 @@ func TestCheatCodes(t *testing.T) { "testdata/contracts/cheat_codes/vm/etch.sol", "testdata/contracts/cheat_codes/vm/fee.sol", "testdata/contracts/cheat_codes/vm/prank.sol", + "testdata/contracts/cheat_codes/vm/start_prank.sol", "testdata/contracts/cheat_codes/vm/roll.sol", "testdata/contracts/cheat_codes/vm/store_load.sol", "testdata/contracts/cheat_codes/vm/warp.sol", diff --git a/fuzzing/testdata/contracts/cheat_codes/vm/start_prank.sol b/fuzzing/testdata/contracts/cheat_codes/vm/start_prank.sol new file mode 100644 index 00000000..5c4a3c1c --- /dev/null +++ b/fuzzing/testdata/contracts/cheat_codes/vm/start_prank.sol @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.13; + +// This test startPrank (spoof msg.sender on subsequent calls in the current scope) and stopPrank (restore original msg.sender). +interface CheatCodes { + function addr(uint256 privateKey) external returns (address); + function startPrank(address) external; + function stopPrank() external; +} + +contract Ownable { + address public owner; + + constructor() { + owner = msg.sender; + } +} + +contract TestContract { + TestContract thisExternal = TestContract(address(this)); + int currentDepth; + bool calledThroughTestFunction; + + Ownable public one; + Ownable public two; + Ownable public three; + Ownable public four; + Ownable public five; + Ownable public six; + Ownable public seven; + address prankAddr = address(7); + + // Obtain our cheat code contract reference. + CheatCodes cheats = CheatCodes(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D); + + function prankAndGetSenderAtDepth( + int prankDepth, + address prankAddress, + int senderDepth + ) public returns (address) { + // This can't be called directly + require(calledThroughTestFunction); + + // This test should be written so prank depth is never less than the sender fetching depth, and the sender + // depth must be a positive number. Prank depth can be negative, to not prank at all. + require(senderDepth >= prankDepth); + require(senderDepth >= 0); + require(senderDepth < 10); // disallow printing at depth > 10, in case fuzzer hits function this directly + + // If we are at the depth we wanted to prank at, call prank. + if (currentDepth == prankDepth) { + // Change value and verify. + cheats.startPrank(prankAddress); + } + + // If we haven't reached our depth limit, return our result from a further external call. + if (currentDepth < senderDepth) { + currentDepth++; + address sender = thisExternal.prankAndGetSenderAtDepth( + prankDepth, + prankAddress, + senderDepth + ); + currentDepth--; + return sender; + } + + // We have reached our depth limit, return the sender. + return msg.sender; + } + + function test() public { + address owner1 = cheats.addr(0x123456); + address owner2 = cheats.addr(0x234567); + address originalMsgSender = msg.sender; + address thisExternalAddr = address(this); + + cheats.startPrank(owner1); + one = new Ownable(); + two = new Ownable(); + three = new Ownable(); + + // Ensure that the msg.sender for this scope has not changed + assert(msg.sender == originalMsgSender); + + // Check that the prank changed the msg.sender for subsequent external calls + assert(one.owner() == owner1); + assert(two.owner() == owner1); + assert(three.owner() == owner1); + + cheats.startPrank(owner2); + four = new Ownable(); + five = new Ownable(); + + // Ensure that the pranked address has changed + assert(four.owner() == owner2); + assert(five.owner() == owner2); + + cheats.stopPrank(); + six = new Ownable(); + seven = new Ownable(); + + // Check that the msg.sender of external calls is reset to the original + assert(six.owner() == thisExternalAddr); + assert(seven.owner() == thisExternalAddr); + + calledThroughTestFunction = true; + + // Check that the msg.sender for nested calls is not spoofed by startPrank + assert(prankAndGetSenderAtDepth(1, prankAddr, 1) == thisExternalAddr); + assert(prankAndGetSenderAtDepth(1, prankAddr, 2) == prankAddr); + assert(prankAndGetSenderAtDepth(1, prankAddr, 3) == thisExternalAddr); + + calledThroughTestFunction = false; + } +}