diff --git a/.forge-snapshots/BinPoolManagerBytecodeSize.snap b/.forge-snapshots/BinPoolManagerBytecodeSize.snap index 20354dc..0444aa9 100644 --- a/.forge-snapshots/BinPoolManagerBytecodeSize.snap +++ b/.forge-snapshots/BinPoolManagerBytecodeSize.snap @@ -1 +1 @@ -24457 \ No newline at end of file +24421 \ No newline at end of file diff --git a/.forge-snapshots/CLPoolManagerBytecodeSize.snap b/.forge-snapshots/CLPoolManagerBytecodeSize.snap index 7d7366f..4de01e1 100644 --- a/.forge-snapshots/CLPoolManagerBytecodeSize.snap +++ b/.forge-snapshots/CLPoolManagerBytecodeSize.snap @@ -1 +1 @@ -21334 \ No newline at end of file +21307 \ No newline at end of file diff --git a/src/ProtocolFees.sol b/src/ProtocolFees.sol index f173512..8a3d1ef 100644 --- a/src/ProtocolFees.sol +++ b/src/ProtocolFees.sol @@ -88,8 +88,7 @@ abstract contract ProtocolFees is IProtocolFees, Owner { override returns (uint256 amountCollected) { - // todo: remove msg.sender access to collectProtocolFees - if (msg.sender != owner() && msg.sender != address(protocolFeeController)) revert InvalidCaller(); + if (msg.sender != address(protocolFeeController)) revert InvalidCaller(); amountCollected = (amount == 0) ? protocolFeesAccrued[currency] : amount; protocolFeesAccrued[currency] -= amountCollected; diff --git a/test/ProtocolFees.t.sol b/test/ProtocolFees.t.sol index aad9602..e2afc60 100644 --- a/test/ProtocolFees.t.sol +++ b/test/ProtocolFees.t.sol @@ -184,11 +184,17 @@ contract ProtocolFeesTest is Test { assertEq(protocolFee1, 1e15); } - function test_CollectProtocolFee_OnlyOwnerOrFeeController() public { + function test_CollectProtocolFee_OnlyFeeController() public { + // random user vm.expectRevert(IProtocolFees.InvalidCaller.selector); - vm.prank(address(alice)); poolManager.collectProtocolFees(alice, Currency.wrap(address(token0)), 1e18); + + // owner + address pmOwner = poolManager.owner(); + vm.expectRevert(IProtocolFees.InvalidCaller.selector); + vm.prank(pmOwner); + poolManager.collectProtocolFees(alice, Currency.wrap(address(token0)), 1e18); } function test_CollectProtocolFee() public { @@ -212,7 +218,7 @@ contract ProtocolFeesTest is Test { assertEq(token1.balanceOf(address(vault)), 1e15); // collect - vm.prank(address(feeController)); + vm.startPrank(address(feeController)); poolManager.collectProtocolFees(alice, Currency.wrap(address(token0)), 1e15); poolManager.collectProtocolFees(alice, Currency.wrap(address(token1)), 1e15); diff --git a/test/pool-cl/CLPoolManager.t.sol b/test/pool-cl/CLPoolManager.t.sol index adeb2f2..af2ab2c 100644 --- a/test/pool-cl/CLPoolManager.t.sol +++ b/test/pool-cl/CLPoolManager.t.sol @@ -2688,46 +2688,6 @@ contract CLPoolManagerTest is Test, NoIsolate, Deployers, TokenFixture, GasSnaps assertEq(slot0.protocolFee, protocolFee); } - function testCollectProtocolFees_ERC20_allowsOwnerToAccumulateFees() public { - // protocol fee 0.1% - uint24 protocolFee = ProtocolFeeLibrary.MAX_PROTOCOL_FEE | (uint24(ProtocolFeeLibrary.MAX_PROTOCOL_FEE) << 12); - uint256 expectedProtocolFees = - uint256(10000) * ProtocolFeeLibrary.MAX_PROTOCOL_FEE / ProtocolFeeLibrary.PIPS_DENOMINATOR; - - PoolKey memory key = PoolKey({ - currency0: currency0, - currency1: currency1, - // 0.3% lp fee - fee: 3000, - hooks: IHooks(address(0)), - poolManager: poolManager, - parameters: bytes32(uint256(10) << 16) - }); - poolManager.setProtocolFeeController(IProtocolFeeController(address(feeController))); - feeController.setProtocolFeeForPool(key.toId(), protocolFee); - - poolManager.initialize(key, SQRT_RATIO_1_1, ZERO_BYTES); - (CLPool.Slot0 memory slot0,,,) = poolManager.pools(key.toId()); - assertEq(slot0.protocolFee, protocolFee); - - ICLPoolManager.ModifyLiquidityParams memory params = - ICLPoolManager.ModifyLiquidityParams(-120, 120, 10 ether, 0); - router.modifyPosition(key, params, ZERO_BYTES); - router.swap( - key, - ICLPoolManager.SwapParams(true, 10000, SQRT_RATIO_1_2), - CLPoolManagerRouter.SwapTestSettings(true, true), - ZERO_BYTES - ); - - assertEq(poolManager.protocolFeesAccrued(currency0), expectedProtocolFees); - assertEq(poolManager.protocolFeesAccrued(currency1), 0); - assertEq(currency0.balanceOf(address(1)), 0); - poolManager.collectProtocolFees(address(1), currency0, expectedProtocolFees); - assertEq(currency0.balanceOf(address(1)), expectedProtocolFees); - assertEq(poolManager.protocolFeesAccrued(currency0), 0); - } - function testCollectProtocolFees_ERC20_returnsAllFeesIf0IsProvidedAsParameter() public { // protocol fee 0.1% uint24 protocolFee = ProtocolFeeLibrary.MAX_PROTOCOL_FEE | (uint24(ProtocolFeeLibrary.MAX_PROTOCOL_FEE) << 12); @@ -2762,12 +2722,13 @@ contract CLPoolManagerTest is Test, NoIsolate, Deployers, TokenFixture, GasSnaps assertEq(poolManager.protocolFeesAccrued(currency0), expectedProtocolFees); assertEq(poolManager.protocolFeesAccrued(currency1), 0); assertEq(currency0.balanceOf(address(1)), 0); + vm.prank(address(feeController)); poolManager.collectProtocolFees(address(1), currency0, 0); assertEq(currency0.balanceOf(address(1)), expectedProtocolFees); assertEq(poolManager.protocolFeesAccrued(currency0), 0); } - function testCollectProtocolFees_nativeToken_allowsOwnerToAccumulateFees() public { + function testCollectProtocolFees_nativeToken_allowsFeeControllerToAccumulateFees() public { // protocol fee 0.1% uint24 protocolFee = ProtocolFeeLibrary.MAX_PROTOCOL_FEE | (uint24(ProtocolFeeLibrary.MAX_PROTOCOL_FEE) << 12); uint256 expectedProtocolFees = @@ -2802,6 +2763,7 @@ contract CLPoolManagerTest is Test, NoIsolate, Deployers, TokenFixture, GasSnaps assertEq(poolManager.protocolFeesAccrued(nativeCurrency), expectedProtocolFees); assertEq(poolManager.protocolFeesAccrued(currency1), 0); assertEq(nativeCurrency.balanceOf(address(1)), 0); + vm.prank(address(feeController)); poolManager.collectProtocolFees(address(1), nativeCurrency, expectedProtocolFees); assertEq(nativeCurrency.balanceOf(address(1)), expectedProtocolFees); assertEq(poolManager.protocolFeesAccrued(nativeCurrency), 0); @@ -2842,6 +2804,7 @@ contract CLPoolManagerTest is Test, NoIsolate, Deployers, TokenFixture, GasSnaps assertEq(poolManager.protocolFeesAccrued(nativeCurrency), expectedProtocolFees); assertEq(poolManager.protocolFeesAccrued(currency1), 0); assertEq(nativeCurrency.balanceOf(address(1)), 0); + vm.prank(address(feeController)); poolManager.collectProtocolFees(address(1), nativeCurrency, 0); assertEq(nativeCurrency.balanceOf(address(1)), expectedProtocolFees); assertEq(poolManager.protocolFeesAccrued(nativeCurrency), 0);