From e527509b8967f3d88b459c9291527318e9de0bf7 Mon Sep 17 00:00:00 2001 From: Joey Santoro Date: Sat, 19 Mar 2022 18:13:35 -0700 Subject: [PATCH] make booster mutable --- src/FlywheelCore.sol | 26 ++++++++++++++------------ src/test/FlywheelTest.sol | 14 +++++++++++++- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/src/FlywheelCore.sol b/src/FlywheelCore.sol index 97737d3..c10e5ff 100644 --- a/src/FlywheelCore.sol +++ b/src/FlywheelCore.sol @@ -24,7 +24,9 @@ contract FlywheelCore is Auth { event AddStrategy(address indexed newStrategy); - event FlywheelRewardsUpdate(address indexed oldFlywheelRewards, address indexed newFlywheelRewards); + event FlywheelRewardsUpdate(address indexed newFlywheelRewards); + + event FlywheelBoosterUpdate(address indexed newBooster); event AccrueRewards(ERC20 indexed cToken, address indexed owner, uint rewardsDelta, uint rewardsIndex); @@ -45,7 +47,7 @@ contract FlywheelCore is Auth { IFlywheelRewards public flywheelRewards; /// @notice optional booster module for calculating virtual balances on strategies - IFlywheelBooster public immutable flywheelBooster; + IFlywheelBooster public flywheelBooster; /// @notice the fixed point factor of flywheel uint224 public constant ONE = 1e18; @@ -59,9 +61,6 @@ contract FlywheelCore is Auth { /// @notice The accrued but not yet transferred rewards for each user mapping(address => uint256) public rewardsAccrued; - /// @dev immutable flag for short-circuiting boosting logic - bool internal immutable applyBoosting; - constructor( ERC20 _rewardToken, IFlywheelRewards _flywheelRewards, @@ -72,8 +71,6 @@ contract FlywheelCore is Auth { rewardToken = _rewardToken; flywheelRewards = _flywheelRewards; flywheelBooster = _flywheelBooster; - - applyBoosting = address(_flywheelBooster) != address(0); } /// @notice initialize a new strategy @@ -89,11 +86,16 @@ contract FlywheelCore is Auth { /// @notice swap out the flywheel rewards contract function setFlywheelRewards(IFlywheelRewards newFlywheelRewards) external requiresAuth { - address oldFlywheelRewards = address(flywheelRewards); - flywheelRewards = newFlywheelRewards; - emit FlywheelRewardsUpdate(oldFlywheelRewards, address(newFlywheelRewards)); + emit FlywheelRewardsUpdate(address(newFlywheelRewards)); + } + + /// @notice swap out the flywheel booster contract + function setBooster(IFlywheelBooster newBooster) external requiresAuth { + flywheelBooster = newBooster; + + emit FlywheelBoosterUpdate(address(newBooster)); } /// @notice accrue rewards for a single user on a strategy @@ -137,7 +139,7 @@ contract FlywheelCore is Auth { rewardsState = state; if (strategyRewardsAccrued > 0) { // use the booster or token supply to calculate reward index denominator - uint256 supplyTokens = applyBoosting ? flywheelBooster.boostedTotalSupply(strategy): strategy.totalSupply(); + uint256 supplyTokens = address(flywheelBooster) != address(0) ? flywheelBooster.boostedTotalSupply(strategy): strategy.totalSupply(); // accumulate rewards per token onto the index, multiplied by fixed-point factor rewardsState = RewardsState({ @@ -165,7 +167,7 @@ contract FlywheelCore is Auth { uint224 deltaIndex = supplyIndex - supplierIndex; // use the booster or token balance to calculate reward balance multiplier - uint256 supplierTokens = applyBoosting ? flywheelBooster.boostedBalanceOf(strategy, user) : strategy.balanceOf(user); + uint256 supplierTokens = address(flywheelBooster) != address(0) ? flywheelBooster.boostedBalanceOf(strategy, user) : strategy.balanceOf(user); // accumulate rewards by multiplying user tokens by rewardsPerToken index and adding on unclaimed uint256 supplierDelta = supplierTokens * deltaIndex / ONE; diff --git a/src/test/FlywheelTest.sol b/src/test/FlywheelTest.sol index fece727..95ad456 100644 --- a/src/test/FlywheelTest.sol +++ b/src/test/FlywheelTest.sol @@ -53,11 +53,23 @@ contract FlywheelTest is DSTestPlus { require(flywheel.flywheelRewards() == IFlywheelRewards(address(1))); } - function testFailSetFlywheelRewards() public { + function testSetFlywheelRewardsUnauthorized() public { hevm.prank(address(1)); + hevm.expectRevert(bytes("UNAUTHORIZED")); flywheel.setFlywheelRewards(IFlywheelRewards(address(1))); } + function testSetFlywheelBooster() public { + flywheel.setBooster(IFlywheelBooster(address(1))); + require(flywheel.flywheelBooster() == IFlywheelBooster(address(1))); + } + + function testSetFlywheelBoosterUnauthorized() public { + hevm.prank(address(1)); + hevm.expectRevert(bytes("UNAUTHORIZED")); + flywheel.setBooster(IFlywheelBooster(address(1))); + } + function testAccrue() public { strategy.mint(user, 1 ether); strategy.mint(user2, 3 ether);