Skip to content

Commit

Permalink
feat: only allow fee controller to collect protocolFee (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChefMist committed Sep 18, 2024
1 parent 6358370 commit 7b15a99
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .forge-snapshots/BinPoolManagerBytecodeSize.snap
Original file line number Diff line number Diff line change
@@ -1 +1 @@
24457
24421
2 changes: 1 addition & 1 deletion .forge-snapshots/CLPoolManagerBytecodeSize.snap
Original file line number Diff line number Diff line change
@@ -1 +1 @@
21334
21307
3 changes: 1 addition & 2 deletions src/ProtocolFees.sol
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ abstract contract ProtocolFees is IProtocolFees, Owner {
override
returns (uint256 amountCollected)
{
// todo: remove msg.sender access to collectProtocolFees
if (msg.sender != owner() && msg.sender != address(protocolFeeController)) revert InvalidCaller();
if (msg.sender != address(protocolFeeController)) revert InvalidCaller();

amountCollected = (amount == 0) ? protocolFeesAccrued[currency] : amount;
protocolFeesAccrued[currency] -= amountCollected;
Expand Down
12 changes: 9 additions & 3 deletions test/ProtocolFees.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,17 @@ contract ProtocolFeesTest is Test {
assertEq(protocolFee1, 1e15);
}

function test_CollectProtocolFee_OnlyOwnerOrFeeController() public {
function test_CollectProtocolFee_OnlyFeeController() public {
// random user
vm.expectRevert(IProtocolFees.InvalidCaller.selector);

vm.prank(address(alice));
poolManager.collectProtocolFees(alice, Currency.wrap(address(token0)), 1e18);

// owner
address pmOwner = poolManager.owner();
vm.expectRevert(IProtocolFees.InvalidCaller.selector);
vm.prank(pmOwner);
poolManager.collectProtocolFees(alice, Currency.wrap(address(token0)), 1e18);
}

function test_CollectProtocolFee() public {
Expand All @@ -212,7 +218,7 @@ contract ProtocolFeesTest is Test {
assertEq(token1.balanceOf(address(vault)), 1e15);

// collect
vm.prank(address(feeController));
vm.startPrank(address(feeController));
poolManager.collectProtocolFees(alice, Currency.wrap(address(token0)), 1e15);
poolManager.collectProtocolFees(alice, Currency.wrap(address(token1)), 1e15);

Expand Down
45 changes: 4 additions & 41 deletions test/pool-cl/CLPoolManager.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2688,46 +2688,6 @@ contract CLPoolManagerTest is Test, NoIsolate, Deployers, TokenFixture, GasSnaps
assertEq(slot0.protocolFee, protocolFee);
}

function testCollectProtocolFees_ERC20_allowsOwnerToAccumulateFees() public {
// protocol fee 0.1%
uint24 protocolFee = ProtocolFeeLibrary.MAX_PROTOCOL_FEE | (uint24(ProtocolFeeLibrary.MAX_PROTOCOL_FEE) << 12);
uint256 expectedProtocolFees =
uint256(10000) * ProtocolFeeLibrary.MAX_PROTOCOL_FEE / ProtocolFeeLibrary.PIPS_DENOMINATOR;

PoolKey memory key = PoolKey({
currency0: currency0,
currency1: currency1,
// 0.3% lp fee
fee: 3000,
hooks: IHooks(address(0)),
poolManager: poolManager,
parameters: bytes32(uint256(10) << 16)
});
poolManager.setProtocolFeeController(IProtocolFeeController(address(feeController)));
feeController.setProtocolFeeForPool(key.toId(), protocolFee);

poolManager.initialize(key, SQRT_RATIO_1_1, ZERO_BYTES);
(CLPool.Slot0 memory slot0,,,) = poolManager.pools(key.toId());
assertEq(slot0.protocolFee, protocolFee);

ICLPoolManager.ModifyLiquidityParams memory params =
ICLPoolManager.ModifyLiquidityParams(-120, 120, 10 ether, 0);
router.modifyPosition(key, params, ZERO_BYTES);
router.swap(
key,
ICLPoolManager.SwapParams(true, 10000, SQRT_RATIO_1_2),
CLPoolManagerRouter.SwapTestSettings(true, true),
ZERO_BYTES
);

assertEq(poolManager.protocolFeesAccrued(currency0), expectedProtocolFees);
assertEq(poolManager.protocolFeesAccrued(currency1), 0);
assertEq(currency0.balanceOf(address(1)), 0);
poolManager.collectProtocolFees(address(1), currency0, expectedProtocolFees);
assertEq(currency0.balanceOf(address(1)), expectedProtocolFees);
assertEq(poolManager.protocolFeesAccrued(currency0), 0);
}

function testCollectProtocolFees_ERC20_returnsAllFeesIf0IsProvidedAsParameter() public {
// protocol fee 0.1%
uint24 protocolFee = ProtocolFeeLibrary.MAX_PROTOCOL_FEE | (uint24(ProtocolFeeLibrary.MAX_PROTOCOL_FEE) << 12);
Expand Down Expand Up @@ -2762,12 +2722,13 @@ contract CLPoolManagerTest is Test, NoIsolate, Deployers, TokenFixture, GasSnaps
assertEq(poolManager.protocolFeesAccrued(currency0), expectedProtocolFees);
assertEq(poolManager.protocolFeesAccrued(currency1), 0);
assertEq(currency0.balanceOf(address(1)), 0);
vm.prank(address(feeController));
poolManager.collectProtocolFees(address(1), currency0, 0);
assertEq(currency0.balanceOf(address(1)), expectedProtocolFees);
assertEq(poolManager.protocolFeesAccrued(currency0), 0);
}

function testCollectProtocolFees_nativeToken_allowsOwnerToAccumulateFees() public {
function testCollectProtocolFees_nativeToken_allowsFeeControllerToAccumulateFees() public {
// protocol fee 0.1%
uint24 protocolFee = ProtocolFeeLibrary.MAX_PROTOCOL_FEE | (uint24(ProtocolFeeLibrary.MAX_PROTOCOL_FEE) << 12);
uint256 expectedProtocolFees =
Expand Down Expand Up @@ -2802,6 +2763,7 @@ contract CLPoolManagerTest is Test, NoIsolate, Deployers, TokenFixture, GasSnaps
assertEq(poolManager.protocolFeesAccrued(nativeCurrency), expectedProtocolFees);
assertEq(poolManager.protocolFeesAccrued(currency1), 0);
assertEq(nativeCurrency.balanceOf(address(1)), 0);
vm.prank(address(feeController));
poolManager.collectProtocolFees(address(1), nativeCurrency, expectedProtocolFees);
assertEq(nativeCurrency.balanceOf(address(1)), expectedProtocolFees);
assertEq(poolManager.protocolFeesAccrued(nativeCurrency), 0);
Expand Down Expand Up @@ -2842,6 +2804,7 @@ contract CLPoolManagerTest is Test, NoIsolate, Deployers, TokenFixture, GasSnaps
assertEq(poolManager.protocolFeesAccrued(nativeCurrency), expectedProtocolFees);
assertEq(poolManager.protocolFeesAccrued(currency1), 0);
assertEq(nativeCurrency.balanceOf(address(1)), 0);
vm.prank(address(feeController));
poolManager.collectProtocolFees(address(1), nativeCurrency, 0);
assertEq(nativeCurrency.balanceOf(address(1)), expectedProtocolFees);
assertEq(poolManager.protocolFeesAccrued(nativeCurrency), 0);
Expand Down

0 comments on commit 7b15a99

Please sign in to comment.