Skip to content

Commit

Permalink
fix: do not validate request when creating response
Browse files Browse the repository at this point in the history
  • Loading branch information
nsklikas committed Mar 29, 2024
1 parent 3342f41 commit 322726b
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 63 deletions.
4 changes: 4 additions & 0 deletions handler/oauth2/flow_authorize_code_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ func (c AuthorizeCodeHandler) Code(ctx context.Context, requester fosite.AccessR
}

func (c AuthorizeCodeHandler) ValidateCode(ctx context.Context, requester fosite.Requester, code string) error {
return nil
}

func (c AuthorizeCodeHandler) ValidateCodeSession(ctx context.Context, requester fosite.Requester, code string) error {
return c.AuthorizeCodeStrategy.ValidateAuthorizeCode(ctx, requester, code)
}

Expand Down
22 changes: 0 additions & 22 deletions handler/oauth2/flow_authorize_code_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,6 @@ func TestAuthorizeCode_PopulateTokenEndpointResponse(t *testing.T) {
},
expectErr: fosite.ErrServerError,
},
{
description: "should fail because authorization code is expired",
areq: &fosite.AccessRequest{
GrantTypes: fosite.Arguments{"authorization_code"},
Request: fosite.Request{
Form: url.Values{"code": []string{"foo.bar"}},
Client: &fosite.DefaultClient{
GrantTypes: fosite.Arguments{"authorization_code"},
},
Session: &fosite.DefaultSession{
ExpiresAt: map[fosite.TokenType]time.Time{
fosite.AuthorizeCode: time.Now().Add(-time.Hour).UTC(),
},
},
RequestedAt: time.Now().Add(-2 * time.Hour).UTC(),
},
},
setup: func(t *testing.T, areq *fosite.AccessRequest, config *fosite.Config) {
require.NoError(t, store.CreateAuthorizeCodeSession(context.Background(), "bar", areq))
},
expectErr: fosite.ErrInvalidRequest,
},
{
description: "should pass with offline scope and refresh token",
areq: &fosite.AccessRequest{
Expand Down
15 changes: 11 additions & 4 deletions handler/oauth2/flow_generic_code_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ type CodeHandler interface {
// Code fetches the code and code signature.
Code(ctx context.Context, requester fosite.AccessRequester) (code string, signature string, err error)

// ValidateCode validates the code.
// ValidateCode validates the code. Can be used for checks that need to run before we fetch the session from the database.
ValidateCode(ctx context.Context, requester fosite.Requester, code string) error

// ValidateCodeSession validates the code session.
ValidateCodeSession(ctx context.Context, requester fosite.Requester, code string) error
}

// SessionHandler handles session-related operations.
Expand Down Expand Up @@ -83,8 +86,8 @@ func (c *GenericCodeTokenEndpointHandler) PopulateTokenEndpointResponse(ctx cont
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}

if err = c.ValidateCode(ctx, requester, code); err != nil {
return errorsx.WithStack(fosite.ErrInvalidRequest.WithWrap(err).WithDebug(err.Error()))
if err = c.ValidateCodeSession(ctx, ar, code); err != nil {
return errorsx.WithStack(err)
}

for _, scope := range ar.GetRequestedScopes() {
Expand Down Expand Up @@ -166,6 +169,10 @@ func (c *GenericCodeTokenEndpointHandler) HandleTokenEndpointRequest(ctx context
return err
}

if err = c.ValidateCode(ctx, requester, code); err != nil {
return errorsx.WithStack(err)
}

var ar fosite.Requester
if ar, err = c.Session(ctx, requester, signature); err != nil {
if ar != nil && (errors.Is(err, fosite.ErrInvalidatedAuthorizeCode) || errors.Is(err, fosite.ErrInvalidatedDeviceCode)) {
Expand All @@ -175,7 +182,7 @@ func (c *GenericCodeTokenEndpointHandler) HandleTokenEndpointRequest(ctx context
return err
}

if err = c.ValidateCode(ctx, ar, code); err != nil {
if err = c.ValidateCodeSession(ctx, ar, code); err != nil {
return errorsx.WithStack(err)
}

Expand Down
21 changes: 12 additions & 9 deletions handler/rfc8628/token_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,6 @@ type DeviceCodeHandler struct {
func (c DeviceCodeHandler) Code(ctx context.Context, requester fosite.AccessRequester) (code string, signature string, err error) {
code = requester.GetRequestForm().Get("device_code")

shouldRateLimit, err := c.DeviceRateLimitStrategy.ShouldRateLimit(ctx, code)
// TODO(nsklikas) : should we error out or just silently log it?
if err != nil {
return "", "", err
}
if shouldRateLimit {
return "", "", errorsx.WithStack(fosite.ErrPollingRateLimited)
}

signature, err = c.DeviceCodeStrategy.DeviceCodeSignature(ctx, code)
if err != nil {
return "", "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
Expand All @@ -40,6 +31,18 @@ func (c DeviceCodeHandler) Code(ctx context.Context, requester fosite.AccessRequ
}

func (c DeviceCodeHandler) ValidateCode(ctx context.Context, requester fosite.Requester, code string) error {
shouldRateLimit, err := c.DeviceRateLimitStrategy.ShouldRateLimit(ctx, code)
// TODO(nsklikas) : should we error out or just silently log it?
if err != nil {
return err
}
if shouldRateLimit {
return errorsx.WithStack(fosite.ErrPollingRateLimited)
}
return nil
}

func (c DeviceCodeHandler) ValidateCodeSession(ctx context.Context, requester fosite.Requester, code string) error {
return c.DeviceCodeStrategy.ValidateDeviceCode(ctx, requester, code)
}

Expand Down
28 changes: 0 additions & 28 deletions handler/rfc8628/token_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,33 +376,6 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) {
},
expectErr: fosite.ErrServerError,
},
{
description: "should fail because device code is expired",
areq: &fosite.AccessRequest{
GrantTypes: fosite.Arguments{string(fosite.GrantTypeDeviceCode)},
Request: fosite.Request{
Form: url.Values{},
Client: &fosite.DefaultClient{
GrantTypes: fosite.Arguments{string(fosite.GrantTypeDeviceCode)},
},
Session: &DefaultDeviceFlowSession{
ExpiresAt: map[fosite.TokenType]time.Time{
fosite.DeviceCode: time.Now().Add(-time.Hour).UTC(),
},
BrowserFlowCompleted: true,
},
RequestedAt: time.Now().Add(-2 * time.Hour).UTC(),
},
},
setup: func(t *testing.T, areq *fosite.AccessRequest, config *fosite.Config) {
code, signature, err := strategy.GenerateDeviceCode(context.TODO())
require.NoError(t, err)
areq.Form.Add("device_code", code)

require.NoError(t, store.CreateDeviceCodeSession(context.Background(), signature, areq))
},
expectErr: fosite.ErrInvalidRequest,
},
{
description: "should pass with offline scope and refresh token",
areq: &fosite.AccessRequest{
Expand Down Expand Up @@ -767,7 +740,6 @@ func TestDeviceUserCodeTransactional_HandleTokenEndpointRequest(t *testing.T) {
mockCoreStore = internal.NewMockCoreStorage(ctrl)
mockDeviceCodeStore = internal.NewMockDeviceCodeStorage(ctrl)
mockDeviceRateLimitStrategy = internal.NewMockDeviceRateLimitStrategy(ctrl)
mockDeviceRateLimitStrategy.EXPECT().ShouldRateLimit(gomock.Any(), gomock.Any()).Return(false, nil).Times(1)
testCase.setup()

handler := oauth2.GenericCodeTokenEndpointHandler{
Expand Down

0 comments on commit 322726b

Please sign in to comment.