diff --git a/handler/oauth2/flow_authorize_code_token.go b/handler/oauth2/flow_authorize_code_token.go index 734e2ab0..1a36b605 100644 --- a/handler/oauth2/flow_authorize_code_token.go +++ b/handler/oauth2/flow_authorize_code_token.go @@ -24,6 +24,10 @@ func (c AuthorizeCodeHandler) Code(ctx context.Context, requester fosite.AccessR } func (c AuthorizeCodeHandler) ValidateCode(ctx context.Context, requester fosite.Requester, code string) error { + return nil +} + +func (c AuthorizeCodeHandler) ValidateCodeSession(ctx context.Context, requester fosite.Requester, code string) error { return c.AuthorizeCodeStrategy.ValidateAuthorizeCode(ctx, requester, code) } diff --git a/handler/oauth2/flow_authorize_code_token_test.go b/handler/oauth2/flow_authorize_code_token_test.go index aac54dfe..e7349627 100644 --- a/handler/oauth2/flow_authorize_code_token_test.go +++ b/handler/oauth2/flow_authorize_code_token_test.go @@ -64,28 +64,6 @@ func TestAuthorizeCode_PopulateTokenEndpointResponse(t *testing.T) { }, expectErr: fosite.ErrServerError, }, - { - description: "should fail because authorization code is expired", - areq: &fosite.AccessRequest{ - GrantTypes: fosite.Arguments{"authorization_code"}, - Request: fosite.Request{ - Form: url.Values{"code": []string{"foo.bar"}}, - Client: &fosite.DefaultClient{ - GrantTypes: fosite.Arguments{"authorization_code"}, - }, - Session: &fosite.DefaultSession{ - ExpiresAt: map[fosite.TokenType]time.Time{ - fosite.AuthorizeCode: time.Now().Add(-time.Hour).UTC(), - }, - }, - RequestedAt: time.Now().Add(-2 * time.Hour).UTC(), - }, - }, - setup: func(t *testing.T, areq *fosite.AccessRequest, config *fosite.Config) { - require.NoError(t, store.CreateAuthorizeCodeSession(context.Background(), "bar", areq)) - }, - expectErr: fosite.ErrInvalidRequest, - }, { description: "should pass with offline scope and refresh token", areq: &fosite.AccessRequest{ diff --git a/handler/oauth2/flow_generic_code_token.go b/handler/oauth2/flow_generic_code_token.go index 756507c2..e469e558 100644 --- a/handler/oauth2/flow_generic_code_token.go +++ b/handler/oauth2/flow_generic_code_token.go @@ -35,8 +35,11 @@ type CodeHandler interface { // Code fetches the code and code signature. Code(ctx context.Context, requester fosite.AccessRequester) (code string, signature string, err error) - // ValidateCode validates the code. + // ValidateCode validates the code. Can be used for checks that need to run before we fetch the session from the database. ValidateCode(ctx context.Context, requester fosite.Requester, code string) error + + // ValidateCodeSession validates the code session. + ValidateCodeSession(ctx context.Context, requester fosite.Requester, code string) error } // SessionHandler handles session-related operations. @@ -83,8 +86,8 @@ func (c *GenericCodeTokenEndpointHandler) PopulateTokenEndpointResponse(ctx cont return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) } - if err = c.ValidateCode(ctx, requester, code); err != nil { - return errorsx.WithStack(fosite.ErrInvalidRequest.WithWrap(err).WithDebug(err.Error())) + if err = c.ValidateCodeSession(ctx, ar, code); err != nil { + return errorsx.WithStack(err) } for _, scope := range ar.GetRequestedScopes() { @@ -166,6 +169,10 @@ func (c *GenericCodeTokenEndpointHandler) HandleTokenEndpointRequest(ctx context return err } + if err = c.ValidateCode(ctx, requester, code); err != nil { + return errorsx.WithStack(err) + } + var ar fosite.Requester if ar, err = c.Session(ctx, requester, signature); err != nil { if ar != nil && (errors.Is(err, fosite.ErrInvalidatedAuthorizeCode) || errors.Is(err, fosite.ErrInvalidatedDeviceCode)) { @@ -175,7 +182,7 @@ func (c *GenericCodeTokenEndpointHandler) HandleTokenEndpointRequest(ctx context return err } - if err = c.ValidateCode(ctx, ar, code); err != nil { + if err = c.ValidateCodeSession(ctx, ar, code); err != nil { return errorsx.WithStack(err) } diff --git a/handler/rfc8628/token_handler.go b/handler/rfc8628/token_handler.go index a8d0f8bb..f3df5668 100644 --- a/handler/rfc8628/token_handler.go +++ b/handler/rfc8628/token_handler.go @@ -22,15 +22,6 @@ type DeviceCodeHandler struct { func (c DeviceCodeHandler) Code(ctx context.Context, requester fosite.AccessRequester) (code string, signature string, err error) { code = requester.GetRequestForm().Get("device_code") - shouldRateLimit, err := c.DeviceRateLimitStrategy.ShouldRateLimit(ctx, code) - // TODO(nsklikas) : should we error out or just silently log it? - if err != nil { - return "", "", err - } - if shouldRateLimit { - return "", "", errorsx.WithStack(fosite.ErrPollingRateLimited) - } - signature, err = c.DeviceCodeStrategy.DeviceCodeSignature(ctx, code) if err != nil { return "", "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) @@ -40,6 +31,18 @@ func (c DeviceCodeHandler) Code(ctx context.Context, requester fosite.AccessRequ } func (c DeviceCodeHandler) ValidateCode(ctx context.Context, requester fosite.Requester, code string) error { + shouldRateLimit, err := c.DeviceRateLimitStrategy.ShouldRateLimit(ctx, code) + // TODO(nsklikas) : should we error out or just silently log it? + if err != nil { + return err + } + if shouldRateLimit { + return errorsx.WithStack(fosite.ErrPollingRateLimited) + } + return nil +} + +func (c DeviceCodeHandler) ValidateCodeSession(ctx context.Context, requester fosite.Requester, code string) error { return c.DeviceCodeStrategy.ValidateDeviceCode(ctx, requester, code) } diff --git a/handler/rfc8628/token_handler_test.go b/handler/rfc8628/token_handler_test.go index 01d55d8b..6e29d187 100644 --- a/handler/rfc8628/token_handler_test.go +++ b/handler/rfc8628/token_handler_test.go @@ -376,33 +376,6 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) { }, expectErr: fosite.ErrServerError, }, - { - description: "should fail because device code is expired", - areq: &fosite.AccessRequest{ - GrantTypes: fosite.Arguments{string(fosite.GrantTypeDeviceCode)}, - Request: fosite.Request{ - Form: url.Values{}, - Client: &fosite.DefaultClient{ - GrantTypes: fosite.Arguments{string(fosite.GrantTypeDeviceCode)}, - }, - Session: &DefaultDeviceFlowSession{ - ExpiresAt: map[fosite.TokenType]time.Time{ - fosite.DeviceCode: time.Now().Add(-time.Hour).UTC(), - }, - BrowserFlowCompleted: true, - }, - RequestedAt: time.Now().Add(-2 * time.Hour).UTC(), - }, - }, - setup: func(t *testing.T, areq *fosite.AccessRequest, config *fosite.Config) { - code, signature, err := strategy.GenerateDeviceCode(context.TODO()) - require.NoError(t, err) - areq.Form.Add("device_code", code) - - require.NoError(t, store.CreateDeviceCodeSession(context.Background(), signature, areq)) - }, - expectErr: fosite.ErrInvalidRequest, - }, { description: "should pass with offline scope and refresh token", areq: &fosite.AccessRequest{ @@ -767,7 +740,6 @@ func TestDeviceUserCodeTransactional_HandleTokenEndpointRequest(t *testing.T) { mockCoreStore = internal.NewMockCoreStorage(ctrl) mockDeviceCodeStore = internal.NewMockDeviceCodeStorage(ctrl) mockDeviceRateLimitStrategy = internal.NewMockDeviceRateLimitStrategy(ctrl) - mockDeviceRateLimitStrategy.EXPECT().ShouldRateLimit(gomock.Any(), gomock.Any()).Return(false, nil).Times(1) testCase.setup() handler := oauth2.GenericCodeTokenEndpointHandler{