Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Use sync.Map to avoid contention on RateLimiter
Browse files Browse the repository at this point in the history
Signed-off-by: TungHoang <[email protected]>
  • Loading branch information
LaPetiteSouris committed May 16, 2023
1 parent e9585fb commit 6061035
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 22 deletions.
43 changes: 25 additions & 18 deletions plugins/rate_limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,50 +20,56 @@ type RateLimitExceeded error
type accessRecords struct {
limiter *rate.Limiter
lastAccess time.Time
mutex *sync.Mutex
}

// LimiterStore stores the access records for each user
type LimiterStore struct {
// accessPerUser is a synchronized map of userID to accessRecords
accessPerUser map[string]*accessRecords
mutex *sync.Mutex
accessPerUser *sync.Map
requestPerSec int
burstSize int
mutex *sync.Mutex
cleanupInterval time.Duration
}

// Allow takes a userID and returns an error if the user has exceeded the rate limit
func (l *LimiterStore) Allow(userID string) error {
l.mutex.Lock()
defer l.mutex.Unlock()
if _, ok := l.accessPerUser[userID]; !ok {
l.accessPerUser[userID] = &accessRecords{
lastAccess: time.Now(),
limiter: rate.NewLimiter(rate.Limit(l.requestPerSec), l.burstSize),
}
}
l.accessPerUser[userID].lastAccess = time.Now()

if !l.accessPerUser[userID].limiter.Allow() {
accessRecord, _ := l.accessPerUser.LoadOrStore(userID, &accessRecords{
limiter: rate.NewLimiter(rate.Limit(l.requestPerSec), l.burstSize),
lastAccess: time.Now(),
mutex: &sync.Mutex{},
})
accessRecord.(*accessRecords).mutex.Lock()
defer accessRecord.(*accessRecords).mutex.Unlock()

accessRecord.(*accessRecords).lastAccess = time.Now()
l.accessPerUser.Store(userID, accessRecord)

if !accessRecord.(*accessRecords).limiter.Allow() {
return RateLimitExceeded(fmt.Errorf("rate limit exceeded"))
}

return nil
}

// clean removes the access records for users who have not accessed the system for a while
func (l *LimiterStore) clean() {
l.mutex.Lock()
defer l.mutex.Unlock()
for userID, accessRecord := range l.accessPerUser {
if time.Since(accessRecord.lastAccess) > l.cleanupInterval {
delete(l.accessPerUser, userID)
l.accessPerUser.Range(func(key, value interface{}) bool {
value.(*accessRecords).mutex.Lock()
defer value.(*accessRecords).mutex.Unlock()
if time.Since(value.(*accessRecords).lastAccess) > l.cleanupInterval {
l.accessPerUser.Delete(key)
}
}
return true
})
}

func newRateLimitStore(requestPerSec int, burstSize int, cleanupInterval time.Duration) *LimiterStore {
l := &LimiterStore{
accessPerUser: make(map[string]*accessRecords),
accessPerUser: &sync.Map{},
mutex: &sync.Mutex{},
requestPerSec: requestPerSec,
burstSize: burstSize,
Expand All @@ -80,6 +86,7 @@ func newRateLimitStore(requestPerSec int, burstSize int, cleanupInterval time.Du
return l
}

// RateLimiter is a struct that implements the RateLimiter interface from grpc middleware
type RateLimiter struct {
limiter *LimiterStore
}
Expand Down
25 changes: 21 additions & 4 deletions plugins/rate_limit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestNewRateLimiter(t *testing.T) {
}

func TestLimiterAllow(t *testing.T) {
rlStore := newRateLimitStore(1, 1, time.Second)
rlStore := newRateLimitStore(1, 1, 10*time.Second)
assert.NoError(t, rlStore.Allow("hello"))
assert.Error(t, rlStore.Allow("hello"))
time.Sleep(time.Second)
Expand Down Expand Up @@ -97,13 +97,30 @@ func TestRateLimiterLimitWithoutUserIdentity(t *testing.T) {
func TestRateLimiterUpdateLastAccessTime(t *testing.T) {
rlStore := newRateLimitStore(2, 2, time.Second)
assert.NoError(t, rlStore.Allow("hello"))
firstAccessTime := rlStore.accessPerUser["hello"].lastAccess
// get last access time

accessRecord, _ := rlStore.accessPerUser.Load("hello")
accessRecord.(*accessRecords).mutex.Lock()
firstAccessTime := accessRecord.(*accessRecords).lastAccess
accessRecord.(*accessRecords).mutex.Unlock()

assert.NoError(t, rlStore.Allow("hello"))
secondAccessTime := rlStore.accessPerUser["hello"].lastAccess

accessRecord, _ = rlStore.accessPerUser.Load("hello")
accessRecord.(*accessRecords).mutex.Lock()
secondAccessTime := accessRecord.(*accessRecords).lastAccess
accessRecord.(*accessRecords).mutex.Unlock()

assert.True(t, secondAccessTime.After(firstAccessTime))

// Verify that the last access time is updated even when user is rate limited
assert.Error(t, rlStore.Allow("hello"))
thirdAccessTime := rlStore.accessPerUser["hello"].lastAccess

accessRecord, _ = rlStore.accessPerUser.Load("hello")
accessRecord.(*accessRecords).mutex.Lock()
thirdAccessTime := accessRecord.(*accessRecords).lastAccess
accessRecord.(*accessRecords).mutex.Unlock()

assert.True(t, thirdAccessTime.After(secondAccessTime))

}

0 comments on commit 6061035

Please sign in to comment.