Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: rate limiting #11

Merged
merged 2 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions handler/oauth2/flow_authorize_code_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ func (c AuthorizeCodeHandler) Code(ctx context.Context, requester fosite.AccessR
}

func (c AuthorizeCodeHandler) ValidateCode(ctx context.Context, requester fosite.Requester, code string) error {
return nil
}

func (c AuthorizeCodeHandler) ValidateCodeSession(ctx context.Context, requester fosite.Requester, code string) error {
return c.AuthorizeCodeStrategy.ValidateAuthorizeCode(ctx, requester, code)
}

Expand Down
22 changes: 0 additions & 22 deletions handler/oauth2/flow_authorize_code_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,6 @@ func TestAuthorizeCode_PopulateTokenEndpointResponse(t *testing.T) {
},
expectErr: fosite.ErrServerError,
},
{
description: "should fail because authorization code is expired",
areq: &fosite.AccessRequest{
GrantTypes: fosite.Arguments{"authorization_code"},
Request: fosite.Request{
Form: url.Values{"code": []string{"foo.bar"}},
Client: &fosite.DefaultClient{
GrantTypes: fosite.Arguments{"authorization_code"},
},
Session: &fosite.DefaultSession{
ExpiresAt: map[fosite.TokenType]time.Time{
fosite.AuthorizeCode: time.Now().Add(-time.Hour).UTC(),
},
},
RequestedAt: time.Now().Add(-2 * time.Hour).UTC(),
},
},
setup: func(t *testing.T, areq *fosite.AccessRequest, config *fosite.Config) {
require.NoError(t, store.CreateAuthorizeCodeSession(context.Background(), "bar", areq))
},
expectErr: fosite.ErrInvalidRequest,
},
{
description: "should pass with offline scope and refresh token",
areq: &fosite.AccessRequest{
Expand Down
15 changes: 11 additions & 4 deletions handler/oauth2/flow_generic_code_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ type CodeHandler interface {
// Code fetches the code and code signature.
Code(ctx context.Context, requester fosite.AccessRequester) (code string, signature string, err error)

// ValidateCode validates the code.
// ValidateCode validates the code. Can be used for checks that need to run before we fetch the session from the database.
ValidateCode(ctx context.Context, requester fosite.Requester, code string) error

// ValidateCodeSession validates the code session.
ValidateCodeSession(ctx context.Context, requester fosite.Requester, code string) error
}

// SessionHandler handles session-related operations.
Expand Down Expand Up @@ -83,8 +86,8 @@ func (c *GenericCodeTokenEndpointHandler) PopulateTokenEndpointResponse(ctx cont
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}

if err = c.ValidateCode(ctx, requester, code); err != nil {
return errorsx.WithStack(fosite.ErrInvalidRequest.WithWrap(err).WithDebug(err.Error()))
if err = c.ValidateCodeSession(ctx, ar, code); err != nil {
return errorsx.WithStack(err)
}

for _, scope := range ar.GetRequestedScopes() {
Expand Down Expand Up @@ -166,6 +169,10 @@ func (c *GenericCodeTokenEndpointHandler) HandleTokenEndpointRequest(ctx context
return err
}

if err = c.ValidateCode(ctx, requester, code); err != nil {
return errorsx.WithStack(err)
}

var ar fosite.Requester
if ar, err = c.Session(ctx, requester, signature); err != nil {
if ar != nil && (errors.Is(err, fosite.ErrInvalidatedAuthorizeCode) || errors.Is(err, fosite.ErrInvalidatedDeviceCode)) {
Expand All @@ -175,7 +182,7 @@ func (c *GenericCodeTokenEndpointHandler) HandleTokenEndpointRequest(ctx context
return err
}

if err = c.ValidateCode(ctx, ar, code); err != nil {
if err = c.ValidateCodeSession(ctx, ar, code); err != nil {
return errorsx.WithStack(err)
}

Expand Down
26 changes: 14 additions & 12 deletions handler/rfc8628/strategy_hmacsha.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
enigma "github.com/ory/fosite/token/hmac"
)

const POLLING_RATE_LIMITING_LEEWAY = 200 * time.Millisecond

// DeviceFlowSession is a fosite.Session container specific for the device flow.
type DeviceFlowSession interface {
// GetBrowserFlowCompleted returns the flag indicating whether user has completed the browser flow or not.
Expand Down Expand Up @@ -179,13 +181,13 @@ func (h *DefaultDeviceStrategy) ShouldRateLimit(context context.Context, code st
if err != nil {
timer := new(expirationTimer)
timer.Counter = 1
timer.NotUntil = h.getExpirationTime(context, 1)
timer.NotUntil = h.getNotUntil(context, 1)
exp, err := h.serializeExpiration(timer)
if err != nil {
return false, err
return false, errorsx.WithStack(fosite.ErrServerError.WithHintf("Failed to serialize expiration struct %s", err))
}
// Set the expiration time as value, and use the lifespan of the device code as TTL.
h.RateLimiterCache.Set(keyBytes, exp, int(h.Config.GetDeviceAndUserCodeLifespan(context)))
h.RateLimiterCache.Set(keyBytes, exp, int(h.Config.GetDeviceAndUserCodeLifespan(context).Seconds()))
return false, nil
}

Expand All @@ -195,31 +197,31 @@ func (h *DefaultDeviceStrategy) ShouldRateLimit(context context.Context, code st
}

// The code is valid and enough time has passed since the last call.
if expiration.NotUntil.Before(time.Now()) {
expiration.NotUntil = h.getExpirationTime(context, expiration.Counter)
if time.Now().After(expiration.NotUntil) {
expiration.NotUntil = h.getNotUntil(context, expiration.Counter)
exp, err := h.serializeExpiration(expiration)
if err != nil {
return false, err
return false, errorsx.WithStack(fosite.ErrServerError.WithHintf("Failed to serialize expiration struct %s", err))
}
h.RateLimiterCache.Set(keyBytes, exp, int(h.Config.GetDeviceAndUserCodeLifespan(context)))
h.RateLimiterCache.Set(keyBytes, exp, int(h.Config.GetDeviceAndUserCodeLifespan(context).Seconds()))
return false, nil
}

// The token calls were made too fast, we need to double the interval period
expiration.NotUntil = h.getExpirationTime(context, expiration.Counter+1)
expiration.NotUntil = h.getNotUntil(context, expiration.Counter+1)
expiration.Counter += 1
exp, err := h.serializeExpiration(expiration)
if err != nil {
return false, err
return false, errorsx.WithStack(fosite.ErrServerError.WithHintf("Failed to serialize expiration struct %s", err))
}
h.RateLimiterCache.Set(keyBytes, exp, int(h.Config.GetDeviceAndUserCodeLifespan(context)))
h.RateLimiterCache.Set(keyBytes, exp, int(h.Config.GetDeviceAndUserCodeLifespan(context).Seconds()))

return true, nil
}

func (h *DefaultDeviceStrategy) getExpirationTime(context context.Context, multiplier int) time.Time {
func (h *DefaultDeviceStrategy) getNotUntil(context context.Context, multiplier int) time.Time {
duration := h.Config.GetDeviceAuthTokenPollingInterval(context)
expiration := time.Now().Add(duration * time.Duration(multiplier))
expiration := time.Now().Add(duration * time.Duration(multiplier)).Add(-POLLING_RATE_LIMITING_LEEWAY)
return expiration
}

Expand Down
21 changes: 12 additions & 9 deletions handler/rfc8628/token_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,6 @@ type DeviceCodeHandler struct {
func (c DeviceCodeHandler) Code(ctx context.Context, requester fosite.AccessRequester) (code string, signature string, err error) {
code = requester.GetRequestForm().Get("device_code")

shouldRateLimit, err := c.DeviceRateLimitStrategy.ShouldRateLimit(ctx, code)
// TODO(nsklikas) : should we error out or just silently log it?
if err != nil {
return "", "", err
}
if shouldRateLimit {
return "", "", errorsx.WithStack(fosite.ErrPollingRateLimited)
}

signature, err = c.DeviceCodeStrategy.DeviceCodeSignature(ctx, code)
if err != nil {
return "", "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
Expand All @@ -40,6 +31,18 @@ func (c DeviceCodeHandler) Code(ctx context.Context, requester fosite.AccessRequ
}

func (c DeviceCodeHandler) ValidateCode(ctx context.Context, requester fosite.Requester, code string) error {
shouldRateLimit, err := c.DeviceRateLimitStrategy.ShouldRateLimit(ctx, code)
// TODO(nsklikas) : should we error out or just silently log it?
BarcoMasile marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return err
}
if shouldRateLimit {
return errorsx.WithStack(fosite.ErrPollingRateLimited)
}
return nil
}

func (c DeviceCodeHandler) ValidateCodeSession(ctx context.Context, requester fosite.Requester, code string) error {
return c.DeviceCodeStrategy.ValidateDeviceCode(ctx, requester, code)
}

Expand Down
28 changes: 0 additions & 28 deletions handler/rfc8628/token_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,33 +376,6 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) {
},
expectErr: fosite.ErrServerError,
},
{
description: "should fail because device code is expired",
areq: &fosite.AccessRequest{
GrantTypes: fosite.Arguments{string(fosite.GrantTypeDeviceCode)},
Request: fosite.Request{
Form: url.Values{},
Client: &fosite.DefaultClient{
GrantTypes: fosite.Arguments{string(fosite.GrantTypeDeviceCode)},
},
Session: &DefaultDeviceFlowSession{
ExpiresAt: map[fosite.TokenType]time.Time{
fosite.DeviceCode: time.Now().Add(-time.Hour).UTC(),
},
BrowserFlowCompleted: true,
},
RequestedAt: time.Now().Add(-2 * time.Hour).UTC(),
},
},
setup: func(t *testing.T, areq *fosite.AccessRequest, config *fosite.Config) {
code, signature, err := strategy.GenerateDeviceCode(context.TODO())
require.NoError(t, err)
areq.Form.Add("device_code", code)

require.NoError(t, store.CreateDeviceCodeSession(context.Background(), signature, areq))
},
expectErr: fosite.ErrInvalidRequest,
},
{
description: "should pass with offline scope and refresh token",
areq: &fosite.AccessRequest{
Expand Down Expand Up @@ -767,7 +740,6 @@ func TestDeviceUserCodeTransactional_HandleTokenEndpointRequest(t *testing.T) {
mockCoreStore = internal.NewMockCoreStorage(ctrl)
mockDeviceCodeStore = internal.NewMockDeviceCodeStorage(ctrl)
mockDeviceRateLimitStrategy = internal.NewMockDeviceRateLimitStrategy(ctrl)
mockDeviceRateLimitStrategy.EXPECT().ShouldRateLimit(gomock.Any(), gomock.Any()).Return(false, nil).Times(1)
testCase.setup()

handler := oauth2.GenericCodeTokenEndpointHandler{
Expand Down
Loading