-
Notifications
You must be signed in to change notification settings - Fork 51
/
Copy pathRateLimiter.sol
339 lines (295 loc) · 12 KB
/
RateLimiter.sol
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
// SPDX-License-Identifier: Apache 2
pragma solidity >=0.8.8 <0.9.0;
import "../interfaces/IRateLimiter.sol";
import "../interfaces/IRateLimiterEvents.sol";
import "./TransceiverHelpers.sol";
import "./TransceiverStructs.sol";
import "../libraries/TrimmedAmount.sol";
import "openzeppelin-contracts/contracts/utils/math/SafeCast.sol";
abstract contract RateLimiter is IRateLimiter, IRateLimiterEvents {
using TrimmedAmountLib for TrimmedAmount;
/// @dev The duration (in seconds) it takes for the limits to fully replenish.
uint64 public immutable rateLimitDuration;
/// =============== STORAGE ===============================================
bytes32 private constant OUTBOUND_LIMIT_PARAMS_SLOT =
bytes32(uint256(keccak256("ntt.outboundLimitParams")) - 1);
bytes32 private constant OUTBOUND_QUEUE_SLOT =
bytes32(uint256(keccak256("ntt.outboundQueue")) - 1);
bytes32 private constant INBOUND_LIMIT_PARAMS_SLOT =
bytes32(uint256(keccak256("ntt.inboundLimitParams")) - 1);
bytes32 private constant INBOUND_QUEUE_SLOT =
bytes32(uint256(keccak256("ntt.inboundQueue")) - 1);
function _getOutboundLimitParamsStorage() internal pure returns (RateLimitParams storage $) {
uint256 slot = uint256(OUTBOUND_LIMIT_PARAMS_SLOT);
assembly ("memory-safe") {
$.slot := slot
}
}
function _getOutboundQueueStorage()
internal
pure
returns (mapping(uint64 => OutboundQueuedTransfer) storage $)
{
uint256 slot = uint256(OUTBOUND_QUEUE_SLOT);
assembly ("memory-safe") {
$.slot := slot
}
}
function _getInboundLimitParamsStorage()
internal
pure
returns (mapping(uint16 => RateLimitParams) storage $)
{
uint256 slot = uint256(INBOUND_LIMIT_PARAMS_SLOT);
assembly ("memory-safe") {
$.slot := slot
}
}
function _getInboundQueueStorage()
internal
pure
returns (mapping(bytes32 => InboundQueuedTransfer) storage $)
{
uint256 slot = uint256(INBOUND_QUEUE_SLOT);
assembly ("memory-safe") {
$.slot := slot
}
}
constructor(uint64 _rateLimitDuration, bool _skipRateLimiting) {
if (
_rateLimitDuration == 0 && !_skipRateLimiting
|| _rateLimitDuration != 0 && _skipRateLimiting
) {
revert UndefinedRateLimiting();
}
rateLimitDuration = _rateLimitDuration;
}
function _setLimit(TrimmedAmount limit, RateLimitParams storage rateLimitParams) internal {
TrimmedAmount oldLimit = rateLimitParams.limit;
if (oldLimit.isNull()) {
rateLimitParams.currentCapacity = limit;
} else {
TrimmedAmount currentCapacity = _getCurrentCapacity(rateLimitParams);
rateLimitParams.currentCapacity =
_calculateNewCurrentCapacity(limit, oldLimit, currentCapacity);
}
rateLimitParams.limit = limit;
rateLimitParams.lastTxTimestamp = uint64(block.timestamp);
}
function _setOutboundLimit(
TrimmedAmount limit
) internal virtual {
RateLimitParams storage rateLimitParams = _getOutboundLimitParamsStorage();
TrimmedAmount oldLimit = rateLimitParams.limit;
uint8 decimals = tokenDecimals();
_setLimit(limit, rateLimitParams);
emit OutboundTransferLimitUpdated(oldLimit.untrim(decimals), limit.untrim(decimals));
}
function getOutboundLimitParams() public pure virtual returns (RateLimitParams memory) {
return _getOutboundLimitParamsStorage();
}
function getCurrentOutboundCapacity() public view virtual returns (uint256) {
TrimmedAmount trimmedCapacity = _getCurrentCapacity(getOutboundLimitParams());
uint8 decimals = tokenDecimals();
return trimmedCapacity.untrim(decimals);
}
function getOutboundQueuedTransfer(
uint64 queueSequence
) public view virtual returns (OutboundQueuedTransfer memory) {
return _getOutboundQueueStorage()[queueSequence];
}
function _setInboundLimit(TrimmedAmount limit, uint16 chainId_) internal virtual {
RateLimitParams storage rateLimitParams = _getInboundLimitParamsStorage()[chainId_];
TrimmedAmount oldLimit = rateLimitParams.limit;
uint8 decimals = tokenDecimals();
_setLimit(limit, rateLimitParams);
emit InboundTransferLimitUpdated(
chainId_, oldLimit.untrim(decimals), limit.untrim(decimals)
);
}
function getInboundLimitParams(
uint16 chainId_
) public view virtual returns (RateLimitParams memory) {
return _getInboundLimitParamsStorage()[chainId_];
}
function getCurrentInboundCapacity(
uint16 chainId_
) public view virtual returns (uint256) {
TrimmedAmount trimmedCapacity = _getCurrentCapacity(getInboundLimitParams(chainId_));
uint8 decimals = tokenDecimals();
return trimmedCapacity.untrim(decimals);
}
function getInboundQueuedTransfer(
bytes32 digest
) public view virtual returns (InboundQueuedTransfer memory) {
return _getInboundQueueStorage()[digest];
}
/**
* @dev Gets the current capacity for a parameterized rate limits struct
*/
function _getCurrentCapacity(
RateLimitParams memory rateLimitParams
) internal view returns (TrimmedAmount capacity) {
// If the rate limit duration is 0 then the rate limiter is skipped
if (rateLimitDuration == 0) {
return
packTrimmedAmount(type(uint64).max, rateLimitParams.currentCapacity.getDecimals());
}
// The capacity and rate limit are expressed as trimmed amounts, i.e.
// 64-bit unsigned integers. The following operations upcast the 64-bit
// unsigned integers to 256-bit unsigned integers to avoid overflow.
// Specifically, the calculatedCapacity can overflow the u64 max.
// For example, if the limit is uint64.max, then the multiplication in calculatedCapacity
// will overflow when timePassed is greater than rateLimitDuration.
// Operating on uint256 avoids this issue. The overflow is cancelled out by the min operation,
// whose second argument is a uint64, so the result can safely be downcast to a uint64.
unchecked {
uint256 timePassed = block.timestamp - rateLimitParams.lastTxTimestamp;
// Multiply (limit * timePassed), then divide by the duration.
// Dividing first has terrible numerical stability --
// when rateLimitDuration is close to the limit, there is significant rounding error.
// We are safe to multiply first, since these numbers are u64 TrimmedAmount types
// and we're performing arithmetic on u256 words.
uint256 calculatedCapacity = rateLimitParams.currentCapacity.getAmount()
+ (rateLimitParams.limit.getAmount() * timePassed) / rateLimitDuration;
uint256 result = min(calculatedCapacity, rateLimitParams.limit.getAmount());
return packTrimmedAmount(
SafeCast.toUint64(result), rateLimitParams.currentCapacity.getDecimals()
);
}
}
/**
* @dev Updates the current capacity
*
* @param newLimit The new limit
* @param oldLimit The old limit
* @param currentCapacity The current capacity
*/
function _calculateNewCurrentCapacity(
TrimmedAmount newLimit,
TrimmedAmount oldLimit,
TrimmedAmount currentCapacity
) internal pure returns (TrimmedAmount newCurrentCapacity) {
TrimmedAmount difference;
if (oldLimit > newLimit) {
difference = oldLimit - newLimit;
newCurrentCapacity = currentCapacity > difference
? currentCapacity - difference
: packTrimmedAmount(0, currentCapacity.getDecimals());
} else {
difference = newLimit - oldLimit;
newCurrentCapacity = currentCapacity + difference;
}
if (newCurrentCapacity > newLimit) {
revert CapacityCannotExceedLimit(newCurrentCapacity, newLimit);
}
}
function _consumeOutboundAmount(
TrimmedAmount amount
) internal {
if (rateLimitDuration == 0) return;
_consumeRateLimitAmount(
amount, _getCurrentCapacity(getOutboundLimitParams()), _getOutboundLimitParamsStorage()
);
}
function _backfillOutboundAmount(
TrimmedAmount amount
) internal {
if (rateLimitDuration == 0) return;
_backfillRateLimitAmount(
amount, _getCurrentCapacity(getOutboundLimitParams()), _getOutboundLimitParamsStorage()
);
}
function _consumeInboundAmount(TrimmedAmount amount, uint16 chainId_) internal {
if (rateLimitDuration == 0) return;
_consumeRateLimitAmount(
amount,
_getCurrentCapacity(getInboundLimitParams(chainId_)),
_getInboundLimitParamsStorage()[chainId_]
);
}
function _backfillInboundAmount(TrimmedAmount amount, uint16 chainId_) internal {
if (rateLimitDuration == 0) return;
_backfillRateLimitAmount(
amount,
_getCurrentCapacity(getInboundLimitParams(chainId_)),
_getInboundLimitParamsStorage()[chainId_]
);
}
function _consumeRateLimitAmount(
TrimmedAmount amount,
TrimmedAmount capacity,
RateLimitParams storage rateLimitParams
) internal {
rateLimitParams.lastTxTimestamp = uint64(block.timestamp);
rateLimitParams.currentCapacity = capacity - amount;
}
/// @dev Refills the capacity by the given amount.
/// This is used to replenish the capacity via backflows.
function _backfillRateLimitAmount(
TrimmedAmount amount,
TrimmedAmount capacity,
RateLimitParams storage rateLimitParams
) internal {
rateLimitParams.lastTxTimestamp = uint64(block.timestamp);
rateLimitParams.currentCapacity = capacity.saturatingAdd(amount).min(rateLimitParams.limit);
}
function _isOutboundAmountRateLimited(
TrimmedAmount amount
) internal view returns (bool) {
return rateLimitDuration != 0
? _isAmountRateLimited(_getCurrentCapacity(getOutboundLimitParams()), amount)
: false;
}
function _isInboundAmountRateLimited(
TrimmedAmount amount,
uint16 chainId_
) internal view returns (bool) {
return rateLimitDuration != 0
? _isAmountRateLimited(_getCurrentCapacity(getInboundLimitParams(chainId_)), amount)
: false;
}
function _isAmountRateLimited(
TrimmedAmount capacity,
TrimmedAmount amount
) internal pure returns (bool) {
return capacity < amount;
}
function _enqueueOutboundTransfer(
uint16 sourceChain,
uint64 sequence,
TrimmedAmount amount,
uint16 recipientChain,
bytes32 recipient,
bytes32 refundAddress,
address senderAddress,
bytes memory transceiverInstructions
) internal {
_getOutboundQueueStorage()[sequence] = OutboundQueuedTransfer({
amount: amount,
recipientChain: recipientChain,
recipient: recipient,
refundAddress: refundAddress,
txTimestamp: uint64(block.timestamp),
sender: senderAddress,
transceiverInstructions: transceiverInstructions,
sourceChain: sourceChain
});
emit OutboundTransferQueued(sequence);
}
function _enqueueInboundTransfer(
uint16 sourceChain,
bytes32 digest,
TrimmedAmount amount,
address recipient
) internal {
_getInboundQueueStorage()[digest] = InboundQueuedTransfer({
amount: amount,
recipient: recipient,
txTimestamp: uint64(block.timestamp),
sourceChain: sourceChain
});
emit InboundTransferQueued(digest);
}
function tokenDecimals() public view virtual returns (uint8);
}