diff --git a/src/TokenizedStrategy.sol b/src/TokenizedStrategy.sol index 7ea1afc..2347e4a 100644 --- a/src/TokenizedStrategy.sol +++ b/src/TokenizedStrategy.sol @@ -746,26 +746,25 @@ contract TokenizedStrategy { /** * @notice Total number of underlying assets that can - * be deposited by `_owner` into the strategy, where `owner` - * corresponds to the receiver of a {deposit} call. + * be deposited into the strategy, where `receiver` + * corresponds to the receiver of the shares of a {deposit} call. * - * @param owner The address depositing. - * @return . The max that `owner` can deposit in `asset`. + * @param receiver The address receiving the shares. + * @return . The max that `receiver` can deposit in `asset`. */ - function maxDeposit(address owner) external view returns (uint256) { - return _maxDeposit(_strategyStorage(), owner); + function maxDeposit(address receiver) external view returns (uint256) { + return _maxDeposit(_strategyStorage(), receiver); } /** - * @notice Total number of shares that can be minted by `owner` - * into the strategy, where `_owner` corresponds to the receiver + * @notice Total number of shares that can be minted to `receiver` * of a {mint} call. * - * @param owner The address minting. - * @return _maxMint The max that `owner` can mint in shares. + * @param receiver The address receiving the shares. + * @return _maxMint The max that `receiver` can mint in shares. */ - function maxMint(address owner) external view returns (uint256) { - return _maxMint(_strategyStorage(), owner); + function maxMint(address receiver) external view returns (uint256) { + return _maxMint(_strategyStorage(), receiver); } /** @@ -840,12 +839,14 @@ contract TokenizedStrategy { uint256 assets, Math.Rounding _rounding ) internal view returns (uint256) { - // Saves an extra SLOAD if totalAssets() is non-zero. - uint256 totalAssets_ = _totalAssets(S); + // Saves an extra SLOAD if values are non-zero. uint256 totalSupply_ = _totalSupply(S); + // If supply is 0, PPS = 1. + if (totalSupply_ == 0) return assets; + uint256 totalAssets_ = _totalAssets(S); // If assets are 0 but supply is not PPS = 0. - if (totalAssets_ == 0) return totalSupply_ == 0 ? assets : 0; + if (totalAssets_ == 0) return 0; return assets.mulDiv(totalSupply_, totalAssets_, _rounding); } @@ -868,23 +869,23 @@ contract TokenizedStrategy { /// @dev Internal implementation of {maxDeposit}. function _maxDeposit( StrategyData storage S, - address owner + address receiver ) internal view returns (uint256) { - // Cannot deposit when shutdown. - if (S.shutdown) return 0; + // Cannot deposit when shutdown or to the strategy. + if (S.shutdown || receiver == address(this)) return 0; - return IBaseStrategy(address(this)).availableDepositLimit(owner); + return IBaseStrategy(address(this)).availableDepositLimit(receiver); } /// @dev Internal implementation of {maxMint}. function _maxMint( StrategyData storage S, - address owner + address receiver ) internal view returns (uint256 maxMint_) { - // Cannot mint when shutdown. - if (S.shutdown) return 0; + // Cannot mint when shutdown or to the strategy. + if (S.shutdown || receiver == address(this)) return 0; - maxMint_ = IBaseStrategy(address(this)).availableDepositLimit(owner); + maxMint_ = IBaseStrategy(address(this)).availableDepositLimit(receiver); if (maxMint_ != type(uint256).max) { maxMint_ = _convertToShares(S, maxMint_, Math.Rounding.Down); } @@ -956,8 +957,6 @@ contract TokenizedStrategy { uint256 assets, uint256 shares ) internal { - require(receiver != address(this), "ERC4626: mint to self"); - // Cache storage variables used more than once. ERC20 _asset = S.asset; diff --git a/src/test/Accounting.t.sol b/src/test/Accounting.t.sol index d851c51..c0ea7a6 100644 --- a/src/test/Accounting.t.sol +++ b/src/test/Accounting.t.sol @@ -439,10 +439,10 @@ contract AccountingTest is Setup { setFees(0, 0); mintAndDepositIntoStrategy(strategy, _address, _amount); - uint256 toLoose = (_amount * _lossFactor) / MAX_BPS; + uint256 toLose = (_amount * _lossFactor) / MAX_BPS; // Simulate a loss. vm.prank(address(yieldSource)); - asset.transfer(address(69), toLoose); + asset.transfer(address(69), toLose); vm.expectRevert("too much loss"); vm.prank(_address); @@ -465,13 +465,13 @@ contract AccountingTest is Setup { setFees(0, 0); mintAndDepositIntoStrategy(strategy, _address, _amount); - uint256 toLoose = (_amount * _lossFactor) / MAX_BPS; + uint256 toLose = (_amount * _lossFactor) / MAX_BPS; // Simulate a loss. vm.prank(address(yieldSource)); - asset.transfer(address(69), toLoose); + asset.transfer(address(69), toLose); uint256 beforeBalance = asset.balanceOf(_address); - uint256 expectedOut = _amount - toLoose; + uint256 expectedOut = _amount - toLose; // Withdraw the full amount before the loss is reported. vm.prank(_address); strategy.withdraw(_amount, _address, _address, _lossFactor); @@ -499,13 +499,13 @@ contract AccountingTest is Setup { setFees(0, 0); mintAndDepositIntoStrategy(strategy, _address, _amount); - uint256 toLoose = (_amount * _lossFactor) / MAX_BPS; + uint256 toLose = (_amount * _lossFactor) / MAX_BPS; // Simulate a loss. vm.prank(address(yieldSource)); - asset.transfer(address(69), toLoose); + asset.transfer(address(69), toLose); uint256 beforeBalance = asset.balanceOf(_address); - uint256 expectedOut = _amount - toLoose; + uint256 expectedOut = _amount - toLose; // Withdraw the full amount before the loss is reported. vm.prank(_address); strategy.redeem(_amount, _address, _address); @@ -533,10 +533,10 @@ contract AccountingTest is Setup { setFees(0, 0); mintAndDepositIntoStrategy(strategy, _address, _amount); - uint256 toLoose = (_amount * _lossFactor) / MAX_BPS; + uint256 toLose = (_amount * _lossFactor) / MAX_BPS; // Simulate a loss. vm.prank(address(yieldSource)); - asset.transfer(address(69), toLoose); + asset.transfer(address(69), toLose); vm.expectRevert("too much loss"); vm.prank(_address); @@ -559,13 +559,13 @@ contract AccountingTest is Setup { setFees(0, 0); mintAndDepositIntoStrategy(strategy, _address, _amount); - uint256 toLoose = (_amount * _lossFactor) / MAX_BPS; + uint256 toLose = (_amount * _lossFactor) / MAX_BPS; // Simulate a loss. vm.prank(address(yieldSource)); - asset.transfer(address(69), toLoose); + asset.transfer(address(69), toLose); uint256 beforeBalance = asset.balanceOf(_address); - uint256 expectedOut = _amount - toLoose; + uint256 expectedOut = _amount - toLose; // First set it to just under the expected loss. vm.expectRevert("too much loss"); @@ -613,4 +613,88 @@ contract AccountingTest is Setup { assertEq(asset.balanceOf(address(yieldSource)), _amount); } + + function test_deposit_zeroAssetsPositiveSupply_reverts( + address _address, + uint256 _amount + ) public { + _amount = bound(_amount, minFuzzAmount, maxFuzzAmount); + vm.assume( + _address != address(0) && + _address != address(strategy) && + _address != address(yieldSource) + ); + + setFees(0, 0); + mintAndDepositIntoStrategy(strategy, _address, _amount); + + uint256 toLose = _amount; + // Simulate a loss. + vm.prank(address(yieldSource)); + asset.transfer(address(69), toLose); + + vm.prank(keeper); + strategy.report(); + + // Should still have shares but no assets + checkStrategyTotals(strategy, 0, 0, 0, _amount); + + assertEq(strategy.balanceOf(_address), _amount); + assertEq(asset.balanceOf(address(strategy)), 0); + assertEq(asset.balanceOf(address(yieldSource)), 0); + + asset.mint(_address, _amount); + vm.prank(_address); + asset.approve(address(strategy), _amount); + + vm.expectRevert("ZERO_SHARES"); + vm.prank(_address); + strategy.deposit(_amount, _address); + + assertEq(strategy.convertToAssets(_amount), 0); + assertEq(strategy.convertToShares(_amount), 0); + assertEq(strategy.pricePerShare(), 0); + } + + function test_mint_zeroAssetsPositiveSupply_reverts( + address _address, + uint256 _amount + ) public { + _amount = bound(_amount, minFuzzAmount, maxFuzzAmount); + vm.assume( + _address != address(0) && + _address != address(strategy) && + _address != address(yieldSource) + ); + + setFees(0, 0); + mintAndDepositIntoStrategy(strategy, _address, _amount); + + uint256 toLose = _amount; + // Simulate a loss. + vm.prank(address(yieldSource)); + asset.transfer(address(69), toLose); + + vm.prank(keeper); + strategy.report(); + + // Should still have shares but no assets + checkStrategyTotals(strategy, 0, 0, 0, _amount); + + assertEq(strategy.balanceOf(_address), _amount); + assertEq(asset.balanceOf(address(strategy)), 0); + assertEq(asset.balanceOf(address(yieldSource)), 0); + + asset.mint(_address, _amount); + vm.prank(_address); + asset.approve(address(strategy), _amount); + + vm.expectRevert("ZERO_ASSETS"); + vm.prank(_address); + strategy.mint(_amount, _address); + + assertEq(strategy.convertToAssets(_amount), 0); + assertEq(strategy.convertToShares(_amount), 0); + assertEq(strategy.pricePerShare(), 0); + } }