diff --git a/compose/compose_oauth2.go b/compose/compose_oauth2.go index d405eb68..9be25fb4 100644 --- a/compose/compose_oauth2.go +++ b/compose/compose_oauth2.go @@ -24,9 +24,12 @@ func OAuth2AuthorizeExplicitAuthFactory(config fosite.Configurator, storage inte func Oauth2AuthorizeExplicitTokenFactory(config fosite.Configurator, storage interface{}, strategy interface{}) interface{} { return &oauth2.AuthorizeExplicitTokenEndpointHandler{ GenericCodeTokenEndpointHandler: oauth2.GenericCodeTokenEndpointHandler{ - CodeTokenEndpointHandler: &oauth2.AuthorizeExplicitGrantTokenHandler{ + AccessRequestValidator: &oauth2.AuthorizeExplicitGrantAccessRequestValidator{}, + CodeHandler: &oauth2.AuthorizeCodeHandler{ AuthorizeCodeStrategy: strategy.(oauth2.AuthorizeCodeStrategy), - AuthorizeCodeStorage: storage.(oauth2.AuthorizeCodeStorage), + }, + SessionHandler: &oauth2.AuthorizeExplicitGrantSessionHandler{ + AuthorizeCodeStorage: storage.(oauth2.AuthorizeCodeStorage), }, AccessTokenStrategy: strategy.(oauth2.AccessTokenStrategy), diff --git a/compose/compose_rfc8628.go b/compose/compose_rfc8628.go index 62e212f0..64ccc1e4 100644 --- a/compose/compose_rfc8628.go +++ b/compose/compose_rfc8628.go @@ -26,10 +26,13 @@ func RFC8628DeviceFactory(config fosite.Configurator, storage interface{}, strat func RFC8628DeviceAuthorizationTokenFactory(config fosite.Configurator, storage interface{}, strategy interface{}) interface{} { return &rfc8628.DeviceCodeTokenEndpointHandler{ GenericCodeTokenEndpointHandler: oauth2.GenericCodeTokenEndpointHandler{ - CodeTokenEndpointHandler: &rfc8628.DeviceTokenHandler{ + AccessRequestValidator: &rfc8628.DeviceAccessRequestValidator{}, + CodeHandler: &rfc8628.DeviceCodeHandler{ DeviceRateLimitStrategy: strategy.(rfc8628.DeviceRateLimitStrategy), DeviceCodeStrategy: strategy.(rfc8628.DeviceCodeStrategy), - DeviceCodeStorage: storage.(rfc8628.DeviceCodeStorage), + }, + SessionHandler: &rfc8628.DeviceSessionHandler{ + DeviceCodeStorage: storage.(rfc8628.DeviceCodeStorage), }, AccessTokenStrategy: strategy.(oauth2.AccessTokenStrategy), diff --git a/fosite_test.go b/fosite_test.go index 61ebe651..ded80b2c 100644 --- a/fosite_test.go +++ b/fosite_test.go @@ -26,7 +26,6 @@ func TestAuthorizeEndpointHandlers(t *testing.T) { } func TestTokenEndpointHandlers(t *testing.T) { - // h := &oauth2.AuthorizeExplicitGrantHandler{} h := &oauth2.GenericCodeTokenEndpointHandler{} hs := TokenEndpointHandlers{} hs.Append(h) diff --git a/handler.go b/handler.go index 962c2f6a..163f4af6 100644 --- a/handler.go +++ b/handler.go @@ -8,7 +8,7 @@ import ( ) type AuthorizeEndpointHandler interface { - // HandleAuthorizeEndpointRequest handles an authorize endpoint request. To extend the handler's capabilities, the http request + // HandleAuthorizeRequest handles an authorize endpoint request. To extend the handler's capabilities, the http request // is passed along, if further information retrieval is required. If the handler feels that he is not responsible for // the authorize request, he must return nil and NOT modify session nor responder neither requester. // @@ -31,12 +31,12 @@ type TokenEndpointHandler interface { // the request, this method should return ErrUnknownRequest and otherwise handle the request. HandleTokenEndpointRequest(ctx context.Context, requester AccessRequester) error - // CanSkipClientAuth indicates if client authentication can be skipped. By default, it MUST be false, unless you are + // CanSkipClientAuth indicates if client authentication can be skipped. By default it MUST be false, unless you are // implementing extension grant type, which allows unauthenticated client. CanSkipClientAuth must be called // before HandleTokenEndpointRequest to decide, if AccessRequester will contain authenticated client. CanSkipClientAuth(ctx context.Context, requester AccessRequester) bool - // CanHandleTokenEndpointRequest indicates, if TokenEndpointHandler can handle this request or not. If true, + // CanHandleRequest indicates, if TokenEndpointHandler can handle this request or not. If true, // HandleTokenEndpointRequest can be called. CanHandleTokenEndpointRequest(ctx context.Context, requester AccessRequester) bool } @@ -61,7 +61,7 @@ type RevocationHandler interface { // PushedAuthorizeEndpointHandler is the interface that handles PAR (https://datatracker.ietf.org/doc/html/rfc9126) type PushedAuthorizeEndpointHandler interface { - // HandlePushedAuthorizeEndpointRequest handles a pushed authorize endpoint request. To extend the handler's capabilities, the http request + // HandlePushedAuthorizeRequest handles a pushed authorize endpoint request. To extend the handler's capabilities, the http request // is passed along, if further information retrieval is required. If the handler feels that he is not responsible for // the pushed authorize request, he must return nil and NOT modify session nor responder neither requester. HandlePushedAuthorizeEndpointRequest(ctx context.Context, requester AuthorizeRequester, responder PushedAuthorizeResponder) error diff --git a/handler/oauth2/flow_authorize_code_token.go b/handler/oauth2/flow_authorize_code_token.go index 29085278..69cb0e76 100644 --- a/handler/oauth2/flow_authorize_code_token.go +++ b/handler/oauth2/flow_authorize_code_token.go @@ -6,54 +6,94 @@ package oauth2 import ( "context" + "github.com/pkg/errors" + "github.com/ory/x/errorsx" "github.com/ory/fosite" ) -// AuthorizeExplicitGrantTokenHandler is a response handler for the Authorize Code grant using the explicit grant type -// as defined in https://tools.ietf.org/html/rfc6749#section-4.1 -type AuthorizeExplicitGrantTokenHandler struct { +type AuthorizeCodeHandler struct { AuthorizeCodeStrategy AuthorizeCodeStrategy - AuthorizeCodeStorage AuthorizeCodeStorage } -func (c AuthorizeExplicitGrantTokenHandler) ValidateGrantTypes(ctx context.Context, requester fosite.AccessRequester) error { - if !requester.GetClient().GetGrantTypes().Has("authorization_code") { - return errorsx.WithStack(fosite.ErrUnauthorizedClient.WithHint("The OAuth 2.0 Client is not allowed to use authorization grant \"authorization_code\".")) - } - - return nil +func (c AuthorizeCodeHandler) Code(ctx context.Context, requester fosite.AccessRequester) (code string, signature string, err error) { + code = requester.GetRequestForm().Get("code") + signature = c.AuthorizeCodeStrategy.AuthorizeCodeSignature(ctx, code) + return code, signature, nil } -func (c AuthorizeExplicitGrantTokenHandler) ValidateCode(ctx context.Context, requester fosite.AccessRequester, code string) error { +func (c AuthorizeCodeHandler) ValidateCode(ctx context.Context, requester fosite.AccessRequester, code string) error { return c.AuthorizeCodeStrategy.ValidateAuthorizeCode(ctx, requester, code) } -func (c AuthorizeExplicitGrantTokenHandler) GetCodeAndSession(ctx context.Context, requester fosite.AccessRequester) (code string, signature string, authorizeRequest fosite.Requester, err error) { - code = requester.GetRequestForm().Get("code") - signature = c.AuthorizeCodeStrategy.AuthorizeCodeSignature(ctx, code) - req, err := c.AuthorizeCodeStorage.GetAuthorizeCodeSession(ctx, signature, requester.GetSession()) - return code, signature, req, err +type AuthorizeExplicitGrantSessionHandler struct { + AuthorizeCodeStorage AuthorizeCodeStorage } -func (c AuthorizeExplicitGrantTokenHandler) InvalidateSession(ctx context.Context, signature string) error { - return c.AuthorizeCodeStorage.InvalidateAuthorizeCodeSession(ctx, signature) +func (s AuthorizeExplicitGrantSessionHandler) Session(ctx context.Context, requester fosite.AccessRequester, codeSignature string) (fosite.Requester, error) { + req, err := s.AuthorizeCodeStorage.GetAuthorizeCodeSession(ctx, codeSignature, requester.GetSession()) + + if err != nil && errors.Is(err, fosite.ErrInvalidatedAuthorizeCode) { + if req == nil { + return req, fosite.ErrServerError. + WithHint("Misconfigured code lead to an error that prohibited the OAuth 2.0 Framework from processing this request."). + WithDebug("\"GetAuthorizeCodeSession\" must return a value for \"fosite.Requester\" when returning \"ErrInvalidatedAuthorizeCode\".") + } + + return req, err + } + + if err != nil && errors.Is(err, fosite.ErrNotFound) { + return nil, errorsx.WithStack(fosite.ErrInvalidGrant.WithWrap(err).WithDebug(err.Error())) + } + + if err != nil { + return nil, errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + + return req, err } -func (c AuthorizeExplicitGrantTokenHandler) CanSkipClientAuth(ctx context.Context, requester fosite.AccessRequester) bool { - return false +func (s AuthorizeExplicitGrantSessionHandler) InvalidateSession(ctx context.Context, codeSignature string) error { + return s.AuthorizeCodeStorage.InvalidateAuthorizeCodeSession(ctx, codeSignature) } -func (c AuthorizeExplicitGrantTokenHandler) CanHandleTokenEndpointRequest(ctx context.Context, requester fosite.AccessRequester) bool { +type AuthorizeExplicitGrantAccessRequestValidator struct{} + +func (v AuthorizeExplicitGrantAccessRequestValidator) ValidateRequest(requester fosite.AccessRequester) bool { return requester.GetGrantTypes().ExactOne("authorization_code") } +func (v AuthorizeExplicitGrantAccessRequestValidator) ValidateClientAuth(requester fosite.AccessRequester) bool { + return false +} + +func (v AuthorizeExplicitGrantAccessRequestValidator) ValidateGrantTypes(requester fosite.AccessRequester) error { + if !requester.GetClient().GetGrantTypes().Has("authorization_code") { + return errorsx.WithStack(fosite.ErrUnauthorizedClient.WithHint("The OAuth 2.0 Client is not allowed to use authorization grant \"authorization_code\".")) + } + + return nil +} + +func (v AuthorizeExplicitGrantAccessRequestValidator) ValidateRedirectURI(accessRequester fosite.AccessRequester, authorizeRequester fosite.Requester) error { + forcedRedirectURI := authorizeRequester.GetRequestForm().Get("redirect_uri") + requestedRedirectURI := accessRequester.GetRequestForm().Get("redirect_uri") + if forcedRedirectURI != "" && forcedRedirectURI != requestedRedirectURI { + return errorsx.WithStack(fosite.ErrInvalidGrant.WithHint("The \"redirect_uri\" from this request does not match the one from the authorize request.")) + } + + return nil +} + type AuthorizeExplicitTokenEndpointHandler struct { GenericCodeTokenEndpointHandler } var ( - _ CodeTokenEndpointHandler = (*AuthorizeExplicitGrantTokenHandler)(nil) + _ AccessRequestValidator = (*AuthorizeExplicitGrantAccessRequestValidator)(nil) + _ CodeHandler = (*AuthorizeCodeHandler)(nil) + _ SessionHandler = (*AuthorizeExplicitGrantSessionHandler)(nil) _ fosite.TokenEndpointHandler = (*AuthorizeExplicitTokenEndpointHandler)(nil) ) diff --git a/handler/oauth2/flow_authorize_code_token_test.go b/handler/oauth2/flow_authorize_code_token_test.go index 41fa70db..59d89ab3 100644 --- a/handler/oauth2/flow_authorize_code_token_test.go +++ b/handler/oauth2/flow_authorize_code_token_test.go @@ -208,9 +208,12 @@ func TestAuthorizeCode_PopulateTokenEndpointResponse(t *testing.T) { RefreshTokenScopes: []string{"offline"}, } h = GenericCodeTokenEndpointHandler{ - CodeTokenEndpointHandler: &AuthorizeExplicitGrantTokenHandler{ + AccessRequestValidator: &AuthorizeExplicitGrantAccessRequestValidator{}, + CodeHandler: &AuthorizeCodeHandler{ AuthorizeCodeStrategy: strategy, - AuthorizeCodeStorage: store, + }, + SessionHandler: &AuthorizeExplicitGrantSessionHandler{ + AuthorizeCodeStorage: store, }, AccessTokenStrategy: strategy, RefreshTokenStrategy: strategy, @@ -252,9 +255,12 @@ func TestAuthorizeCode_HandleTokenEndpointRequest(t *testing.T) { AuthorizeCodeLifespan: time.Minute, } h := GenericCodeTokenEndpointHandler{ - CodeTokenEndpointHandler: &AuthorizeExplicitGrantTokenHandler{ + AccessRequestValidator: &AuthorizeExplicitGrantAccessRequestValidator{}, + CodeHandler: &AuthorizeCodeHandler{ AuthorizeCodeStrategy: strategy, - AuthorizeCodeStorage: store, + }, + SessionHandler: &AuthorizeExplicitGrantSessionHandler{ + AuthorizeCodeStorage: store, }, TokenRevocationStorage: store, Config: config, @@ -668,8 +674,11 @@ func TestAuthorizeCodeTransactional_HandleTokenEndpointRequest(t *testing.T) { AuthorizeCodeLifespan: time.Minute, } handler := GenericCodeTokenEndpointHandler{ - CodeTokenEndpointHandler: &AuthorizeExplicitGrantTokenHandler{ + AccessRequestValidator: &AuthorizeExplicitGrantAccessRequestValidator{}, + CodeHandler: &AuthorizeCodeHandler{ AuthorizeCodeStrategy: &strategy, + }, + SessionHandler: &AuthorizeExplicitGrantSessionHandler{ AuthorizeCodeStorage: authorizeTransactionalStore{ mockTransactional, mockAuthorizeStore, diff --git a/handler/oauth2/flow_generic_code_token.go b/handler/oauth2/flow_generic_code_token.go index e3df13fb..b8584e36 100644 --- a/handler/oauth2/flow_generic_code_token.go +++ b/handler/oauth2/flow_generic_code_token.go @@ -18,31 +18,46 @@ import ( "github.com/ory/fosite" ) -// CodeTokenEndpointHandler handles the differences between Authorize code grant and extended grant types. -type CodeTokenEndpointHandler interface { - // ValidateGrantTypes validates the authorization grant type. - ValidateGrantTypes(ctx context.Context, requester fosite.AccessRequester) error +// AccessRequestValidator handles various validations in the access request handling. +type AccessRequestValidator interface { + // ValidateRequest validates if the access request should be handled. + ValidateRequest(requester fosite.AccessRequester) bool - // ValidateCode validates the code used in the authorization flow. - ValidateCode(ctx context.Context, requester fosite.AccessRequester, code string) error + // ValidateClientAuth validates if the client authentication is required. + ValidateClientAuth(requester fosite.AccessRequester) bool + + // ValidateGrantTypes validates the grant types in the access request. + ValidateGrantTypes(requester fosite.AccessRequester) error - // GetCodeAndSession retrieves the code, the code signature, and the request session. - GetCodeAndSession(ctx context.Context, requester fosite.AccessRequester) (code string, signature string, authorizeRequest fosite.Requester, err error) + // ValidateRedirectURI validates the redirect uri in the access request. + ValidateRedirectURI(accessRequester fosite.AccessRequester, authorizeRequester fosite.Requester) error +} - // InvalidateSession invalidates the code once the code is used. - InvalidateSession(ctx context.Context, signature string) error +// CodeHandler handles authorization/device code related operations. +type CodeHandler interface { + // Code fetches the code and code signature. + Code(ctx context.Context, requester fosite.AccessRequester) (code string, signature string, err error) + + // ValidateCode validates the code. + ValidateCode(ctx context.Context, requester fosite.AccessRequester, code string) error +} - // CanSkipClientAuth indicates if client authentication can be skipped. By default, it MUST be false, unless you are - // implementing extension grant type, which allows unauthenticated client. CanSkipClientAuth must be called - // before HandleTokenEndpointRequest to decide, if AccessRequester will contain authenticated client. - CanSkipClientAuth(ctx context.Context, requester fosite.AccessRequester) bool +// SessionHandler handles session-related operations. +type SessionHandler interface { + // Session fetches the authorized request. + Session(ctx context.Context, requester fosite.AccessRequester, codeSignature string) (fosite.Requester, error) - // CanHandleTokenEndpointRequest indicates if GenericCodeTokenEndpointHandler can handle this request or not. - CanHandleTokenEndpointRequest(ctx context.Context, requester fosite.AccessRequester) bool + // InvalidateSession invalidates the code and session. + InvalidateSession(ctx context.Context, codeSignature string) error } +// GenericCodeTokenEndpointHandler is a token response handler for +// - the Authorize Code grant using the explicit grant type as defined in https://tools.ietf.org/html/rfc6749#section-4.1 +// - the Device Authorization Grant as defined in https://www.rfc-editor.org/rfc/rfc8628 type GenericCodeTokenEndpointHandler struct { - CodeTokenEndpointHandler + AccessRequestValidator + CodeHandler + SessionHandler AccessTokenStrategy AccessTokenStrategy RefreshTokenStrategy RefreshTokenStrategy @@ -60,7 +75,12 @@ func (c *GenericCodeTokenEndpointHandler) PopulateTokenEndpointResponse(ctx cont return errorsx.WithStack(fosite.ErrUnknownRequest) } - code, signature, ar, err := c.GetCodeAndSession(ctx, requester) + code, signature, err := c.Code(ctx, requester) + if err != nil { + return err + } + + ar, err := c.Session(ctx, requester, signature) if err != nil { return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) } @@ -137,30 +157,21 @@ func (c *GenericCodeTokenEndpointHandler) HandleTokenEndpointRequest(ctx context return errorsx.WithStack(errorsx.WithStack(fosite.ErrUnknownRequest)) } - if err := c.ValidateGrantTypes(ctx, requester); err != nil { + if err := c.ValidateGrantTypes(requester); err != nil { return err } - code, _, ar, err := c.GetCodeAndSession(ctx, requester) + code, signature, err := c.Code(ctx, requester) if err != nil { - switch { - case errors.Is(err, fosite.ErrInvalidatedAuthorizeCode), errors.Is(err, fosite.ErrInvalidatedDeviceCode): - if ar == nil { - return fosite.ErrServerError. - WithHint("Misconfigured code lead to an error that prohibited the OAuth 2.0 Framework from processing this request."). - WithDebug("getCodeSession must return a value for \"fosite.Requester\" when returning \"ErrInvalidatedAuthorizeCode\" or \"ErrInvalidatedDeviceCode\".") - } + return err + } - return c.revokeTokens(ctx, requester.GetID()) - case errors.Is(err, fosite.ErrAuthorizationPending): - return err - case errors.Is(err, fosite.ErrPollingRateLimited): - return errorsx.WithStack(err) - case errors.Is(err, fosite.ErrNotFound): - return errorsx.WithStack(fosite.ErrInvalidGrant.WithWrap(err).WithDebug(err.Error())) - default: - return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) - } + ar, err := c.Session(ctx, requester, signature) + if ar != nil && err != nil && (errors.Is(err, fosite.ErrInvalidatedAuthorizeCode) || errors.Is(err, fosite.ErrInvalidatedDeviceCode)) { + return c.revokeTokens(ctx, requester.GetID()) + } + if err != nil { + return err } if err = c.ValidateCode(ctx, requester, code); err != nil { @@ -180,10 +191,8 @@ func (c *GenericCodeTokenEndpointHandler) HandleTokenEndpointRequest(ctx context return errorsx.WithStack(fosite.ErrInvalidGrant.WithHint("The OAuth 2.0 Client ID from this request does not match the one from the authorize request.")) } - forcedRedirectURI := ar.GetRequestForm().Get("redirect_uri") - requestedRedirectURI := requester.GetRequestForm().Get("redirect_uri") - if forcedRedirectURI != "" && forcedRedirectURI != requestedRedirectURI { - return errorsx.WithStack(fosite.ErrInvalidGrant.WithHint("The \"redirect_uri\" from this request does not match the one from the authorize request.")) + if err = c.ValidateRedirectURI(requester, ar); err != nil { + return err } // Checking of POST client_id skipped, because @@ -204,11 +213,11 @@ func (c *GenericCodeTokenEndpointHandler) HandleTokenEndpointRequest(ctx context } func (c *GenericCodeTokenEndpointHandler) CanSkipClientAuth(ctx context.Context, requester fosite.AccessRequester) bool { - return c.CodeTokenEndpointHandler.CanSkipClientAuth(ctx, requester) + return c.ValidateClientAuth(requester) } func (c *GenericCodeTokenEndpointHandler) CanHandleTokenEndpointRequest(ctx context.Context, requester fosite.AccessRequester) bool { - return c.CodeTokenEndpointHandler.CanHandleTokenEndpointRequest(ctx, requester) + return c.ValidateRequest(requester) } func (c *GenericCodeTokenEndpointHandler) canIssueRefreshToken(ctx context.Context, requester fosite.Requester) bool { diff --git a/handler/rfc8628/token_handler.go b/handler/rfc8628/token_handler.go index 40c27f2a..62569854 100644 --- a/handler/rfc8628/token_handler.go +++ b/handler/rfc8628/token_handler.go @@ -6,69 +6,99 @@ package rfc8628 import ( "context" + "github.com/pkg/errors" + "github.com/ory/x/errorsx" "github.com/ory/fosite" "github.com/ory/fosite/handler/oauth2" ) -// DeviceTokenHandler is a token response handler for the Device Code introduced in the Device Authorize Grant -// as defined in https://www.rfc-editor.org/rfc/rfc8628 -type DeviceTokenHandler struct { +type DeviceCodeHandler struct { DeviceRateLimitStrategy DeviceRateLimitStrategy DeviceCodeStrategy DeviceCodeStrategy - DeviceCodeStorage DeviceCodeStorage } -func (c DeviceTokenHandler) ValidateGrantTypes(ctx context.Context, requester fosite.AccessRequester) error { - if !requester.GetClient().GetGrantTypes().Has(string(fosite.GrantTypeDeviceCode)) { - return errorsx.WithStack(fosite.ErrUnauthorizedClient.WithHint("The OAuth 2.0 Client is not allowed to use authorization grant \"urn:ietf:params:oauth:grant-type:device_code\".")) +func (c DeviceCodeHandler) Code(ctx context.Context, requester fosite.AccessRequester) (code string, signature string, err error) { + code = requester.GetRequestForm().Get("device_code") + + if c.DeviceRateLimitStrategy.ShouldRateLimit(ctx, code) { + return "", "", errorsx.WithStack(fosite.ErrPollingRateLimited) } - return nil + signature, err = c.DeviceCodeStrategy.DeviceCodeSignature(ctx, code) + if err != nil { + return "", "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + + return } -func (c DeviceTokenHandler) ValidateCode(ctx context.Context, requester fosite.AccessRequester, code string) error { +func (c DeviceCodeHandler) ValidateCode(ctx context.Context, requester fosite.AccessRequester, code string) error { return c.DeviceCodeStrategy.ValidateDeviceCode(ctx, requester, code) } -func (c DeviceTokenHandler) GetCodeAndSession(ctx context.Context, requester fosite.AccessRequester) (code string, signature string, authorizeRequest fosite.Requester, err error) { - code = requester.GetRequestForm().Get("device_code") +type DeviceSessionHandler struct { + DeviceCodeStorage DeviceCodeStorage +} - if c.DeviceRateLimitStrategy.ShouldRateLimit(ctx, code) { - return "", "", nil, fosite.ErrPollingRateLimited +func (s DeviceSessionHandler) Session(ctx context.Context, requester fosite.AccessRequester, codeSignature string) (fosite.Requester, error) { + req, err := s.DeviceCodeStorage.GetDeviceCodeSession(ctx, codeSignature, requester.GetSession()) + + if err != nil && errors.Is(err, fosite.ErrInvalidatedDeviceCode) { + if req == nil { + return req, fosite.ErrServerError. + WithHint("Misconfigured code lead to an error that prohibited the OAuth 2.0 Framework from processing this request."). + WithDebug("\"GetDeviceCodeSession\" must return a value for \"fosite.Requester\" when returning \"ErrInvalidatedDeviceCode\".") + } + + return req, nil } - signature, err = c.DeviceCodeStrategy.DeviceCodeSignature(ctx, code) - if err != nil { - return "", "", nil, errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + if err != nil && errors.Is(err, fosite.ErrNotFound) { + return nil, errorsx.WithStack(fosite.ErrInvalidGrant.WithWrap(err).WithDebug(err.Error())) } - req, err := c.DeviceCodeStorage.GetDeviceCodeSession(ctx, signature, requester.GetSession()) if err != nil { - return "", "", nil, err + return nil, errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) } - return code, signature, req, nil + return req, err } -func (c DeviceTokenHandler) InvalidateSession(ctx context.Context, signature string) error { - return c.DeviceCodeStorage.InvalidateDeviceCodeSession(ctx, signature) +func (s DeviceSessionHandler) InvalidateSession(ctx context.Context, codeSignature string) error { + return s.DeviceCodeStorage.InvalidateDeviceCodeSession(ctx, codeSignature) } -func (c DeviceTokenHandler) CanSkipClientAuth(ctx context.Context, requester fosite.AccessRequester) bool { +type DeviceAccessRequestValidator struct{} + +func (v DeviceAccessRequestValidator) ValidateRequest(requester fosite.AccessRequester) bool { return requester.GetGrantTypes().ExactOne(string(fosite.GrantTypeDeviceCode)) } -func (c DeviceTokenHandler) CanHandleTokenEndpointRequest(ctx context.Context, requester fosite.AccessRequester) bool { +func (v DeviceAccessRequestValidator) ValidateClientAuth(requester fosite.AccessRequester) bool { return requester.GetGrantTypes().ExactOne(string(fosite.GrantTypeDeviceCode)) } +func (v DeviceAccessRequestValidator) ValidateGrantTypes(requester fosite.AccessRequester) error { + if !requester.GetClient().GetGrantTypes().Has(string(fosite.GrantTypeDeviceCode)) { + return errorsx.WithStack(fosite.ErrUnauthorizedClient.WithHint("The OAuth 2.0 Client is not allowed to use authorization grant \"urn:ietf:params:oauth:grant-type:device_code\".")) + } + + return nil +} + +func (v DeviceAccessRequestValidator) ValidateRedirectURI(accessRequester fosite.AccessRequester, authorizeRequester fosite.Requester) error { + return nil +} + type DeviceCodeTokenEndpointHandler struct { oauth2.GenericCodeTokenEndpointHandler } var ( - _ oauth2.CodeTokenEndpointHandler = (*DeviceTokenHandler)(nil) - _ fosite.TokenEndpointHandler = (*DeviceCodeTokenEndpointHandler)(nil) + _ oauth2.AccessRequestValidator = (*DeviceAccessRequestValidator)(nil) + _ oauth2.CodeHandler = (*DeviceCodeHandler)(nil) + _ oauth2.SessionHandler = (*DeviceSessionHandler)(nil) + _ fosite.TokenEndpointHandler = (*DeviceCodeTokenEndpointHandler)(nil) ) diff --git a/handler/rfc8628/token_handler_test.go b/handler/rfc8628/token_handler_test.go index 3068a1b2..2773ad5f 100644 --- a/handler/rfc8628/token_handler_test.go +++ b/handler/rfc8628/token_handler_test.go @@ -217,10 +217,13 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) { RefreshTokenScopes: []string{"offline"}, } h = oauth2.GenericCodeTokenEndpointHandler{ - CodeTokenEndpointHandler: &DeviceTokenHandler{ + AccessRequestValidator: &DeviceAccessRequestValidator{}, + CodeHandler: &DeviceCodeHandler{ DeviceRateLimitStrategy: strategy, DeviceCodeStrategy: strategy, - DeviceCodeStorage: store, + }, + SessionHandler: &DeviceSessionHandler{ + DeviceCodeStorage: store, }, AccessTokenStrategy: strategy.CoreStrategy, RefreshTokenStrategy: strategy.CoreStrategy, @@ -262,10 +265,13 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) { store := storage.NewMemoryStore() h := oauth2.GenericCodeTokenEndpointHandler{ - CodeTokenEndpointHandler: &DeviceTokenHandler{ + AccessRequestValidator: &DeviceAccessRequestValidator{}, + CodeHandler: &DeviceCodeHandler{ DeviceRateLimitStrategy: strategy, - DeviceCodeStrategy: strategy.RFC8628CodeStrategy, - DeviceCodeStorage: store, + DeviceCodeStrategy: strategy, + }, + SessionHandler: &DeviceSessionHandler{ + DeviceCodeStorage: store, }, CoreStorage: store, AccessTokenStrategy: strategy.CoreStrategy, @@ -656,9 +662,12 @@ func TestDeviceUserCodeTransactional_HandleTokenEndpointRequest(t *testing.T) { testCase.setup() handler := oauth2.GenericCodeTokenEndpointHandler{ - CodeTokenEndpointHandler: &DeviceTokenHandler{ + AccessRequestValidator: &DeviceAccessRequestValidator{}, + CodeHandler: &DeviceCodeHandler{ DeviceRateLimitStrategy: mockDeviceRateLimitStrategy, DeviceCodeStrategy: &deviceStrategy, + }, + SessionHandler: &DeviceSessionHandler{ DeviceCodeStorage: deviceTransactionalStore{ mockTransactional, mockDeviceCodeStore, diff --git a/integration/authorize_device_grant_request_test.go b/integration/authorize_device_grant_request_test.go index 2deb3067..1497f3af 100644 --- a/integration/authorize_device_grant_request_test.go +++ b/integration/authorize_device_grant_request_test.go @@ -6,7 +6,6 @@ package integration_test import ( "context" "fmt" - "net/url" "testing" "github.com/ory/fosite" @@ -25,6 +24,7 @@ func TestDeviceFlow(t *testing.T) { hmacStrategy, } { runDeviceFlowTest(t, strategy) + runDeviceFlowAccessTokenTest(t, strategy) } } @@ -39,10 +39,12 @@ func runDeviceFlowTest(t *testing.T, strategy interface{}) { Username: "peteru", }, } - fc := new(fosite.Config) - fc.DeviceVerificationURL = "https://example.com/" - fc.RefreshTokenLifespan = -1 - fc.GlobalSecret = []byte("some-secret-thats-random-some-secret-thats-random-") + + fc := &fosite.Config{ + DeviceVerificationURL: "https://example.com/", + RefreshTokenLifespan: -1, + GlobalSecret: []byte("some-secret-thats-random-some-secret-thats-random-"), + } f := compose.ComposeAllEnabled(fc, fositeStore, gen.MustRSAKey()) ts := mockServer(t, f, session) defer ts.Close() @@ -59,9 +61,8 @@ func runDeviceFlowTest(t *testing.T, strategy interface{}) { description string setup func() err bool - client fosite.Client check func(t *testing.T, token *goauth.DeviceAuthResponse, err error) - params url.Values + cleanUp func() }{ { description: "should fail with invalid_grant", @@ -72,6 +73,9 @@ func runDeviceFlowTest(t *testing.T, strategy interface{}) { check: func(t *testing.T, token *goauth.DeviceAuthResponse, err error) { assert.ErrorContains(t, err, "invalid_grant") }, + cleanUp: func() { + fositeStore.Clients["device-client"].(*fosite.DefaultClient).GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"} + }, }, { description: "should fail with invalid_scope", @@ -83,6 +87,10 @@ func runDeviceFlowTest(t *testing.T, strategy interface{}) { check: func(t *testing.T, token *goauth.DeviceAuthResponse, err error) { assert.ErrorContains(t, err, "invalid_scope") }, + cleanUp: func() { + oauthClient.Scopes = []string{} + fositeStore.Clients["device-client"].(*fosite.DefaultClient).Scopes = []string{"fosite", "offline", "openid"} + }, }, { description: "should fail with invalid_client", @@ -93,6 +101,9 @@ func runDeviceFlowTest(t *testing.T, strategy interface{}) { check: func(t *testing.T, token *goauth.DeviceAuthResponse, err error) { assert.ErrorContains(t, err, "invalid_client") }, + cleanUp: func() { + oauthClient.ClientID = "device-client" + }, }, { description: "should pass", @@ -101,24 +112,6 @@ func runDeviceFlowTest(t *testing.T, strategy interface{}) { }, } { t.Run(fmt.Sprintf("case=%d description=%s", k, c.description), func(t *testing.T) { - // Restore client - fositeStore.Clients["device-client"] = &fosite.DefaultClient{ - ID: "device-client", - Secret: []byte(`$2a$10$IxMdI6d.LIRZPpSfEwNoeu4rY3FhDREsxFJXikcgdRRAStxUlsuEO`), // = "foobar" - GrantTypes: []string{"urn:ietf:params:oauth:grant-type:device_code"}, - Scopes: []string{"fosite", "offline", "openid"}, - Audience: []string{tokenURL}, - Public: true, - } - oauthClient = &goauth.Config{ - ClientID: "device-client", - ClientSecret: "foobar", - Endpoint: goauth.Endpoint{ - TokenURL: ts.URL + tokenRelativePath, - DeviceAuthURL: ts.URL + deviceAuthRelativePath, - }, - } - c.setup() resp, err := oauthClient.DeviceAuth(context.Background()) @@ -135,6 +128,123 @@ func runDeviceFlowTest(t *testing.T, strategy interface{}) { c.check(t, resp, err) } + if c.cleanUp != nil { + c.cleanUp() + } + + t.Logf("Passed test case %d", k) + }) + } +} + +func runDeviceFlowAccessTokenTest(t *testing.T, strategy interface{}) { + session := &defaultSession{ + DefaultSession: &openid.DefaultSession{ + Claims: &jwt.IDTokenClaims{ + Subject: "peter", + }, + Headers: &jwt.Headers{}, + Subject: "peter", + Username: "peteru", + }, + } + + fc := &fosite.Config{ + DeviceVerificationURL: "https://example.com/", + RefreshTokenLifespan: -1, + GlobalSecret: []byte("some-secret-thats-random-some-secret-thats-random-"), + DeviceAuthTokenPollingInterval: -1, + } + f := compose.ComposeAllEnabled(fc, fositeStore, gen.MustRSAKey()) + ts := mockServer(t, f, session) + defer ts.Close() + + oauthClient := &goauth.Config{ + ClientID: "device-client", + ClientSecret: "foobar", + Endpoint: goauth.Endpoint{ + TokenURL: ts.URL + tokenRelativePath, + DeviceAuthURL: ts.URL + deviceAuthRelativePath, + }, + } + resp, _ := oauthClient.DeviceAuth(context.Background()) + + for k, c := range []struct { + description string + setup func() + params []goauth.AuthCodeOption + err bool + check func(t *testing.T, token *goauth.Token, err error) + cleanUp func() + }{ + { + description: "should fail with invalid grant type", + setup: func() { + }, + params: []goauth.AuthCodeOption{goauth.SetAuthURLParam("grant_type", "invalid_grant_type")}, + err: true, + check: func(t *testing.T, token *goauth.Token, err error) { + assert.ErrorContains(t, err, "invalid_request") + }, + }, + { + description: "should fail with unauthorized client", + setup: func() { + fositeStore.Clients["device-client"].(*fosite.DefaultClient).GrantTypes = []string{"authorization_code"} + }, + params: []goauth.AuthCodeOption{}, + err: true, + check: func(t *testing.T, token *goauth.Token, err error) { + assert.ErrorContains(t, err, "unauthorized_client") + }, + cleanUp: func() { + fositeStore.Clients["device-client"].(*fosite.DefaultClient).GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"} + }, + }, + { + description: "should fail with invalid device code", + setup: func() {}, + params: []goauth.AuthCodeOption{goauth.SetAuthURLParam("device_code", "invalid_device_code")}, + err: true, + check: func(t *testing.T, token *goauth.Token, err error) { + assert.ErrorContains(t, err, "invalid_grant") + }, + }, + { + description: "should fail with invalid client id", + setup: func() { + oauthClient.ClientID = "invalid_client_id" + }, + err: true, + check: func(t *testing.T, token *goauth.Token, err error) { + assert.ErrorContains(t, err, "unauthorized_client") + }, + cleanUp: func() { + oauthClient.ClientID = "device-client" + }, + }, + { + description: "should pass", + setup: func() {}, + err: false, + }, + } { + t.Run(fmt.Sprintf("case=%d description=%s", k, c.description), func(t *testing.T) { + c.setup() + + token, err := oauthClient.DeviceAccessToken(context.Background(), resp, c.params...) + if !c.err { + assert.NotEmpty(t, token.AccessToken) + } + + if c.check != nil { + c.check(t, token, err) + } + + if c.cleanUp != nil { + c.cleanUp() + } + t.Logf("Passed test case %d", k) }) } diff --git a/integration/helper_setup_test.go b/integration/helper_setup_test.go index f33b1894..ab1384ec 100644 --- a/integration/helper_setup_test.go +++ b/integration/helper_setup_test.go @@ -83,6 +83,14 @@ var fositeStore = &storage.MemoryStore{ Scopes: []string{"fosite", "offline", "openid"}, Audience: []string{tokenURL}, }, + "device-client": &fosite.DefaultClient{ + ID: "device-client", + Secret: []byte(`$2a$10$IxMdI6d.LIRZPpSfEwNoeu4rY3FhDREsxFJXikcgdRRAStxUlsuEO`), // = "foobar" + GrantTypes: []string{"urn:ietf:params:oauth:grant-type:device_code"}, + Scopes: []string{"fosite", "offline", "openid"}, + Audience: []string{tokenURL}, + Public: true, + }, }, Users: map[string]storage.MemoryUserRelation{ "peter": { @@ -190,16 +198,18 @@ var hmacStrategy = &oauth2.HMACSHAStrategy{ }, } -var defaultRSAKey = gen.MustRSAKey() -var jwtStrategy = &oauth2.DefaultJWTStrategy{ - Signer: &jwt.DefaultSigner{ - GetPrivateKey: func(ctx context.Context) (interface{}, error) { - return defaultRSAKey, nil +var ( + defaultRSAKey = gen.MustRSAKey() + jwtStrategy = &oauth2.DefaultJWTStrategy{ + Signer: &jwt.DefaultSigner{ + GetPrivateKey: func(ctx context.Context) (interface{}, error) { + return defaultRSAKey, nil + }, }, - }, - Config: &fosite.Config{}, - HMACSHAStrategy: hmacStrategy, -} + Config: &fosite.Config{}, + HMACSHAStrategy: hmacStrategy, + } +) func mockServer(t *testing.T, f fosite.OAuth2Provider, session fosite.Session) *httptest.Server { router := mux.NewRouter()