Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Refactor assumeNot* std cheats #407

Merged
merged 18 commits into from
Jul 11, 2023
Merged
109 changes: 99 additions & 10 deletions src/StdCheats.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import {Vm} from "./Vm.sol";
abstract contract StdCheatsSafe {
Vm private constant vm = Vm(address(uint160(uint256(keccak256("hevm cheat code")))));

uint256 private constant UINT256_MAX =
115792089237316195423570985008687907853269984665640564039457584007913129639935;

bool private gasMeteringOff;

// Data structures to parse Transaction objects from the broadcast artifact
Expand Down Expand Up @@ -193,6 +196,14 @@ abstract contract StdCheatsSafe {
uint256 key;
}

enum AddressType {
mds1 marked this conversation as resolved.
Show resolved Hide resolved
Payable,
NonPayable,
ZeroAddress,
Precompile,
ForgeAddress
}

// Checks that `addr` is not blacklisted by token contracts that have a blacklist.
function assumeNotBlacklisted(address token, address addr) internal view virtual {
// Nothing to check if `token` is not a contract.
Expand Down Expand Up @@ -222,11 +233,91 @@ abstract contract StdCheatsSafe {
assumeNotBlacklisted(token, addr);
}

function assumeNoPrecompiles(address addr) internal pure virtual {
assumeNoPrecompiles(addr, _pureChainId());
function assumeAddressIsNot(address addr, AddressType addressType) internal virtual {
if (addressType == AddressType.Payable) {
assumeNotPayable(addr);
} else if (addressType == AddressType.NonPayable) {
assumePayable(addr);
} else if (addressType == AddressType.ZeroAddress) {
assumeNotZeroAddress(addr);
} else if (addressType == AddressType.Precompile) {
assumeNotPrecompile(addr);
} else if (addressType == AddressType.ForgeAddress) {
assumeNotForgeAddress(addr);
}
}

function assumeAddressIsNot(address addr, AddressType addressType1, AddressType addressType2) internal virtual {
assumeAddressIsNot(addr, addressType1);
assumeAddressIsNot(addr, addressType2);
}

function assumeAddressIsNot(
address addr,
AddressType addressType1,
AddressType addressType2,
AddressType addressType3
) internal virtual {
assumeAddressIsNot(addr, addressType1);
assumeAddressIsNot(addr, addressType2);
assumeAddressIsNot(addr, addressType3);
}

function assumeAddressIsNot(
address addr,
AddressType addressType1,
AddressType addressType2,
AddressType addressType3,
AddressType addressType4
) internal virtual {
assumeAddressIsNot(addr, addressType1);
assumeAddressIsNot(addr, addressType2);
assumeAddressIsNot(addr, addressType3);
assumeAddressIsNot(addr, addressType4);
}

// This function checks whether an address, `addr`, is payable. It works by sending 1 wei to
// `addr` and checking the `success` return value.
// NOTE: This function may result in state changes depending on the fallback/receive logic
// implemented by `addr`, which should be taken into account when this function is used.
function _isPayable(address addr) private returns (bool) {
require(
addr.balance < UINT256_MAX,
"StdCheats _isPayable(address): Balance equals max uint256, so it cannot receive any more funds"
);
uint256 origBalanceTest = address(this).balance;
uint256 origBalanceAddr = address(addr).balance;

vm.deal(address(this), 1);
(bool success,) = payable(addr).call{value: 1}("");

// reset balances
vm.deal(address(this), origBalanceTest);
vm.deal(addr, origBalanceAddr);

return success;
}

// NOTE: This function may result in state changes depending on the fallback/receive logic
// implemented by `addr`, which should be taken into account when this function is used. See the
// `_isPayable` method for more information.
function assumePayable(address addr) internal virtual {
vm.assume(_isPayable(addr));
}

function assumeNotPayable(address addr) internal virtual {
vm.assume(!_isPayable(addr));
}

function assumeNoPrecompiles(address addr, uint256 chainId) internal pure virtual {
function assumeNotZeroAddress(address addr) internal pure virtual {
vm.assume(addr != address(0));
}

function assumeNotPrecompile(address addr) internal pure virtual {
assumeNotPrecompile(addr, _pureChainId());
}

function assumeNotPrecompile(address addr, uint256 chainId) internal pure virtual {
// Note: For some chains like Optimism these are technically predeploys (i.e. bytecode placed at a specific
// address), but the same rationale for excluding them applies so we include those too.

Expand All @@ -249,6 +340,11 @@ abstract contract StdCheatsSafe {
// forgefmt: disable-end
}

function assumeNotForgeAddress(address addr) internal pure virtual {
// vm and console addresses
vm.assume(addr != address(vm) || addr != 0x000000000000000000636F6e736F6c652e6c6f67);
}

function readEIP1559ScriptArtifact(string memory path)
internal
view
Expand Down Expand Up @@ -512,13 +608,6 @@ abstract contract StdCheatsSafe {
}
}

// a cheat for fuzzing addresses that are payable only
// see https://github.com/foundry-rs/foundry/issues/3631
function assumePayable(address addr) internal virtual {
(bool success,) = payable(addr).call{value: 0}("");
vm.assume(success);
}

// We use this complex approach of `_viewChainId` and `_pureChainId` to ensure there are no
// compiler warnings when accessing chain ID in any solidity version supported by forge-std. We
// can't simply access the chain ID in a normal view or pure function because the solc View Pure
Expand Down
71 changes: 57 additions & 14 deletions test/StdCheats.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -335,21 +335,20 @@ contract StdCheatsTest is Test {
return number;
}

function testAssumeNoPrecompiles(address addr) external {
assumeNoPrecompiles(addr, getChain("optimism_goerli").chainId);
assertTrue(
addr < address(1) || (addr > address(9) && addr < address(0x4200000000000000000000000000000000000000))
|| addr > address(0x4200000000000000000000000000000000000800)
);
}

function _assumePayable(address addr) public {
assumePayable(addr);
function testAssumeAddressIsNot(address addr) external {
// skip over Payable and NonPayable enums
for (uint8 i = 2; i < uint8(type(AddressType).max); i++) {
assumeAddressIsNot(addr, AddressType(i));
}
assertTrue(addr != address(0));
assertTrue(addr < address(1) || addr > address(9));
assertTrue(addr != address(vm) || addr != 0x000000000000000000636F6e736F6c652e6c6f67);
}

function testAssumePayable() external {
// We deploy a mock version so we can properly test the revert.
StdCheatsMock stdCheatsMock = new StdCheatsMock();

// all should revert since these addresses are not payable

// VM address
Expand All @@ -363,13 +362,49 @@ contract StdCheatsTest is Test {
// Create2Deployer
vm.expectRevert();
stdCheatsMock.exposed_assumePayable(0x4e59b44847b379578588920cA78FbF26c0B4956C);

// all should pass since these addresses are payable

// vitalik.eth
stdCheatsMock.exposed_assumePayable(0xd8dA6BF26964aF9D7eEd9e03E53415D37aA96045);

// mock payable contract
MockContractPayable cp = new MockContractPayable();
stdCheatsMock.exposed_assumePayable(address(cp));
}

function testAssumePayable(address addr) external {
assumePayable(addr);
function testAssumeNotPayable() external {
// We deploy a mock version so we can properly test the revert.
StdCheatsMock stdCheatsMock = new StdCheatsMock();

// all should pass since these addresses are not payable

// VM address
stdCheatsMock.exposed_assumeNotPayable(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D);

// Console address
stdCheatsMock.exposed_assumeNotPayable(0x000000000000000000636F6e736F6c652e6c6f67);

// Create2Deployer
stdCheatsMock.exposed_assumeNotPayable(0x4e59b44847b379578588920cA78FbF26c0B4956C);

// all should revert since these addresses are payable

// vitalik.eth
vm.expectRevert();
stdCheatsMock.exposed_assumeNotPayable(0xd8dA6BF26964aF9D7eEd9e03E53415D37aA96045);

// mock payable contract
MockContractPayable cp = new MockContractPayable();
vm.expectRevert();
stdCheatsMock.exposed_assumeNotPayable(address(cp));
}

function testAssumeNotPrecompile(address addr) external {
assumeNotPrecompile(addr, getChain("optimism_goerli").chainId);
assertTrue(
addr != 0x7109709ECfa91a80626fF3989D68f67F5b1DD12D && addr != 0x000000000000000000636F6e736F6c652e6c6f67
&& addr != 0x4e59b44847b379578588920cA78FbF26c0B4956C
addr < address(1) || (addr > address(9) && addr < address(0x4200000000000000000000000000000000000000))
|| addr > address(0x4200000000000000000000000000000000000800)
);
}

Expand Down Expand Up @@ -406,6 +441,10 @@ contract StdCheatsMock is StdCheats {
assumePayable(addr);
}

function exposed_assumeNotPayable(address addr) external {
assumeNotPayable(addr);
}

// We deploy a mock version so we can properly test expected reverts.
function exposed_assumeNotBlacklisted(address token, address addr) external view {
return assumeNotBlacklisted(token, addr);
Expand Down Expand Up @@ -557,3 +596,7 @@ contract MockContractWithConstructorArgs {
z = _z;
}
}

contract MockContractPayable {
receive() external payable {}
}