diff --git a/packages/foundry/contracts/hooks/BalancerFun.sol b/packages/foundry/contracts/hooks/BalancerFun.sol new file mode 100644 index 00000000..dfd0b916 --- /dev/null +++ b/packages/foundry/contracts/hooks/BalancerFun.sol @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import { IHooks } from "@balancer-labs/v3-interfaces/contracts/vault/IHooks.sol"; +import { IVault } from "@balancer-labs/v3-interfaces/contracts/vault/IVault.sol"; +import { + AfterSwapParams, + LiquidityManagement, + TokenConfig, + PoolSwapParams, + RemoveLiquidityKind, + HookFlags, + SwapKind +} from "@balancer-labs/v3-interfaces/contracts/vault/VaultTypes.sol"; +import { VaultGuard } from "@balancer-labs/v3-vault/contracts/VaultGuard.sol"; +import { BaseHooks } from "@balancer-labs/v3-vault/contracts/BaseHooks.sol"; + +/** + * @title BalancerFun Hook Contract + * @notice Implements custom hooks for Balancer V3 liquidity pools to enforce swap limits and prevent liquidity removal. + */ +contract BalancerFun is BaseHooks, VaultGuard { + /// @notice The token that is managed by this hook. + IERC20 public immutable token; + + /// @notice The maximum amount that can be swapped in a single block. + uint public immutable maxSwapAmount; + + /// @notice A mapping that tracks the total amount of tokens sold per block. + mapping(uint => uint) public blockToTotalSold; + + /// @notice Event emitted when the BalancerFun hook is registered. + /// @param hooksContract The address of the hooks contract. + /// @param pool The address of the pool where the hook is registered. + event BalancerFunHookRegistered(address indexed hooksContract, address indexed pool); + + /// @notice Error thrown when a swap exceeds the maximum allowed swap amount. + error MaximumSwapExceeded(); + + /// @notice Error thrown when an attempt is made to remove liquidity, which is not allowed. + error LiquidityIsLocked(); + + /** + * @notice Constructor for the BalancerFun contract. + * @param vault The Balancer Vault contract. + * @param _token The ERC20 token that is managed by this hook. + */ + constructor(IVault vault, IERC20 _token) VaultGuard(vault) { + token = _token; + maxSwapAmount = 1_000_000 ether * 3 / 100; // 3% of total supply, could use token.totalSupply() + } + + /** + * @notice Returns the hook flags indicating which hook functions should be called. + * @return hookFlags The HookFlags struct indicating which hook functions are enabled. + */ + function getHookFlags() public pure override returns (HookFlags memory) { + HookFlags memory hookFlags; + hookFlags.shouldCallAfterSwap = true; + hookFlags.shouldCallAfterRemoveLiquidity = true; + return hookFlags; + } + + /** + * @notice Called when the hook is registered to a pool. + * @param pool The address of the pool to which the hook is being registered. + * @return success Boolean indicating whether the registration was successful. + */ + function onRegister( + address, + address pool, + TokenConfig[] memory, + LiquidityManagement calldata + ) public override onlyVault returns (bool) { + emit BalancerFunHookRegistered(address(this), pool); + return true; + } + + /** + * @notice Called after a swap is performed in the pool. + * @param params The parameters for the swap, including token addresses and amounts. + * @return success Boolean indicating if the swap was successful. + * @return amountCalculatedRaw The calculated amount after the swap. + */ + function onAfterSwap( + AfterSwapParams calldata params + ) public override onlyVault returns (bool, uint) { + if (address(params.tokenIn) == address(token)) { + uint currentBlockSold = blockToTotalSold[block.number]; + if (currentBlockSold + params.amountInScaled18 > maxSwapAmount * 1e18) { + revert MaximumSwapExceeded(); + } + blockToTotalSold[block.number] = currentBlockSold + params.amountInScaled18; + } + return (true, params.amountCalculatedRaw); + } + + /** + * @notice Called after an attempt to remove liquidity from the pool. + * @dev This function always reverts with `LiquidityIsLocked()` error to prevent liquidity removal. + * @return success Boolean indicating if the function succeeded. + * @return emptyArray An empty array to satisfy return requirements. + */ + function onAfterRemoveLiquidity( + address, + address, + RemoveLiquidityKind, + uint, + uint[] memory, + uint[] memory, + uint[] memory, + bytes memory + ) public view override onlyVault returns (bool, uint[] memory) { + revert LiquidityIsLocked(); + } +} \ No newline at end of file diff --git a/packages/foundry/test/BalancerFun.t.sol b/packages/foundry/test/BalancerFun.t.sol new file mode 100644 index 00000000..b04168b3 --- /dev/null +++ b/packages/foundry/test/BalancerFun.t.sol @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import "forge-std/Test.sol"; + +import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; + +import { IRouter } from "@balancer-labs/v3-interfaces/contracts/vault/IRouter.sol"; +import { IVault } from "@balancer-labs/v3-interfaces/contracts/vault/IVault.sol"; +import { + LiquidityManagement, + PoolRoleAccounts, + SwapKind +} from "@balancer-labs/v3-interfaces/contracts/vault/VaultTypes.sol"; + +import { CastingHelpers } from "@balancer-labs/v3-solidity-utils/contracts/helpers/CastingHelpers.sol"; +import { FixedPoint } from "@balancer-labs/v3-solidity-utils/contracts/math/FixedPoint.sol"; + +import { BaseVaultTest } from "@balancer-labs/v3-vault/test/foundry/utils/BaseVaultTest.sol"; +import { PoolMock } from "@balancer-labs/v3-vault/contracts/test/PoolMock.sol"; + +import { BalancerFun } from "../contracts/hooks/BalancerFun.sol"; + +/** + * @title BalancerFunTest + * @notice Unit tests for the BalancerFun contract. + * @dev Inherits from BaseVaultTest to perform setup and test BalancerFun interactions. + */ +contract BalancerFunTest is BaseVaultTest { + using CastingHelpers for address[]; + using FixedPoint for uint; + + uint internal daiIdx; + uint internal usdcIdx; + + /** + * @notice Sets up the test environment. + * @dev Overrides BaseVaultTest's setUp function to initialize token indexes. + */ + function setUp() public virtual override { + BaseVaultTest.setUp(); + (daiIdx, usdcIdx) = getSortedIndexes(address(dai), address(usdc)); + } + + /** + * @notice Creates a new BalancerFun hook for testing. + * @dev Deploys a new instance of BalancerFun and sets it as the hook for the pool. + * @return address The address of the newly created hook. + */ + function createHook() internal override returns (address) { + // lp will be the owner of the hook. Only the owner can set hook fee percentages. + vm.prank(lp); + BalancerFun hook = new BalancerFun(IVault(address(vault)), IERC20(address(router))); + return address(hook); + } + + /** + * @notice Creates a new pool with custom liquidity management settings. + * @dev Overrides the pool creation to disable unbalanced liquidity by setting liquidityManagement. + * @param tokens The tokens to be used in the pool. + * @param label A label for the pool. + * @return address The address of the newly created pool. + */ + function _createPool(address[] memory tokens, string memory label) internal override returns (address) { + PoolMock newPool = new PoolMock(IVault(address(vault)), "Balancer.Fun Pool", "BALFUN"); + vm.label(address(newPool), label); + PoolRoleAccounts memory roleAccounts; + roleAccounts.poolCreator = lp; + LiquidityManagement memory liquidityManagement; + factoryMock.registerPool( + address(newPool), + vault.buildTokenConfig(tokens.asIERC20()), + roleAccounts, + poolHooksContract, + liquidityManagement + ); + + return address(newPool); + } + + /** + * @notice Tests the setup of the contract. + * @dev Verifies that the setup has been completed successfully. + */ + function testSetUp() public { + assertEq(daiIdx, 0, "SetUp has failed"); + } + + /** + * @notice Tests a swap operation. + * @dev Executes a swap and verifies the balance changes for Alice. + */ + function testSwap() public { + ( + BaseVaultTest.Balances memory balancesBefore, + BaseVaultTest.Balances memory balancesAfter, + uint swapAmount, + uint[] memory accruedFees, + uint iterations + ) = _executeSwap(1 ether); + + assertEq( + balancesBefore.aliceTokens[daiIdx] - balancesAfter.aliceTokens[daiIdx], + swapAmount, + "Alice DAI balance is wrong" + ); + } + + /** + * @notice Executes a swap operation. + * @dev Performs a swap of the specified amount and returns relevant data. + * @param _swapAmount The amount to be swapped. + * @return balancesBefore The balances before the swap. + * @return balancesAfter The balances after the swap. + * @return swapAmount The amount that was swapped. + * @return accruedFees The accrued fees during the swap. + * @return iterations The number of iterations performed. + */ + function _executeSwap(uint _swapAmount) private returns ( + BaseVaultTest.Balances memory balancesBefore, + BaseVaultTest.Balances memory balancesAfter, + uint swapAmount, + uint[] memory accruedFees, + uint iterations + ) + { + vm.prank(lp); + balancesBefore = getBalances(alice); + bytes4 routerMethod; + routerMethod = IRouter.swapSingleTokenExactIn.selector; + uint amountGiven = _swapAmount; + + vm.prank(alice); + (bool success, ) = address(router).call( + abi.encodeWithSelector( + routerMethod, + address(pool), + dai, + usdc, + amountGiven, + amountGiven, + MAX_UINT256, + false, + bytes("") + ) + ); + + assertTrue(success, "Swap has failed"); + balancesAfter = getBalances(alice); + swapAmount = _swapAmount; + } +}