Skip to content

Commit

Permalink
Merge pull request fei-protocol#36 from fei-protocol/feat/mutableBooster
Browse files Browse the repository at this point in the history
make booster mutable
  • Loading branch information
Joeysantoro authored Mar 20, 2022
2 parents 310e882 + e527509 commit 20966c2
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
26 changes: 14 additions & 12 deletions src/FlywheelCore.sol
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ contract FlywheelCore is Auth {

event AddStrategy(address indexed newStrategy);

event FlywheelRewardsUpdate(address indexed oldFlywheelRewards, address indexed newFlywheelRewards);
event FlywheelRewardsUpdate(address indexed newFlywheelRewards);

event FlywheelBoosterUpdate(address indexed newBooster);

event AccrueRewards(ERC20 indexed cToken, address indexed owner, uint rewardsDelta, uint rewardsIndex);

Expand All @@ -45,7 +47,7 @@ contract FlywheelCore is Auth {
IFlywheelRewards public flywheelRewards;

/// @notice optional booster module for calculating virtual balances on strategies
IFlywheelBooster public immutable flywheelBooster;
IFlywheelBooster public flywheelBooster;

/// @notice the fixed point factor of flywheel
uint224 public constant ONE = 1e18;
Expand All @@ -59,9 +61,6 @@ contract FlywheelCore is Auth {
/// @notice The accrued but not yet transferred rewards for each user
mapping(address => uint256) public rewardsAccrued;

/// @dev immutable flag for short-circuiting boosting logic
bool internal immutable applyBoosting;

constructor(
ERC20 _rewardToken,
IFlywheelRewards _flywheelRewards,
Expand All @@ -72,8 +71,6 @@ contract FlywheelCore is Auth {
rewardToken = _rewardToken;
flywheelRewards = _flywheelRewards;
flywheelBooster = _flywheelBooster;

applyBoosting = address(_flywheelBooster) != address(0);
}

/// @notice initialize a new strategy
Expand All @@ -89,11 +86,16 @@ contract FlywheelCore is Auth {

/// @notice swap out the flywheel rewards contract
function setFlywheelRewards(IFlywheelRewards newFlywheelRewards) external requiresAuth {
address oldFlywheelRewards = address(flywheelRewards);

flywheelRewards = newFlywheelRewards;

emit FlywheelRewardsUpdate(oldFlywheelRewards, address(newFlywheelRewards));
emit FlywheelRewardsUpdate(address(newFlywheelRewards));
}

/// @notice swap out the flywheel booster contract
function setBooster(IFlywheelBooster newBooster) external requiresAuth {
flywheelBooster = newBooster;

emit FlywheelBoosterUpdate(address(newBooster));
}

/// @notice accrue rewards for a single user on a strategy
Expand Down Expand Up @@ -137,7 +139,7 @@ contract FlywheelCore is Auth {
rewardsState = state;
if (strategyRewardsAccrued > 0) {
// use the booster or token supply to calculate reward index denominator
uint256 supplyTokens = applyBoosting ? flywheelBooster.boostedTotalSupply(strategy): strategy.totalSupply();
uint256 supplyTokens = address(flywheelBooster) != address(0) ? flywheelBooster.boostedTotalSupply(strategy): strategy.totalSupply();

// accumulate rewards per token onto the index, multiplied by fixed-point factor
rewardsState = RewardsState({
Expand Down Expand Up @@ -165,7 +167,7 @@ contract FlywheelCore is Auth {

uint224 deltaIndex = supplyIndex - supplierIndex;
// use the booster or token balance to calculate reward balance multiplier
uint256 supplierTokens = applyBoosting ? flywheelBooster.boostedBalanceOf(strategy, user) : strategy.balanceOf(user);
uint256 supplierTokens = address(flywheelBooster) != address(0) ? flywheelBooster.boostedBalanceOf(strategy, user) : strategy.balanceOf(user);

// accumulate rewards by multiplying user tokens by rewardsPerToken index and adding on unclaimed
uint256 supplierDelta = supplierTokens * deltaIndex / ONE;
Expand Down
14 changes: 13 additions & 1 deletion src/test/FlywheelTest.sol
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,23 @@ contract FlywheelTest is DSTestPlus {
require(flywheel.flywheelRewards() == IFlywheelRewards(address(1)));
}

function testFailSetFlywheelRewards() public {
function testSetFlywheelRewardsUnauthorized() public {
hevm.prank(address(1));
hevm.expectRevert(bytes("UNAUTHORIZED"));
flywheel.setFlywheelRewards(IFlywheelRewards(address(1)));
}

function testSetFlywheelBooster() public {
flywheel.setBooster(IFlywheelBooster(address(1)));
require(flywheel.flywheelBooster() == IFlywheelBooster(address(1)));
}

function testSetFlywheelBoosterUnauthorized() public {
hevm.prank(address(1));
hevm.expectRevert(bytes("UNAUTHORIZED"));
flywheel.setBooster(IFlywheelBooster(address(1)));
}

function testAccrue() public {
strategy.mint(user, 1 ether);
strategy.mint(user2, 3 ether);
Expand Down

0 comments on commit 20966c2

Please sign in to comment.