Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

startPrank and stopPrank cheatcode #452

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions chain/cheat_code_contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package chain
import (
"encoding/binary"
"fmt"

"github.com/crytic/medusa/logging"
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
Expand All @@ -27,6 +28,9 @@ type CheatCodeContract struct {

// abi refers to the cheat code contract's ABI definition.
abi abi.ABI

// storage holds values stored by cheatcodes
storage map[string]map[string]any
}

// cheatCodeMethod defines the method information for a given precompiledContract.
Expand Down Expand Up @@ -94,6 +98,7 @@ func newCheatCodeContract(tracer *cheatCodeTracer, address common.Address, name
Fallback: abi.Method{},
Receive: abi.Method{},
},
storage: make(map[string]map[string]any),
}
}

Expand Down
58 changes: 56 additions & 2 deletions chain/standard_cheat_code_contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strconv"
"strings"

"github.com/crytic/medusa/chain/types"
"github.com/crytic/medusa/utils"
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
Expand Down Expand Up @@ -238,7 +239,7 @@ func getStandardCheatCodeContract(tracer *cheatCodeTracer) (*CheatCodeContract,
},
)

// Prank: Sets the msg.sender within the next EVM call scope created by the caller.
// prank: Sets the msg.sender within the next EVM call scope created by the caller.
contract.addMethod(
"prank", abi.Arguments{{Type: typeAddress}}, abi.Arguments{},
func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) {
Expand All @@ -261,7 +262,60 @@ func getStandardCheatCodeContract(tracer *cheatCodeTracer) (*CheatCodeContract,
},
)

// PrankHere: Sets the msg.sender within caller EVM scope until it is exited.
// startPrank: Sets the msg.sender within external calls until stopPrank is called
contract.addMethod(
"startPrank", abi.Arguments{{Type: typeAddress}}, abi.Arguments{},
func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) {
// Obtain the caller frame.
cheatCodeCallerFrame := tracer.PreviousCallFrame()

// Initialize storage for startPrank
_, ok := contract.storage["startPrank"]
if !ok {
contract.storage["startPrank"] = map[string]any{}
}

// Store new caller address
newCallerAddress := inputs[0].(common.Address)
contract.storage["startPrank"]["newCallerAddress"] = newCallerAddress

var nextFrameEnterHook types.GenericHookFunc
nextFrameEnterHook = func() {
prankCallFrame := tracer.CurrentCallFrame()
scopeContext := prankCallFrame.vmScope.(*vm.ScopeContext)

// If we don't have a caller address, stopPrank was called and we can stop propagating the hook
// If the caller address has changed, there has been a new call to startPrank and we can stop propagating the hook
if callerAddress, ok := contract.storage["startPrank"]["newCallerAddress"]; ok && callerAddress == newCallerAddress {
// Override the caller address for the current scope
scopeContext.Contract.CallerAddress = newCallerAddress

// Re-attach hook to override caller address for subsequent external calls
prankCallFrame.onFrameExitRestoreHooks.Push(func() {
currentCallFrame := tracer.PreviousCallFrame()
currentCallFrame.onNextFrameEnterHooks.Push(nextFrameEnterHook)
})
}
}

cheatCodeCallerFrame.onNextFrameEnterHooks.Push(nextFrameEnterHook)

return nil, nil
},
)

// stopPrank: Resets msg.sender altered by startPrank
contract.addMethod(
"stopPrank", abi.Arguments{}, abi.Arguments{},
func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) {
// Delete new caller address
delete(contract.storage["startPrank"], "newCallerAddress")

return nil, nil
},
)

// prankHere: Sets the msg.sender within caller EVM scope until it is exited.
contract.addMethod(
"prankHere", abi.Arguments{{Type: typeAddress}}, abi.Arguments{},
func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) {
Expand Down
1 change: 1 addition & 0 deletions fuzzing/fuzzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ func TestCheatCodes(t *testing.T) {
"testdata/contracts/cheat_codes/vm/etch.sol",
"testdata/contracts/cheat_codes/vm/fee.sol",
"testdata/contracts/cheat_codes/vm/prank.sol",
"testdata/contracts/cheat_codes/vm/start_prank.sol",
"testdata/contracts/cheat_codes/vm/roll.sol",
"testdata/contracts/cheat_codes/vm/store_load.sol",
"testdata/contracts/cheat_codes/vm/warp.sol",
Expand Down
116 changes: 116 additions & 0 deletions fuzzing/testdata/contracts/cheat_codes/vm/start_prank.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.13;

// This test startPrank (spoof msg.sender on subsequent calls in the current scope) and stopPrank (restore original msg.sender).
interface CheatCodes {
function addr(uint256 privateKey) external returns (address);
function startPrank(address) external;
function stopPrank() external;
}

contract Ownable {
address public owner;

constructor() {
owner = msg.sender;
}
}

contract TestContract {
TestContract thisExternal = TestContract(address(this));
int currentDepth;
bool calledThroughTestFunction;

Ownable public one;
Ownable public two;
Ownable public three;
Ownable public four;
Ownable public five;
Ownable public six;
Ownable public seven;
address prankAddr = address(7);

// Obtain our cheat code contract reference.
CheatCodes cheats = CheatCodes(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D);

function prankAndGetSenderAtDepth(
int prankDepth,
address prankAddress,
int senderDepth
) public returns (address) {
// This can't be called directly
require(calledThroughTestFunction);

// This test should be written so prank depth is never less than the sender fetching depth, and the sender
// depth must be a positive number. Prank depth can be negative, to not prank at all.
require(senderDepth >= prankDepth);
require(senderDepth >= 0);
require(senderDepth < 10); // disallow printing at depth > 10, in case fuzzer hits function this directly

// If we are at the depth we wanted to prank at, call prank.
if (currentDepth == prankDepth) {
// Change value and verify.
cheats.startPrank(prankAddress);
}

// If we haven't reached our depth limit, return our result from a further external call.
if (currentDepth < senderDepth) {
currentDepth++;
address sender = thisExternal.prankAndGetSenderAtDepth(
prankDepth,
prankAddress,
senderDepth
);
currentDepth--;
return sender;
}

// We have reached our depth limit, return the sender.
return msg.sender;
}

function test() public {
address owner1 = cheats.addr(0x123456);
address owner2 = cheats.addr(0x234567);
address originalMsgSender = msg.sender;
address thisExternalAddr = address(this);

cheats.startPrank(owner1);
one = new Ownable();
two = new Ownable();
three = new Ownable();

// Ensure that the msg.sender for this scope has not changed
assert(msg.sender == originalMsgSender);

// Check that the prank changed the msg.sender for subsequent external calls
assert(one.owner() == owner1);
assert(two.owner() == owner1);
assert(three.owner() == owner1);

cheats.startPrank(owner2);
four = new Ownable();
five = new Ownable();

// Ensure that the pranked address has changed
assert(four.owner() == owner2);
assert(five.owner() == owner2);

cheats.stopPrank();
six = new Ownable();
seven = new Ownable();

// Check that the msg.sender of external calls is reset to the original
assert(six.owner() == thisExternalAddr);
assert(seven.owner() == thisExternalAddr);

calledThroughTestFunction = true;

// Check that the msg.sender for nested calls is not spoofed by startPrank
assert(prankAndGetSenderAtDepth(1, prankAddr, 1) == thisExternalAddr);
assert(prankAndGetSenderAtDepth(1, prankAddr, 2) == prankAddr);
assert(prankAndGetSenderAtDepth(1, prankAddr, 3) == thisExternalAddr);

calledThroughTestFunction = false;
}
}