diff --git a/chain/standard_cheat_code_contract.go b/chain/standard_cheat_code_contract.go index 79c2d3ad..144c7e77 100644 --- a/chain/standard_cheat_code_contract.go +++ b/chain/standard_cheat_code_contract.go @@ -283,6 +283,27 @@ func getStandardCheatCodeContract(tracer *cheatCodeTracer) (*CheatCodeContract, }, ) + // snapshot: Takes a snapshot of the current state of the evm and returns the id associated with the snapshot + contract.addMethod( + "snapshot", abi.Arguments{}, abi.Arguments{{Type: typeUint256}}, + func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { + snapshotID := tracer.evm.StateDB.Snapshot() + + return []any{snapshotID}, nil + }, + ) + + // revertTo(uint256): Revert the state of the evm to a previous snapshot. Takes the snapshot id to revert to. + contract.addMethod( + "revertTo", abi.Arguments{{Type: typeUint256}}, abi.Arguments{{Type: typeBool}}, + func(tracer *cheatCodeTracer, inputs []any) ([]any, *cheatCodeRawReturnData) { + snapshotID := inputs[0].(*big.Int) + tracer.evm.StateDB.RevertToSnapshot(int(snapshotID.Int64())) + + return []any{true}, nil + }, + ) + // FFI: Run arbitrary command on base OS contract.addMethod( "ffi", abi.Arguments{{Type: typeStringSlice}}, abi.Arguments{{Type: typeBytes}}, diff --git a/fuzzing/fuzzer_test.go b/fuzzing/fuzzer_test.go index 9167b183..b5ad96cc 100644 --- a/fuzzing/fuzzer_test.go +++ b/fuzzing/fuzzer_test.go @@ -208,6 +208,7 @@ func TestCheatCodes(t *testing.T) { "testdata/contracts/cheat_codes/utils/to_string.sol", "testdata/contracts/cheat_codes/utils/sign.sol", "testdata/contracts/cheat_codes/utils/parse.sol", + "testdata/contracts/cheat_codes/vm/snapshot_and_revert_to.sol", "testdata/contracts/cheat_codes/vm/coinbase.sol", "testdata/contracts/cheat_codes/vm/chain_id.sol", "testdata/contracts/cheat_codes/vm/deal.sol", diff --git a/fuzzing/testdata/contracts/cheat_codes/vm/snapshot_and_revert_to.sol b/fuzzing/testdata/contracts/cheat_codes/vm/snapshot_and_revert_to.sol new file mode 100644 index 00000000..577ff194 --- /dev/null +++ b/fuzzing/testdata/contracts/cheat_codes/vm/snapshot_and_revert_to.sol @@ -0,0 +1,58 @@ +// This test ensures that we can take a snapshot of the current state of the testchain and revert to the state at that snapshot using the snapshot and revertTo cheatcodes +pragma solidity ^0.8.0; + +interface CheatCodes { + function warp(uint256) external; + + function deal(address, uint256) external; + + function snapshot() external returns (uint256); + + function revertTo(uint256) external returns (bool); +} + +struct Storage { + uint slot0; + uint slot1; +} + +contract TestContract { + Storage store; + uint256 timestamp; + + function test() public { + // Obtain our cheat code contract reference. + CheatCodes cheats = CheatCodes( + 0x7109709ECfa91a80626fF3989D68f67F5b1DD12D + ); + + store.slot0 = 10; + store.slot1 = 20; + timestamp = block.timestamp; + cheats.deal(address(this), 5 ether); + + // Save state + uint256 snapshot = cheats.snapshot(); + + // Change state + store.slot0 = 300; + store.slot1 = 400; + cheats.deal(address(this), 500 ether); + cheats.warp(12345); + + // Assert that state has been changed + assert(store.slot0 == 300); + assert(store.slot1 == 400); + assert(address(this).balance == 500 ether); + assert(block.timestamp == 12345); + + // Revert to snapshot + cheats.revertTo(snapshot); + + // Ensure state has been reset + assert(store.slot0 == 10); + assert(store.slot1 == 20); + assert(address(this).balance == 5 ether); + assert(block.timestamp == timestamp); + } +}