Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add vm.startPrank and vm.stopPrank cheatcodes #22

Merged
merged 6 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 95 additions & 7 deletions e2e-tests/contracts/TestCheatcodes.sol
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ contract TestCheatcodes {
require(success, "setGreeting failed");
}

function testRoll(uint256 blockNumber) external {
uint256 initialBlockNumber = block.number;
require(blockNumber != initialBlockNumber, "block number must be different than current block number");

(bool success, ) = CHEATCODE_ADDRESS.call(abi.encodeWithSignature("roll(uint256)", blockNumber));
require(success, "roll failed");

uint256 finalBlockNumber = block.number;
require(finalBlockNumber == blockNumber, "block number was not changed");
}

function testSetNonce(address account, uint64 nonce) external {
(bool success, bytes memory data) = CHEATCODE_ADDRESS.call(
abi.encodeWithSignature("setNonce(address,uint64)", account, nonce)
Expand All @@ -32,15 +43,80 @@ contract TestCheatcodes {
require(finalNonce == nonce, "nonce mismatch");
}

function testRoll(uint256 blockNumber) external {
uint256 initialBlockNumber = block.number;
require(blockNumber != initialBlockNumber, "block number must be different than current block number");
function testStartPrank(address account) external {
address original_msg_sender = msg.sender;
address original_tx_origin = tx.origin;

(bool success, ) = CHEATCODE_ADDRESS.call(abi.encodeWithSignature("roll(uint256)", blockNumber));
require(success, "roll failed");
PrankVictim victim = new PrankVictim();

uint256 finalBlockNumber = block.number;
require(finalBlockNumber == blockNumber, "block number was not changed");
victim.assertCallerAndOrigin(
address(this),
"startPrank failed: victim.assertCallerAndOrigin failed",
original_tx_origin,
"startPrank failed: victim.assertCallerAndOrigin failed"
);

(bool success1, ) = CHEATCODE_ADDRESS.call(abi.encodeWithSignature("startPrank(address)", account));
require(success1, "startPrank failed");

require(msg.sender == account, "startPrank failed: msg.sender unchanged");
require(tx.origin == original_tx_origin, "startPrank failed tx.origin changed");
victim.assertCallerAndOrigin(
account,
"startPrank failed: victim.assertCallerAndOrigin failed",
original_tx_origin,
"startPrank failed: victim.assertCallerAndOrigin failed"
);

(bool success2, ) = CHEATCODE_ADDRESS.call(abi.encodeWithSignature("stopPrank()"));
require(success2, "stopPrank failed");

require(msg.sender == original_msg_sender, "stopPrank failed: msg.sender didn't return to original");
require(tx.origin == original_tx_origin, "stopPrank failed tx.origin changed");
victim.assertCallerAndOrigin(
address(this),
"startPrank failed: victim.assertCallerAndOrigin failed",
original_tx_origin,
"startPrank failed: victim.assertCallerAndOrigin failed"
);
}

function testStartPrankWithOrigin(address account, address origin) external {
address original_msg_sender = msg.sender;
address original_tx_origin = tx.origin;

PrankVictim victim = new PrankVictim();

victim.assertCallerAndOrigin(
address(this),
"startPrank failed: victim.assertCallerAndOrigin failed",
original_tx_origin,
"startPrank failed: victim.assertCallerAndOrigin failed"
);

(bool success1, ) = CHEATCODE_ADDRESS.call(abi.encodeWithSignature("startPrank(address,address)", account, origin));
require(success1, "startPrank failed");

require(msg.sender == account, "startPrank failed: msg.sender unchanged");
require(tx.origin == origin, "startPrank failed: tx.origin unchanged");
victim.assertCallerAndOrigin(
account,
"startPrank failed: victim.assertCallerAndOrigin failed",
origin,
"startPrank failed: victim.assertCallerAndOrigin failed"
);

(bool success2, ) = CHEATCODE_ADDRESS.call(abi.encodeWithSignature("stopPrank()"));
require(success2, "stopPrank failed");

require(msg.sender == original_msg_sender, "stopPrank failed: msg.sender didn't return to original");
require(tx.origin == original_tx_origin, "stopPrank failed: tx.origin didn't return to original");
victim.assertCallerAndOrigin(
address(this),
"startPrank failed: victim.assertCallerAndOrigin failed",
original_tx_origin,
"startPrank failed: victim.assertCallerAndOrigin failed"
);
}

function testWarp(uint256 timestamp) external {
Expand All @@ -54,3 +130,15 @@ contract TestCheatcodes {
require(finalTimestamp == timestamp, "timestamp was not changed");
}
}

contract PrankVictim {
function assertCallerAndOrigin(
address expectedSender,
string memory senderMessage,
address expectedOrigin,
string memory originMessage
) public view {
require(msg.sender == expectedSender, senderMessage);
require(tx.origin == expectedOrigin, originMessage);
}
}
52 changes: 46 additions & 6 deletions e2e-tests/test/cheatcodes.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,23 @@ describe("Cheatcodes", function () {
expect(finalRandomWalletCode).to.not.eq(initialRandomWalletCode);
});

it("Should test vm.roll", async function () {
// Arrange
const wallet = new Wallet(RichAccounts[0].PrivateKey);
const deployer = new Deployer(hre, wallet);
const contract = await deployContract(deployer, "TestCheatcodes", []);

const blockNumber = await provider.getBlockNumber();
const newBlockNumber = blockNumber + 345;

// Act
const tx = await contract.testRoll(newBlockNumber, { gasLimit: 1000000 });
const receipt = await tx.wait();

// Assert
expect(receipt.status).to.eq(1);
});

it("Should test vm.setNonce and vm.getNonce", async function () {
// Arrange
const wallet = new Wallet(RichAccounts[0].PrivateKey);
Expand All @@ -66,23 +83,46 @@ describe("Cheatcodes", function () {
expect(finalNonce).to.eq(1234);
});

it("Should test vm.roll", async function () {
it("Should test vm.startPrank", async function () {
// Arrange
const wallet = new Wallet(RichAccounts[0].PrivateKey);
const deployer = new Deployer(hre, wallet);
const contract = await deployContract(deployer, "TestCheatcodes", []);

const blockNumber = await provider.getBlockNumber();
const newBlockNumber = blockNumber + 345;
const randomWallet = Wallet.createRandom().connect(provider);

// Act
const tx = await contract.testRoll(newBlockNumber, { gasLimit: 1000000 });
const cheatcodes = await deployContract(deployer, "TestCheatcodes", []);
const tx = await cheatcodes.testStartPrank(randomWallet.address, {
gasLimit: 10000000,
});
const receipt = await tx.wait();

// Assert
expect(receipt.status).to.eq(1);
});

it("Should test vm.startPrank with tx.origin", async function () {
// Arrange
const wallet = new Wallet(RichAccounts[0].PrivateKey);
const deployer = new Deployer(hre, wallet);
const randomMsgSender = Wallet.createRandom().connect(provider);
const randomTxOrigin = Wallet.createRandom().connect(provider);

// Act
const cheatcodes = await deployContract(deployer, "TestCheatcodes", []);
const tx1 = await cheatcodes.testStartPrank(randomMsgSender.address, {
gasLimit: 10000000,
});
const receipt1 = await tx1.wait();
const tx2 = await cheatcodes.testStartPrankWithOrigin(randomMsgSender.address, randomTxOrigin.address, {
gasLimit: 10000000,
});
const receipt2 = await tx2.wait();

// Assert
expect(receipt1.status).to.eq(1);
expect(receipt2.status).to.eq(1);
});

it("Should test vm.warp", async function () {
// Arrange
const wallet = new Wallet(RichAccounts[0].PrivateKey);
Expand Down
103 changes: 91 additions & 12 deletions src/cheatcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
fmt::Debug,
sync::{Arc, Mutex, RwLock},
};
use zksync_basic_types::{AccountTreeId, H160, H256, U256};
use zksync_basic_types::{AccountTreeId, Address, H160, H256, U256};
use zksync_state::{StoragePtr, WriteStorage};
use zksync_types::{
block::{pack_block_info, unpack_block_info},
Expand All @@ -36,12 +36,42 @@
113, 9, 112, 158, 207, 169, 26, 128, 98, 111, 243, 152, 157, 104, 246, 127, 91, 29, 209, 45,
]);

const INTERNAL_CONTRACT_ADDRESSES: [H160; 20] = [
zksync_types::BOOTLOADER_ADDRESS,
zksync_types::ACCOUNT_CODE_STORAGE_ADDRESS,
zksync_types::NONCE_HOLDER_ADDRESS,
zksync_types::KNOWN_CODES_STORAGE_ADDRESS,
zksync_types::IMMUTABLE_SIMULATOR_STORAGE_ADDRESS,
zksync_types::CONTRACT_DEPLOYER_ADDRESS,
zksync_types::CONTRACT_FORCE_DEPLOYER_ADDRESS,
zksync_types::L1_MESSENGER_ADDRESS,
zksync_types::MSG_VALUE_SIMULATOR_ADDRESS,
zksync_types::KECCAK256_PRECOMPILE_ADDRESS,
zksync_types::L2_ETH_TOKEN_ADDRESS,
zksync_types::SYSTEM_CONTEXT_ADDRESS,
zksync_types::BOOTLOADER_UTILITIES_ADDRESS,
zksync_types::EVENT_WRITER_ADDRESS,
zksync_types::COMPRESSOR_ADDRESS,
zksync_types::COMPLEX_UPGRADER_ADDRESS,
zksync_types::ECRECOVER_PRECOMPILE_ADDRESS,
zksync_types::SHA256_PRECOMPILE_ADDRESS,
zksync_types::MINT_AND_BURN_ADDRESS,
H160::zero(),
];

#[derive(Clone, Debug, Default)]
pub struct CheatcodeTracer<F> {
node_ctx: F,
returndata: Option<Vec<U256>>,
return_ptr: Option<FatPointer>,
near_calls: usize,
start_prank_opts: Option<StartPrankOpts>,
}

#[derive(Clone, Debug)]
pub struct StartPrankOpts {
sender: Address,
origin: Option<Address>,
}

pub trait NodeCtx {
Expand All @@ -54,9 +84,12 @@
r#"[
function deal(address who, uint256 newBalance)
function etch(address who, bytes calldata code)
function setNonce(address account, uint64 nonce)
function getNonce(address account)
function roll(uint256 blockNumber)
function setNonce(address account, uint64 nonce)
function startPrank(address sender)
function startPrank(address sender, address origin)
function stopPrank()
function warp(uint256 timestamp)
]"#
);
Expand Down Expand Up @@ -146,6 +179,14 @@
timestamp,
);
}

if let Some(start_prank_call) = &self.start_prank_opts {
let this_address = state.local_state.callstack.current.this_address;
if !INTERNAL_CONTRACT_ADDRESSES.contains(&this_address) {
state.local_state.callstack.current.msg_sender = start_prank_call.sender;
}
}

TracerExecutionStatus::Continue
}
}
Expand All @@ -154,6 +195,7 @@
pub fn new(node_ctx: F) -> Self {
Self {
node_ctx,
start_prank_opts: None,
returndata: None,
return_ptr: None,
near_calls: 0,
Expand Down Expand Up @@ -193,6 +235,21 @@
);
storage.borrow_mut().set_value(code_key, hash);
}
Roll(RollCall { block_number }) => {
tracing::info!("Setting block number to {}", block_number);

let key = StorageKey::new(
AccountTreeId::new(zksync_types::SYSTEM_CONTEXT_ADDRESS),
zksync_types::CURRENT_VIRTUAL_BLOCK_INFO_POSITION,
);
let mut storage = storage.borrow_mut();
let (_, block_timestamp) =
unpack_block_info(h256_to_u256(storage.read_value(&key)));
storage.set_value(
key,
u256_to_h256(pack_block_info(block_number.as_u64(), block_timestamp)),
);
}
GetNonce(GetNonceCall { account }) => {
tracing::info!("Getting nonce for {account:?}");
let mut storage = storage.borrow_mut();
Expand All @@ -205,7 +262,7 @@
account_nonce.as_u64()
);
tracing::info!("👷 Setting returndata",);
self.returndata = Some(vec![account_nonce.into()]);
self.returndata = Some(vec![account_nonce]);
}
SetNonce(SetNonceCall { account, nonce }) => {
tracing::info!("Setting nonce for {account:?} to {nonce}");
Expand Down Expand Up @@ -240,20 +297,42 @@
);
storage.set_value(nonce_key, u256_to_h256(enforced_full_nonce));
}
Roll(RollCall { block_number }) => {
tracing::info!("Setting block number to {}", block_number);
StartPrank(StartPrankCall { sender }) => {
tracing::info!("Starting prank to {sender:?}");
self.start_prank_opts = Some(StartPrankOpts {
sender,
origin: None,
});
}
StartPrankWithOrigin(StartPrankWithOriginCall { sender, origin }) => {
tracing::info!("Starting prank to {sender:?} with origin {origin:?}");

let key = StorageKey::new(
AccountTreeId::new(zksync_types::SYSTEM_CONTEXT_ADDRESS),
zksync_types::CURRENT_VIRTUAL_BLOCK_INFO_POSITION,
zksync_types::SYSTEM_CONTEXT_TX_ORIGIN_POSITION,
);
let mut storage = storage.borrow_mut();
let (_, block_timestamp) =
unpack_block_info(h256_to_u256(storage.read_value(&key)));
storage.set_value(
key,
u256_to_h256(pack_block_info(block_number.as_u64(), block_timestamp)),
);
let original_tx_origin = storage.read_value(&key);
storage.set_value(key, origin.into());

self.start_prank_opts = Some(StartPrankOpts {
sender,
origin: Some(original_tx_origin.into()),
});
}
StopPrank(StopPrankCall) => {
tracing::info!("Stopping prank");

if let Some(origin) = self.start_prank_opts.as_ref().and_then(|v| v.origin) {
let key = StorageKey::new(
AccountTreeId::new(zksync_types::SYSTEM_CONTEXT_ADDRESS),
zksync_types::SYSTEM_CONTEXT_TX_ORIGIN_POSITION,
);
let mut storage = storage.borrow_mut();
storage.set_value(key, origin.into());
}

self.start_prank_opts = None;
}
Warp(WarpCall { timestamp }) => {
tracing::info!("Setting block timestamp {}", timestamp);
Expand Down Expand Up @@ -318,14 +397,14 @@
use crate::{
deps::system_contracts::bytecode_from_slice,
http_fork_source::HttpForkSource,
node::{InMemoryNode, TransactionResult},

Check failure on line 400 in src/cheatcodes.rs

View workflow job for this annotation

GitHub Actions / unit-tests (macos-latest)

unused imports: `LogBuilder`, `TransactionBuilder`, `TransactionResult`
testing::{self, LogBuilder, TransactionBuilder},
};
use ethers::abi::{short_signature, AbiEncode, HumanReadableParser, ParamType, Token};

Check failure on line 403 in src/cheatcodes.rs

View workflow job for this annotation

GitHub Actions / unit-tests (macos-latest)

unused imports: `AbiEncode`, `HumanReadableParser`, `ParamType`, `Token`
use zksync_basic_types::{Address, L2ChainId, Nonce, H160, H256, U256};
use zksync_core::api_server::web3::backend_jsonrpc::namespaces::eth::EthNamespaceT;
use zksync_types::{
api::{Block, CallTracerConfig, SupportedTracers, TransactionReceipt},

Check failure on line 407 in src/cheatcodes.rs

View workflow job for this annotation

GitHub Actions / unit-tests (macos-latest)

unused imports: `Block`, `CallTracerConfig`, `SupportedTracers`, `TransactionReceipt`, `transaction_request::CallRequestBuilder`
fee::Fee,
l2::L2Tx,
transaction_request::CallRequestBuilder,
Expand Down
Loading