diff --git a/src/StdCheats.sol b/src/StdCheats.sol index c238358e..1b01726b 100644 --- a/src/StdCheats.sol +++ b/src/StdCheats.sol @@ -9,6 +9,9 @@ import {Vm} from "./Vm.sol"; abstract contract StdCheatsSafe { Vm private constant vm = Vm(address(uint160(uint256(keccak256("hevm cheat code"))))); + uint256 private constant UINT256_MAX = + 115792089237316195423570985008687907853269984665640564039457584007913129639935; + bool private gasMeteringOff; // Data structures to parse Transaction objects from the broadcast artifact @@ -193,6 +196,14 @@ abstract contract StdCheatsSafe { uint256 key; } + enum AddressType { + Payable, + NonPayable, + ZeroAddress, + Precompile, + ForgeAddress + } + // Checks that `addr` is not blacklisted by token contracts that have a blacklist. function assumeNotBlacklisted(address token, address addr) internal view virtual { // Nothing to check if `token` is not a contract. @@ -222,11 +233,91 @@ abstract contract StdCheatsSafe { assumeNotBlacklisted(token, addr); } - function assumeNoPrecompiles(address addr) internal pure virtual { - assumeNoPrecompiles(addr, _pureChainId()); + function assumeAddressIsNot(address addr, AddressType addressType) internal virtual { + if (addressType == AddressType.Payable) { + assumeNotPayable(addr); + } else if (addressType == AddressType.NonPayable) { + assumePayable(addr); + } else if (addressType == AddressType.ZeroAddress) { + assumeNotZeroAddress(addr); + } else if (addressType == AddressType.Precompile) { + assumeNotPrecompile(addr); + } else if (addressType == AddressType.ForgeAddress) { + assumeNotForgeAddress(addr); + } + } + + function assumeAddressIsNot(address addr, AddressType addressType1, AddressType addressType2) internal virtual { + assumeAddressIsNot(addr, addressType1); + assumeAddressIsNot(addr, addressType2); + } + + function assumeAddressIsNot( + address addr, + AddressType addressType1, + AddressType addressType2, + AddressType addressType3 + ) internal virtual { + assumeAddressIsNot(addr, addressType1); + assumeAddressIsNot(addr, addressType2); + assumeAddressIsNot(addr, addressType3); + } + + function assumeAddressIsNot( + address addr, + AddressType addressType1, + AddressType addressType2, + AddressType addressType3, + AddressType addressType4 + ) internal virtual { + assumeAddressIsNot(addr, addressType1); + assumeAddressIsNot(addr, addressType2); + assumeAddressIsNot(addr, addressType3); + assumeAddressIsNot(addr, addressType4); + } + + // This function checks whether an address, `addr`, is payable. It works by sending 1 wei to + // `addr` and checking the `success` return value. + // NOTE: This function may result in state changes depending on the fallback/receive logic + // implemented by `addr`, which should be taken into account when this function is used. + function _isPayable(address addr) private returns (bool) { + require( + addr.balance < UINT256_MAX, + "StdCheats _isPayable(address): Balance equals max uint256, so it cannot receive any more funds" + ); + uint256 origBalanceTest = address(this).balance; + uint256 origBalanceAddr = address(addr).balance; + + vm.deal(address(this), 1); + (bool success,) = payable(addr).call{value: 1}(""); + + // reset balances + vm.deal(address(this), origBalanceTest); + vm.deal(addr, origBalanceAddr); + + return success; + } + + // NOTE: This function may result in state changes depending on the fallback/receive logic + // implemented by `addr`, which should be taken into account when this function is used. See the + // `_isPayable` method for more information. + function assumePayable(address addr) internal virtual { + vm.assume(_isPayable(addr)); + } + + function assumeNotPayable(address addr) internal virtual { + vm.assume(!_isPayable(addr)); } - function assumeNoPrecompiles(address addr, uint256 chainId) internal pure virtual { + function assumeNotZeroAddress(address addr) internal pure virtual { + vm.assume(addr != address(0)); + } + + function assumeNotPrecompile(address addr) internal pure virtual { + assumeNotPrecompile(addr, _pureChainId()); + } + + function assumeNotPrecompile(address addr, uint256 chainId) internal pure virtual { // Note: For some chains like Optimism these are technically predeploys (i.e. bytecode placed at a specific // address), but the same rationale for excluding them applies so we include those too. @@ -249,6 +340,11 @@ abstract contract StdCheatsSafe { // forgefmt: disable-end } + function assumeNotForgeAddress(address addr) internal pure virtual { + // vm and console addresses + vm.assume(addr != address(vm) || addr != 0x000000000000000000636F6e736F6c652e6c6f67); + } + function readEIP1559ScriptArtifact(string memory path) internal view @@ -512,13 +608,6 @@ abstract contract StdCheatsSafe { } } - // a cheat for fuzzing addresses that are payable only - // see https://github.com/foundry-rs/foundry/issues/3631 - function assumePayable(address addr) internal virtual { - (bool success,) = payable(addr).call{value: 0}(""); - vm.assume(success); - } - // We use this complex approach of `_viewChainId` and `_pureChainId` to ensure there are no // compiler warnings when accessing chain ID in any solidity version supported by forge-std. We // can't simply access the chain ID in a normal view or pure function because the solc View Pure diff --git a/test/StdCheats.t.sol b/test/StdCheats.t.sol index 22735811..05a56882 100644 --- a/test/StdCheats.t.sol +++ b/test/StdCheats.t.sol @@ -335,21 +335,20 @@ contract StdCheatsTest is Test { return number; } - function testAssumeNoPrecompiles(address addr) external { - assumeNoPrecompiles(addr, getChain("optimism_goerli").chainId); - assertTrue( - addr < address(1) || (addr > address(9) && addr < address(0x4200000000000000000000000000000000000000)) - || addr > address(0x4200000000000000000000000000000000000800) - ); - } - - function _assumePayable(address addr) public { - assumePayable(addr); + function testAssumeAddressIsNot(address addr) external { + // skip over Payable and NonPayable enums + for (uint8 i = 2; i < uint8(type(AddressType).max); i++) { + assumeAddressIsNot(addr, AddressType(i)); + } + assertTrue(addr != address(0)); + assertTrue(addr < address(1) || addr > address(9)); + assertTrue(addr != address(vm) || addr != 0x000000000000000000636F6e736F6c652e6c6f67); } function testAssumePayable() external { // We deploy a mock version so we can properly test the revert. StdCheatsMock stdCheatsMock = new StdCheatsMock(); + // all should revert since these addresses are not payable // VM address @@ -363,13 +362,49 @@ contract StdCheatsTest is Test { // Create2Deployer vm.expectRevert(); stdCheatsMock.exposed_assumePayable(0x4e59b44847b379578588920cA78FbF26c0B4956C); + + // all should pass since these addresses are payable + + // vitalik.eth + stdCheatsMock.exposed_assumePayable(0xd8dA6BF26964aF9D7eEd9e03E53415D37aA96045); + + // mock payable contract + MockContractPayable cp = new MockContractPayable(); + stdCheatsMock.exposed_assumePayable(address(cp)); } - function testAssumePayable(address addr) external { - assumePayable(addr); + function testAssumeNotPayable() external { + // We deploy a mock version so we can properly test the revert. + StdCheatsMock stdCheatsMock = new StdCheatsMock(); + + // all should pass since these addresses are not payable + + // VM address + stdCheatsMock.exposed_assumeNotPayable(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D); + + // Console address + stdCheatsMock.exposed_assumeNotPayable(0x000000000000000000636F6e736F6c652e6c6f67); + + // Create2Deployer + stdCheatsMock.exposed_assumeNotPayable(0x4e59b44847b379578588920cA78FbF26c0B4956C); + + // all should revert since these addresses are payable + + // vitalik.eth + vm.expectRevert(); + stdCheatsMock.exposed_assumeNotPayable(0xd8dA6BF26964aF9D7eEd9e03E53415D37aA96045); + + // mock payable contract + MockContractPayable cp = new MockContractPayable(); + vm.expectRevert(); + stdCheatsMock.exposed_assumeNotPayable(address(cp)); + } + + function testAssumeNotPrecompile(address addr) external { + assumeNotPrecompile(addr, getChain("optimism_goerli").chainId); assertTrue( - addr != 0x7109709ECfa91a80626fF3989D68f67F5b1DD12D && addr != 0x000000000000000000636F6e736F6c652e6c6f67 - && addr != 0x4e59b44847b379578588920cA78FbF26c0B4956C + addr < address(1) || (addr > address(9) && addr < address(0x4200000000000000000000000000000000000000)) + || addr > address(0x4200000000000000000000000000000000000800) ); } @@ -406,6 +441,10 @@ contract StdCheatsMock is StdCheats { assumePayable(addr); } + function exposed_assumeNotPayable(address addr) external { + assumeNotPayable(addr); + } + // We deploy a mock version so we can properly test expected reverts. function exposed_assumeNotBlacklisted(address token, address addr) external view { return assumeNotBlacklisted(token, addr); @@ -557,3 +596,7 @@ contract MockContractWithConstructorArgs { z = _z; } } + +contract MockContractPayable { + receive() external payable {} +}