diff --git a/plugins/rate_limit.go b/plugins/rate_limit.go index 90c85f39f..4c81657de 100644 --- a/plugins/rate_limit.go +++ b/plugins/rate_limit.go @@ -20,50 +20,56 @@ type RateLimitExceeded error type accessRecords struct { limiter *rate.Limiter lastAccess time.Time + mutex *sync.Mutex } // LimiterStore stores the access records for each user type LimiterStore struct { // accessPerUser is a synchronized map of userID to accessRecords - accessPerUser map[string]*accessRecords - mutex *sync.Mutex + accessPerUser *sync.Map requestPerSec int burstSize int + mutex *sync.Mutex cleanupInterval time.Duration } // Allow takes a userID and returns an error if the user has exceeded the rate limit func (l *LimiterStore) Allow(userID string) error { - l.mutex.Lock() - defer l.mutex.Unlock() - if _, ok := l.accessPerUser[userID]; !ok { - l.accessPerUser[userID] = &accessRecords{ - lastAccess: time.Now(), - limiter: rate.NewLimiter(rate.Limit(l.requestPerSec), l.burstSize), - } - } - l.accessPerUser[userID].lastAccess = time.Now() - - if !l.accessPerUser[userID].limiter.Allow() { + accessRecord, _ := l.accessPerUser.LoadOrStore(userID, &accessRecords{ + limiter: rate.NewLimiter(rate.Limit(l.requestPerSec), l.burstSize), + lastAccess: time.Now(), + mutex: &sync.Mutex{}, + }) + accessRecord.(*accessRecords).mutex.Lock() + defer accessRecord.(*accessRecords).mutex.Unlock() + + accessRecord.(*accessRecords).lastAccess = time.Now() + l.accessPerUser.Store(userID, accessRecord) + + if !accessRecord.(*accessRecords).limiter.Allow() { return RateLimitExceeded(fmt.Errorf("rate limit exceeded")) } return nil } +// clean removes the access records for users who have not accessed the system for a while func (l *LimiterStore) clean() { l.mutex.Lock() defer l.mutex.Unlock() - for userID, accessRecord := range l.accessPerUser { - if time.Since(accessRecord.lastAccess) > l.cleanupInterval { - delete(l.accessPerUser, userID) + l.accessPerUser.Range(func(key, value interface{}) bool { + value.(*accessRecords).mutex.Lock() + defer value.(*accessRecords).mutex.Unlock() + if time.Since(value.(*accessRecords).lastAccess) > l.cleanupInterval { + l.accessPerUser.Delete(key) } - } + return true + }) } func newRateLimitStore(requestPerSec int, burstSize int, cleanupInterval time.Duration) *LimiterStore { l := &LimiterStore{ - accessPerUser: make(map[string]*accessRecords), + accessPerUser: &sync.Map{}, mutex: &sync.Mutex{}, requestPerSec: requestPerSec, burstSize: burstSize, @@ -80,6 +86,7 @@ func newRateLimitStore(requestPerSec int, burstSize int, cleanupInterval time.Du return l } +// RateLimiter is a struct that implements the RateLimiter interface from grpc middleware type RateLimiter struct { limiter *LimiterStore } diff --git a/plugins/rate_limit_test.go b/plugins/rate_limit_test.go index b06bc5994..cafcf3d00 100644 --- a/plugins/rate_limit_test.go +++ b/plugins/rate_limit_test.go @@ -15,7 +15,7 @@ func TestNewRateLimiter(t *testing.T) { } func TestLimiterAllow(t *testing.T) { - rlStore := newRateLimitStore(1, 1, time.Second) + rlStore := newRateLimitStore(1, 1, 10*time.Second) assert.NoError(t, rlStore.Allow("hello")) assert.Error(t, rlStore.Allow("hello")) time.Sleep(time.Second) @@ -97,13 +97,30 @@ func TestRateLimiterLimitWithoutUserIdentity(t *testing.T) { func TestRateLimiterUpdateLastAccessTime(t *testing.T) { rlStore := newRateLimitStore(2, 2, time.Second) assert.NoError(t, rlStore.Allow("hello")) - firstAccessTime := rlStore.accessPerUser["hello"].lastAccess + // get last access time + + accessRecord, _ := rlStore.accessPerUser.Load("hello") + accessRecord.(*accessRecords).mutex.Lock() + firstAccessTime := accessRecord.(*accessRecords).lastAccess + accessRecord.(*accessRecords).mutex.Unlock() + assert.NoError(t, rlStore.Allow("hello")) - secondAccessTime := rlStore.accessPerUser["hello"].lastAccess + + accessRecord, _ = rlStore.accessPerUser.Load("hello") + accessRecord.(*accessRecords).mutex.Lock() + secondAccessTime := accessRecord.(*accessRecords).lastAccess + accessRecord.(*accessRecords).mutex.Unlock() + assert.True(t, secondAccessTime.After(firstAccessTime)) + // Verify that the last access time is updated even when user is rate limited assert.Error(t, rlStore.Allow("hello")) - thirdAccessTime := rlStore.accessPerUser["hello"].lastAccess + + accessRecord, _ = rlStore.accessPerUser.Load("hello") + accessRecord.(*accessRecords).mutex.Lock() + thirdAccessTime := accessRecord.(*accessRecords).lastAccess + accessRecord.(*accessRecords).mutex.Unlock() + assert.True(t, thirdAccessTime.After(secondAccessTime)) }