Skip to content

Commit

Permalink
fix: implement rate limiting
Browse files Browse the repository at this point in the history
  • Loading branch information
nsklikas committed Mar 28, 2024
1 parent 3c931a9 commit 9994a15
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 48 deletions.
11 changes: 4 additions & 7 deletions compose/compose_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ package compose
import (
"context"

"github.com/coocood/freecache"
"github.com/ory/fosite"
"github.com/ory/fosite/handler/oauth2"
"github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/handler/rfc8628"
"github.com/ory/fosite/token/hmac"
"github.com/ory/fosite/token/jwt"
"github.com/patrickmn/go-cache"
)

type CommonStrategy struct {
Expand Down Expand Up @@ -58,11 +58,8 @@ func NewOpenIDConnectStrategy(keyGetter func(context.Context) (interface{}, erro
// Create a new device strategy
func NewDeviceStrategy(config fosite.Configurator) *rfc8628.DefaultDeviceStrategy {
return &rfc8628.DefaultDeviceStrategy{
Enigma: &hmac.HMACStrategy{Config: config},
RateLimiterCache: cache.New(
config.GetDeviceAndUserCodeLifespan(context.TODO()),
config.GetDeviceAndUserCodeLifespan(context.TODO())*2,
),
Config: config,
Enigma: &hmac.HMACStrategy{Config: config},
RateLimiterCache: freecache.NewCache(1024 * 1024),
Config: config,
}
}
2 changes: 1 addition & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ const (
errRegistrationNotSupportedName = "registration_not_supported"
errJTIKnownName = "jti_known"
errAuthorizationPending = "authorization_pending"
errPollingIntervalRateLimited = "polling_interval_rate_limited"
errPollingIntervalRateLimited = "slow_down"
errDeviceExpiredToken = "expired_token"
)

Expand Down
3 changes: 1 addition & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ replace github.com/gorilla/sessions => github.com/gorilla/sessions v1.2.1

require (
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2
github.com/coocood/freecache v1.2.4
github.com/cristalhq/jwt/v4 v4.0.2
github.com/dgraph-io/ristretto v0.1.1
github.com/ecordell/optgen v0.0.9
Expand All @@ -25,7 +26,6 @@ require (
github.com/ory/go-convenience v0.1.0
github.com/ory/x v0.0.613
github.com/parnurzeal/gorequest v0.2.15
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.8.4
github.com/tidwall/gjson v1.14.3
Expand All @@ -34,7 +34,6 @@ require (
golang.org/x/net v0.20.0
golang.org/x/oauth2 v0.15.0
golang.org/x/text v0.14.0
golang.org/x/time v0.4.0
)

require (
Expand Down
7 changes: 3 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqy
github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
Expand All @@ -56,6 +57,8 @@ github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGX
github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/coocood/freecache v1.2.4 h1:UdR6Yz/X1HW4fZOuH0Z94KwG851GWOSknua5VUbb/5M=
github.com/coocood/freecache v1.2.4/go.mod h1:RBUWa/Cy+OHdfTGFEhEuE1pMCMX51Ncizj7rthiQ3vk=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
Expand Down Expand Up @@ -344,8 +347,6 @@ github.com/ory/x v0.0.613 h1:MHT0scH7hcrOkc3aH7qqYLzXVJkjhB0szWTwpD2lh8Q=
github.com/ory/x v0.0.613/go.mod h1:uH065puz8neija0neqwIN3PmXXfDsB9VbZTZ20Znoos=
github.com/parnurzeal/gorequest v0.2.15 h1:oPjDCsF5IkD4gUk6vIgsxYNaSgvAnIh1EJeROn3HdJU=
github.com/parnurzeal/gorequest v0.2.15/go.mod h1:3Kh2QUMJoqw3icWAecsyzkpY7UzRfDhbRdTjtNwNiUE=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0=
github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
Expand Down Expand Up @@ -675,8 +676,6 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.4.0 h1:Z81tqI5ddIoXDPvVQ7/7CC9TnLM7ubaFG2qXYd5BbYY=
golang.org/x/time v0.4.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
Expand Down
2 changes: 1 addition & 1 deletion handler/rfc8628/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type RFC8628CodeStrategy interface {

// DeviceRateLimitStrategy handles the rate limiting strategy
type DeviceRateLimitStrategy interface {
ShouldRateLimit(ctx context.Context, code string) bool
ShouldRateLimit(ctx context.Context, code string) (bool, error)
}

// DeviceCodeStrategy handles the device_code strategy
Expand Down
79 changes: 65 additions & 14 deletions handler/rfc8628/strategy_hmacsha.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ package rfc8628

import (
"context"
"encoding/json"
"strings"
"time"

"github.com/coocood/freecache"
"github.com/mohae/deepcopy"

"github.com/ory/x/errorsx"

"github.com/ory/x/randx"
"github.com/patrickmn/go-cache"
"golang.org/x/time/rate"

"github.com/ory/fosite"
enigma "github.com/ory/fosite/token/hmac"
Expand Down Expand Up @@ -100,7 +100,7 @@ func (s *DefaultDeviceFlowSession) SetBrowserFlowCompleted(flag bool) {
// DefaultDeviceStrategy implements the default device strategy
type DefaultDeviceStrategy struct {
Enigma *enigma.HMACStrategy
RateLimiterCache *cache.Cache
RateLimiterCache *freecache.Cache
Config interface {
fosite.DeviceProvider
fosite.DeviceAndUserCodeLifespanProvider
Expand Down Expand Up @@ -170,20 +170,71 @@ func (h *DefaultDeviceStrategy) ValidateDeviceCode(ctx context.Context, r fosite
}

// ShouldRateLimit is used to decide whether a request should be rate-limited
func (h *DefaultDeviceStrategy) ShouldRateLimit(context context.Context, code string) bool {
func (h *DefaultDeviceStrategy) ShouldRateLimit(context context.Context, code string) (bool, error) {
key := code + "_limiter"

if x, found := h.RateLimiterCache.Get(key); found {
return !x.(*rate.Limiter).Allow()
keyBytes := []byte(key)
object, err := h.RateLimiterCache.Get(keyBytes)
// This code is not in the cache, so we just add it
if err != nil {
timer := new(expirationTimer)
timer.Counter = 1
timer.NotUntil = h.getExpirationTime(context, 1)
exp, err := h.serializeExpiration(timer)
if err != nil {
return false, err
}
// Set the expiration time as value, and use the lifespan of the device code as TTL.
h.RateLimiterCache.Set(keyBytes, exp, int(h.Config.GetDeviceAndUserCodeLifespan(context)))
return false, nil
}

expiration, err := h.deserializeExpiration(object)
if err != nil {
return false, errorsx.WithStack(fosite.ErrServerError.WithHintf("Failed to store to rate limit cache: %s", err))
}

// The code is valid and enough time has passed since the last call.
if expiration.NotUntil.Before(time.Now()) {
expiration.NotUntil = h.getExpirationTime(context, expiration.Counter)
exp, err := h.serializeExpiration(expiration)
if err != nil {
return false, err
}
h.RateLimiterCache.Set(keyBytes, exp, int(h.Config.GetDeviceAndUserCodeLifespan(context)))
return false, nil
}

// The token calls were made too fast, we need to double the interval period
expiration.NotUntil = h.getExpirationTime(context, expiration.Counter+1)
expiration.Counter += 1
exp, err := h.serializeExpiration(expiration)
if err != nil {
return false, err
}
h.RateLimiterCache.Set(keyBytes, exp, int(h.Config.GetDeviceAndUserCodeLifespan(context)))

rateLimiter := rate.NewLimiter(
rate.Every(
h.Config.GetDeviceAuthTokenPollingInterval(context),
),
1,
)
return true, nil
}

func (h *DefaultDeviceStrategy) getExpirationTime(context context.Context, multiplier int) time.Time {
duration := h.Config.GetDeviceAuthTokenPollingInterval(context)
expiration := time.Now().Add(duration * time.Duration(multiplier))
return expiration
}

type expirationTimer struct {
NotUntil time.Time
Counter int
}

func (h *DefaultDeviceStrategy) serializeExpiration(exp *expirationTimer) ([]byte, error) {
b, err := json.Marshal(exp)
return b, err
}

h.RateLimiterCache.Set(key, rateLimiter, cache.DefaultExpiration)
return false
func (h *DefaultDeviceStrategy) deserializeExpiration(b []byte) (*expirationTimer, error) {
timer := new(expirationTimer)
err := json.Unmarshal(b, timer)
return timer, err
}
33 changes: 23 additions & 10 deletions handler/rfc8628/strategy_hmacsha_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"testing"
"time"

"github.com/patrickmn/go-cache"
"github.com/coocood/freecache"
"github.com/stretchr/testify/assert"

"github.com/ory/fosite"
Expand All @@ -21,7 +21,7 @@ import (

var hmacshaStrategy = DefaultDeviceStrategy{
Enigma: &hmac.HMACStrategy{Config: &fosite.Config{GlobalSecret: []byte("foobarfoobarfoobarfoobarfoobarfoobarfoobarfoobar")}},
RateLimiterCache: cache.New(24*time.Minute, 2*24*time.Minute),
RateLimiterCache: freecache.NewCache(16384 * 64),
Config: &fosite.Config{
AccessTokenLifespan: time.Minute * 24,
AuthorizeCodeLifespan: time.Minute * 24,
Expand Down Expand Up @@ -115,17 +115,30 @@ func TestHMACDeviceCode(t *testing.T) {

func TestRateLimit(t *testing.T) {
t.Run("ratelimit no-wait", func(t *testing.T) {
hmacshaStrategy.RateLimiterCache.Flush()
assert.False(t, hmacshaStrategy.ShouldRateLimit(context.TODO(), "AAA"))
assert.False(t, hmacshaStrategy.ShouldRateLimit(context.TODO(), "AAA"))
assert.True(t, hmacshaStrategy.ShouldRateLimit(context.TODO(), "AAA"))
hmacshaStrategy.RateLimiterCache.Clear()
b, err := hmacshaStrategy.ShouldRateLimit(context.TODO(), "AAA")
assert.NoError(t, err)
assert.False(t, b)
b, err = hmacshaStrategy.ShouldRateLimit(context.TODO(), "AAA")
assert.NoError(t, err)
assert.True(t, b)
})

t.Run("ratelimit wait", func(t *testing.T) {
hmacshaStrategy.RateLimiterCache.Flush()
assert.False(t, hmacshaStrategy.ShouldRateLimit(context.TODO(), "AAA"))
assert.False(t, hmacshaStrategy.ShouldRateLimit(context.TODO(), "AAA"))
hmacshaStrategy.RateLimiterCache.Clear()
b, err := hmacshaStrategy.ShouldRateLimit(context.TODO(), "AAA")
assert.NoError(t, err)
assert.False(t, b)
time.Sleep(500 * time.Millisecond)
assert.False(t, hmacshaStrategy.ShouldRateLimit(context.TODO(), "AAA"))
b, err = hmacshaStrategy.ShouldRateLimit(context.TODO(), "AAA")
assert.NoError(t, err)
assert.False(t, b)
time.Sleep(500 * time.Millisecond)
b, err = hmacshaStrategy.ShouldRateLimit(context.TODO(), "AAA")
assert.NoError(t, err)
assert.False(t, b)
b, err = hmacshaStrategy.ShouldRateLimit(context.TODO(), "AAA")
assert.NoError(t, err)
assert.True(t, b)
})
}
7 changes: 6 additions & 1 deletion handler/rfc8628/token_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ type DeviceCodeHandler struct {
func (c DeviceCodeHandler) Code(ctx context.Context, requester fosite.AccessRequester) (code string, signature string, err error) {
code = requester.GetRequestForm().Get("device_code")

if c.DeviceRateLimitStrategy.ShouldRateLimit(ctx, code) {
shouldRateLimit, err := c.DeviceRateLimitStrategy.ShouldRateLimit(ctx, code)
// TODO(nsklikas) : should we error out or just silently log it?
if err != nil {
return "", "", err
}
if shouldRateLimit {
return "", "", errorsx.WithStack(fosite.ErrPollingRateLimited)
}

Expand Down
78 changes: 70 additions & 8 deletions handler/rfc8628/token_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@ import (
"testing"
"time"

"github.com/coocood/freecache"
"github.com/pkg/errors"

"github.com/golang/mock/gomock"
"github.com/ory/fosite/internal"

"github.com/patrickmn/go-cache"

"github.com/ory/fosite/handler/oauth2"
"github.com/ory/fosite/token/hmac"

Expand All @@ -35,11 +34,8 @@ var hmacshaStrategy = oauth2.HMACSHAStrategy{
}

var RFC8628HMACSHAStrategy = DefaultDeviceStrategy{
Enigma: &hmac.HMACStrategy{Config: &fosite.Config{GlobalSecret: []byte("foobarfoobarfoobarfoobarfoobarfoobarfoobarfoobar")}},
RateLimiterCache: cache.New(
time.Hour*12,
time.Hour*24,
),
Enigma: &hmac.HMACStrategy{Config: &fosite.Config{GlobalSecret: []byte("foobarfoobarfoobarfoobarfoobarfoobarfoobarfoobar")}},
RateLimiterCache: freecache.NewCache(16384 * 64),
Config: &fosite.Config{
DeviceAndUserCodeLifespan: time.Hour * 24,
},
Expand Down Expand Up @@ -268,6 +264,72 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) {
}
}

func TestDeviceUserCode_HandleTokenEndpointRequest_Ratelimitting(t *testing.T) {
for k, strategy := range map[string]struct {
oauth2.CoreStrategy
RFC8628CodeStrategy
}{
"hmac": {&hmacshaStrategy, &RFC8628HMACSHAStrategy},
} {
t.Run("strategy="+k, func(t *testing.T) {
store := storage.NewMemoryStore()

h := oauth2.GenericCodeTokenEndpointHandler{
AccessRequestValidator: &DeviceAccessRequestValidator{},
CodeHandler: &DeviceCodeHandler{
DeviceRateLimitStrategy: strategy,
DeviceCodeStrategy: strategy,
},
SessionHandler: &DeviceSessionHandler{
DeviceCodeStorage: store,
},
CoreStorage: store,
AccessTokenStrategy: strategy.CoreStrategy,
RefreshTokenStrategy: strategy.CoreStrategy,
Config: &fosite.Config{
ScopeStrategy: fosite.HierarchicScopeStrategy,
AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy,
DeviceAndUserCodeLifespan: time.Minute,
},
}
areq := &fosite.AccessRequest{
GrantTypes: fosite.Arguments{string(fosite.GrantTypeDeviceCode)},
Request: fosite.Request{
Form: url.Values{},
Client: &fosite.DefaultClient{
ID: "foo",
GrantTypes: fosite.Arguments{string(fosite.GrantTypeDeviceCode)},
},
GrantedScope: fosite.Arguments{"foo", "offline"},
Session: &DefaultDeviceFlowSession{},
RequestedAt: time.Now().UTC(),
},
}
authreq := &fosite.DeviceRequest{
Request: fosite.Request{
Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{string(fosite.GrantTypeDeviceCode)}},
Session: &DefaultDeviceFlowSession{
BrowserFlowCompleted: true,
},
RequestedAt: time.Now().UTC(),
},
}

token, signature, err := strategy.GenerateDeviceCode(context.TODO())
require.NoError(t, err)

areq.Form = url.Values{"device_code": {token}}
require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq))
err = h.HandleTokenEndpointRequest(context.Background(), areq)
require.NoError(t, err, "%+v", err)
err = h.HandleTokenEndpointRequest(context.Background(), areq)
require.Error(t, fosite.ErrPollingRateLimited, err)
time.Sleep(10 * time.Second)
err = h.HandleTokenEndpointRequest(context.Background(), areq)
require.NoError(t, err, "%+v", err)
})
}
}
func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) {
for k, strategy := range map[string]struct {
oauth2.CoreStrategy
Expand Down Expand Up @@ -705,7 +767,7 @@ func TestDeviceUserCodeTransactional_HandleTokenEndpointRequest(t *testing.T) {
mockCoreStore = internal.NewMockCoreStorage(ctrl)
mockDeviceCodeStore = internal.NewMockDeviceCodeStorage(ctrl)
mockDeviceRateLimitStrategy = internal.NewMockDeviceRateLimitStrategy(ctrl)
mockDeviceRateLimitStrategy.EXPECT().ShouldRateLimit(gomock.Any(), gomock.Any()).Return(false).Times(1)
mockDeviceRateLimitStrategy.EXPECT().ShouldRateLimit(gomock.Any(), gomock.Any()).Return(false, nil).Times(1)
testCase.setup()

handler := oauth2.GenericCodeTokenEndpointHandler{
Expand Down

0 comments on commit 9994a15

Please sign in to comment.