Skip to content

Commit

Permalink
test: test fee collector with clientTakeRate
Browse files Browse the repository at this point in the history
  • Loading branch information
cucupac committed Jan 30, 2024
1 parent 13ea2ef commit f5ee599
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 52 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ Tests:

- [ ] Invariant: the clientTakeRate + userTakeRate = clientRate
- [ ] Invariant: the totalTokenAmt - sum(clientFeesToken) = (1 - clientRate) \* totalTokenAmt
- [ ] Unit test setClientTakeRate()
- [ ] Unit test getUserSavings()
- [ ] Unit test FeeLib via Test Harness
- [ ] Account for userSavings in all affected FeeCollector unit and integration tests
- [x] Unit test setClientTakeRate()
- [x] Unit test getClientAllocations()
- [x] Unit test FeeLib via Test Harness
- [x] Account for userSavings in all affected FeeCollector unit and integration tests

Considerations:

Expand Down
18 changes: 13 additions & 5 deletions src/FeeCollector.sol
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@ contract FeeCollector is Ownable {
* @param _token The token to collect fees in (the collateral token of the calling Position contract).
* @param _amt The total amount of fees to collect.
*/
function collectFees(address _client, address _token, uint256 _amt) external payable {
function collectFees(address _client, address _token, uint256 _amt, uint256 _clientFee) external payable {
// 1. Transfer tokens to this contract
SafeTransferLib.safeTransferFrom(ERC20(_token), msg.sender, address(this), _amt);

// 2. Update client balances
if (_client != address(0)) {
uint256 clientFee = (_amt * clientRate) / 100;
balances[_client][_token] += clientFee;
totalClientBalances[_token] += clientFee;
balances[_client][_token] += _clientFee;
totalClientBalances[_token] += _clientFee;
}
}

Expand Down Expand Up @@ -81,10 +80,19 @@ contract FeeCollector is Ownable {
* @param _client The address where a client operator will receive protocols fees.
* @param _maxFee The maximum amount of fees the protocol will collect.
*/
function getUserSavings(address _client, uint256 _maxFee) public view returns (uint256 userSavings) {
function getClientAllocations(address _client, uint256 _maxFee)
public
view
returns (uint256 userSavings, uint256 clientFee)
{
// 1. Calculate user savings
uint256 userTakeRate = 100 - clientTakeRates[_client];
uint256 userPercentOfProtocolFee = (userTakeRate * clientRate) / 100;
userSavings = (userPercentOfProtocolFee * _maxFee) / 100;

// 2. Calculate client fee
uint256 maxClientFee = (_maxFee * clientRate) / 100;
clientFee = maxClientFee - userSavings;
}

/* ****************************************************************************
Expand Down
7 changes: 5 additions & 2 deletions src/interfaces/IFeeCollector.sol
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ interface IFeeCollector {
* @param _token The token to collect fees in (the collateral token of the calling Position contract).
* @param _amt The total amount of fees to collect.
*/
function collectFees(address _client, address _token, uint256 _amt) external payable;
function collectFees(address _client, address _token, uint256 _amt, uint256 _clientFee) external payable;
/**
* @notice Withdraw collected fees from this contract.
* @param _token The token address to withdraw.
Expand All @@ -64,7 +64,10 @@ interface IFeeCollector {
* @param _client The address where a client operator will receive protocols fees.
* @param _protocolFee The maximum amount of fees the protocol will collect.
*/
function getUserSavings(address _client, uint256 _protocolFee) external view returns (uint256 userSavings);
function getClientAllocations(address _client, uint256 _protocolFee)
external
view
returns (uint256 userSavings, uint256 clientFee);

/* ****************************************************************************
**
Expand Down
10 changes: 5 additions & 5 deletions src/libraries/FeeLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ library FeeLib {
*/
function takeProtocolFee(address _token, uint256 _cAmt, address _client) internal returns (uint256 cAmtNet) {
uint256 maxFee = (_cAmt * PROTOCOL_FEE_RATE) / 1000;
uint256 userSavings = IFeeCollector(FEE_COLLECTOR).getUserSavings(_client, maxFee);
uint256 fee = maxFee - userSavings;
cAmtNet = _cAmt - fee;
SafeTransferLib.safeApprove(ERC20(_token), FEE_COLLECTOR, fee);
IFeeCollector(FEE_COLLECTOR).collectFees(_client, _token, fee);
(uint256 userSavings, uint256 clientFee) = IFeeCollector(FEE_COLLECTOR).getClientAllocations(_client, maxFee);
uint256 totalFee = maxFee - userSavings;
cAmtNet = _cAmt - totalFee;
SafeTransferLib.safeApprove(ERC20(_token), FEE_COLLECTOR, totalFee);
IFeeCollector(FEE_COLLECTOR).collectFees(_client, _token, totalFee, clientFee);
}
}
75 changes: 69 additions & 6 deletions test/FeeCollector.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ import { PositionFactory } from "src/PositionFactory.sol";
import { FeeCollector } from "src/FeeCollector.sol";
import { Assets, CONTRACT_DEPLOYER, TEST_CLIENT, CLIENT_RATE, USDC, WETH, WBTC } from "test/common/Constants.t.sol";
import { TokenUtils } from "test/common/utils/TokenUtils.t.sol";
import { FeeUtils } from "test/common/utils/FeeUtils.t.sol";
import { IERC20 } from "src/interfaces/token/IERC20.sol";

contract FeeCollectorTest is Test, TokenUtils {
contract FeeCollectorTest is Test, TokenUtils, FeeUtils {
/* solhint-disable func-name-mixedcase */

struct TestPosition {
Expand Down Expand Up @@ -106,7 +107,7 @@ contract FeeCollectorTest is Test, TokenUtils {
uint256 preClientFeeBalance = feeCollector.balances(TEST_CLIENT, feeToken);

// Act: collect fees
feeCollector.collectFees(TEST_CLIENT, feeToken, _protocolFee);
feeCollector.collectFees(TEST_CLIENT, feeToken, _protocolFee, clientFee);

// Post-act balances
uint256 postContractBalance = IERC20(feeToken).balanceOf(feeCollectorAddr);
Expand Down Expand Up @@ -143,7 +144,8 @@ contract FeeCollectorTest is Test, TokenUtils {
uint256 preTotalClientBalances = feeCollector.totalClientBalances(feeToken);

// Act: collect fees
feeCollector.collectFees(address(0), feeToken, _protocolFee);
uint256 clientFee = (_protocolFee * CLIENT_RATE) / 100;
feeCollector.collectFees(address(0), feeToken, _protocolFee, clientFee);

// Post-act balances
uint256 postContractBalance = IERC20(feeToken).balanceOf(feeCollectorAddr);
Expand Down Expand Up @@ -175,7 +177,8 @@ contract FeeCollectorTest is Test, TokenUtils {
IERC20(feeToken).approve(feeCollectorAddr, _amount);

// Collect fees
feeCollector.collectFees(TEST_CLIENT, feeToken, _amount);
uint256 clientFee = (_amount * CLIENT_RATE) / 100;
feeCollector.collectFees(TEST_CLIENT, feeToken, _amount, clientFee);

// Pre-act balances
uint256 preContractBalance = IERC20(feeToken).balanceOf(feeCollectorAddr);
Expand Down Expand Up @@ -249,6 +252,64 @@ contract FeeCollectorTest is Test, TokenUtils {
feeCollector.setClientRate(_clientRate);
}

/// @dev
// - The current client rate should be updated to new client rate
function testFuzz_SetClientTakeRate(uint256 _clientTakeRate) public {
// Assumptions
_clientTakeRate = bound(_clientTakeRate, 0, 100);

// Pre-act data
uint256 preClientTakeRate = feeCollector.clientTakeRates(TEST_CLIENT);

// Assertions
assertEq(preClientTakeRate, 0);

// Act
vm.prank(TEST_CLIENT);
feeCollector.setClientTakeRate(_clientTakeRate);

// Post-act data
uint256 postClientTakeRate = feeCollector.clientTakeRates(TEST_CLIENT);

// Assertions
assertEq(postClientTakeRate, _clientTakeRate);
}

/// @dev
// - The user savings should be correct according to what's calculated in expectations
// - The user savings should be <= maxClientFee
// - The above should be true for all fee tokens
// - The above should be true for fuzzed _maxFee and _clientTakeRate
function testFuzz_GetClientAllocations(uint256 _maxFee, uint256 _clientTakeRate) public {
for (uint256 i; i < positions.length; i++) {
// Test Variables
address feeToken = positions[i].cToken;

// Bound fuzzed variables
_maxFee = bound(_maxFee, assets.minCAmts(feeToken), assets.maxCAmts(feeToken));
_clientTakeRate = bound(_clientTakeRate, 0, 100);

// Setup
vm.prank(TEST_CLIENT);
feeCollector.setClientTakeRate(_clientTakeRate);

// Expectations
uint256 maxClientFee = (CLIENT_RATE * _maxFee) / 100;
(uint256 expectedUserSavings, uint256 expectedClientFee) =
_getExpectedClientAllocations(_maxFee, _clientTakeRate);

// Act
(uint256 userSavings, uint256 clientFee) = feeCollector.getClientAllocations(TEST_CLIENT, _maxFee);

// Assertions
assertEq(userSavings, expectedUserSavings);
assertEq(clientFee, expectedClientFee);
assertEq(userSavings + clientFee, maxClientFee);
assertLe(userSavings, maxClientFee);
assertLe(clientFee, maxClientFee);
}
}

/// @dev
// - The FeeCollector's native balance should decrease by the amount transferred.
// - The owner's native balance should increase by the amount transferred.
Expand Down Expand Up @@ -310,7 +371,8 @@ contract FeeCollectorTest is Test, TokenUtils {
IERC20(token).approve(feeCollectorAddr, _amount);

// Collect fees
feeCollector.collectFees(TEST_CLIENT, token, _amount);
uint256 clientFee = (_amount * CLIENT_RATE) / 100;
feeCollector.collectFees(TEST_CLIENT, token, _amount, clientFee);

// Pre-act balances
uint256 preContractTokenBalance = IERC20(token).balanceOf(feeCollectorAddr);
Expand Down Expand Up @@ -353,7 +415,8 @@ contract FeeCollectorTest is Test, TokenUtils {
IERC20(token).approve(feeCollectorAddr, _amount);

// Collect fees
feeCollector.collectFees(TEST_CLIENT, token, _amount);
uint256 clientFee = (_amount * CLIENT_RATE) / 100;
feeCollector.collectFees(TEST_CLIENT, token, _amount, clientFee);

// Act: attempt to extract ERC20 token
vm.prank(_sender);
Expand Down
24 changes: 24 additions & 0 deletions test/common/utils/FeeUtils.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.21;

// External Imports
import { Test } from "forge-std/Test.sol";

// Local Imports
import { CLIENT_RATE } from "test/common/Constants.t.sol";

contract FeeUtils is Test {
function _getExpectedClientAllocations(uint256 _maxFee, uint256 _clientTakeRate)
internal
pure
returns (uint256 userSavings, uint256 clientFee)
{
uint256 userTakeRate = 100 - _clientTakeRate;
uint256 userPercentOfProtocolFee = (userTakeRate * CLIENT_RATE) / 100;
userSavings = (userPercentOfProtocolFee * _maxFee) / 100;

// 2. Calculate client fee
uint256 maxClientFee = (_maxFee * CLIENT_RATE) / 100;
clientFee = maxClientFee - userSavings;
}
}
Loading

0 comments on commit f5ee599

Please sign in to comment.