Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle the user code generation duplication #23

Merged
merged 1 commit into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions generate-mocks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mockgen -package internal -destination internal/oauth2_storage.go github.com/ory
mockgen -package internal -destination internal/oauth2_strategy.go github.com/ory/fosite/handler/oauth2 CoreStrategy
mockgen -package internal -destination internal/authorize_code_storage.go github.com/ory/fosite/handler/oauth2 AuthorizeCodeStorage
mockgen -package internal -destination internal/device_code_storage.go github.com/ory/fosite/handler/rfc8628 DeviceCodeStorage
mockgen -package internal -destination internal/rfc8628_core_storage.go github.com/ory/fosite/handler/rfc8628 RFC8628CoreStorage
mockgen -package internal -destination internal/oauth2_auth_jwt_storage.go github.com/ory/fosite/handler/rfc7523 RFC7523KeyStorage
mockgen -package internal -destination internal/access_token_storage.go github.com/ory/fosite/handler/oauth2 AccessTokenStorage
mockgen -package internal -destination internal/refresh_token_strategy.go github.com/ory/fosite/handler/oauth2 RefreshTokenStorage
Expand All @@ -17,6 +18,7 @@ mockgen -package internal -destination internal/openid_id_token_storage.go githu
mockgen -package internal -destination internal/access_token_strategy.go github.com/ory/fosite/handler/oauth2 AccessTokenStrategy
mockgen -package internal -destination internal/refresh_token_strategy.go github.com/ory/fosite/handler/oauth2 RefreshTokenStrategy
mockgen -package internal -destination internal/authorize_code_strategy.go github.com/ory/fosite/handler/oauth2 AuthorizeCodeStrategy
mockgen -package internal -destination internal/rfc8628_code_strategy.go github.com/ory/fosite/handler/rfc8628 RFC8628CodeStrategy
mockgen -package internal -destination internal/device_code_rate_limit_strategy.go github.com/ory/fosite/handler/rfc8628 DeviceRateLimitStrategy
mockgen -package internal -destination internal/id_token_strategy.go github.com/ory/fosite/handler/openid OpenIDConnectTokenStrategy
mockgen -package internal -destination internal/pkce_storage_strategy.go github.com/ory/fosite/handler/pkce PKCERequestStorage
Expand Down
2 changes: 2 additions & 0 deletions generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package fosite
//go:generate go run github.com/golang/mock/mockgen -package internal -destination internal/oauth2_strategy.go github.com/ory/fosite/handler/oauth2 CoreStrategy
//go:generate go run github.com/golang/mock/mockgen -package internal -destination internal/authorize_code_storage.go github.com/ory/fosite/handler/oauth2 AuthorizeCodeStorage
//go:generate go run github.com/golang/mock/mockgen -package internal -destination internal/device_code_storage.go github.com/ory/fosite/handler/rfc8628 DeviceCodeStorage
//go:generate go run github.com/golang/mock/mockgen -package internal -destination internal/rfc8628_core_storage.go github.com/ory/fosite/handler/rfc8628 RFC8628CoreStorage
//go:generate go run github.com/golang/mock/mockgen -package internal -destination internal/oauth2_auth_jwt_storage.go github.com/ory/fosite/handler/rfc7523 RFC7523KeyStorage
//go:generate go run github.com/golang/mock/mockgen -package internal -destination internal/access_token_storage.go github.com/ory/fosite/handler/oauth2 AccessTokenStorage
//go:generate go run github.com/golang/mock/mockgen -package internal -destination internal/refresh_token_strategy.go github.com/ory/fosite/handler/oauth2 RefreshTokenStorage
Expand All @@ -21,6 +22,7 @@ package fosite
//go:generate go run github.com/golang/mock/mockgen -package internal -destination internal/refresh_token_strategy.go github.com/ory/fosite/handler/oauth2 RefreshTokenStrategy
//go:generate go run github.com/golang/mock/mockgen -package internal -destination internal/authorize_code_strategy.go github.com/ory/fosite/handler/oauth2 AuthorizeCodeStrategy
//go:generate go run github.com/golang/mock/mockgen -package internal -destination internal/device_code_rate_limit_strategy.go github.com/ory/fosite/handler/rfc8628 DeviceRateLimitStrategy
//go:generate go run github.com/golang/mock/mockgen -package internal -destination internal/rfc8628_code_strategy.go github.com/ory/fosite/handler/rfc8628 RFC8628CodeStrategy
//go:generate go run github.com/golang/mock/mockgen -package internal -destination internal/id_token_strategy.go github.com/ory/fosite/handler/openid OpenIDConnectTokenStrategy
//go:generate go run github.com/golang/mock/mockgen -package internal -destination internal/pkce_storage_strategy.go github.com/ory/fosite/handler/pkce PKCERequestStorage
//go:generate go run github.com/golang/mock/mockgen -package internal -destination internal/authorize_handler.go github.com/ory/fosite AuthorizeEndpointHandler
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ require (
github.com/ory/x v0.0.613
github.com/parnurzeal/gorequest v0.2.15
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.8.4
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.14.3
go.opentelemetry.io/otel/trace v1.21.0
golang.org/x/crypto v0.18.0
Expand Down
6 changes: 4 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,9 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
Expand All @@ -409,8 +410,9 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8=
github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
Expand Down
65 changes: 48 additions & 17 deletions handler/rfc8628/auth_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ package rfc8628

import (
"context"
"fmt"
"time"

"github.com/ory/x/errorsx"

"github.com/ory/fosite"
"github.com/ory/x/errorsx"
)

// MaxAttempts for retrying the generation of user codes.
const MaxAttempts = 3

// DeviceAuthHandler is a response handler for the Device Authorisation Grant as
// defined in https://tools.ietf.org/html/rfc8628#section-3.1
type DeviceAuthHandler struct {
Expand All @@ -25,25 +28,18 @@ type DeviceAuthHandler struct {

// HandleDeviceEndpointRequest implements https://tools.ietf.org/html/rfc8628#section-3.1
func (d *DeviceAuthHandler) HandleDeviceEndpointRequest(ctx context.Context, dar fosite.DeviceRequester, resp fosite.DeviceResponder) error {
deviceCode, deviceCodeSignature, err := d.Strategy.GenerateDeviceCode(ctx)
if err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}
var err error

userCode, userCodeSignature, err := d.Strategy.GenerateUserCode(ctx)
var deviceCode string
deviceCode, err = d.handleDeviceCode(ctx, dar)
if err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
return err
}

// Store the User Code session (this has no real data other that the user and device code), can be converted into a 'full' session after user auth
dar.GetSession().SetExpiresAt(fosite.DeviceCode, time.Now().UTC().Add(d.Config.GetDeviceAndUserCodeLifespan(ctx)))
if err := d.Storage.CreateDeviceCodeSession(ctx, deviceCodeSignature, dar.Sanitize(nil)); err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}

dar.GetSession().SetExpiresAt(fosite.UserCode, time.Now().UTC().Add(d.Config.GetDeviceAndUserCodeLifespan(ctx)).Round(time.Second))
if err := d.Storage.CreateUserCodeSession(ctx, userCodeSignature, dar.Sanitize(nil)); err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
var userCode string
userCode, err = d.handleUserCode(ctx, dar)
if err != nil {
return err
}

// Populate the response fields
Expand All @@ -55,3 +51,38 @@ func (d *DeviceAuthHandler) HandleDeviceEndpointRequest(ctx context.Context, dar
resp.SetInterval(int(d.Config.GetDeviceAuthTokenPollingInterval(ctx).Seconds()))
return nil
}

func (d *DeviceAuthHandler) handleDeviceCode(ctx context.Context, dar fosite.DeviceRequester) (string, error) {
code, signature, err := d.Strategy.GenerateDeviceCode(ctx)
if err != nil {
return "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}

dar.GetSession().SetExpiresAt(fosite.DeviceCode, time.Now().UTC().Add(d.Config.GetDeviceAndUserCodeLifespan(ctx)))
if err = d.Storage.CreateDeviceCodeSession(ctx, signature, dar.Sanitize(nil)); err != nil {
return "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}

return code, nil
}

func (d *DeviceAuthHandler) handleUserCode(ctx context.Context, dar fosite.DeviceRequester) (string, error) {
var err error
var userCode, signature string
// Retry when persisting user code fails
// Possible causes include database connection issue, constraints violations, etc.
for i := 0; i < MaxAttempts; i++ {
userCode, signature, err = d.Strategy.GenerateUserCode(ctx)
if err != nil {
return "", err
}

dar.GetSession().SetExpiresAt(fosite.UserCode, time.Now().UTC().Add(d.Config.GetDeviceAndUserCodeLifespan(ctx)).Round(time.Second))
if err = d.Storage.CreateUserCodeSession(ctx, signature, dar.Sanitize(nil)); err == nil {
return userCode, nil
}
}

errMsg := fmt.Sprintf("Exceeded user-code generation max attempts %v: %s", MaxAttempts, err.Error())
return "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(errMsg))
}
163 changes: 160 additions & 3 deletions handler/rfc8628/auth_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@ package rfc8628_test

import (
"context"
"errors"
"fmt"
"testing"
"time"

"github.com/ory/fosite/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/golang/mock/gomock"

"github.com/ory/fosite/internal"

"github.com/ory/fosite"
"github.com/ory/fosite/handler/rfc8628"
"github.com/ory/fosite/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_HandleDeviceEndpointRequest(t *testing.T) {
Expand Down Expand Up @@ -52,3 +57,155 @@ func Test_HandleDeviceEndpointRequest(t *testing.T) {
assert.Contains(t, resp.GetDeviceCode(), ".")
assert.Equal(t, resp.GetVerificationURI(), "www.test.com")
}

func Test_HandleDeviceEndpointRequestWithRetry(t *testing.T) {
var mockRFC8628CoreStorage *internal.MockRFC8628CoreStorage
var mockRFC8628CodeStrategy *internal.MockRFC8628CodeStrategy

ctx := context.Background()
req := &fosite.DeviceRequest{
Request: fosite.Request{
Client: &fosite.DefaultClient{
Audience: []string{"https://www.ory.sh/api"},
},
Session: &fosite.DefaultSession{},
},
}

testCases := []struct {
description string
setup func()
check func(t *testing.T, resp *fosite.DeviceResponse)
expectError error
}{
{
description: "should pass when generating a unique user code at the first attempt",
setup: func() {
mockRFC8628CodeStrategy.
EXPECT().
GenerateDeviceCode(ctx).
Return("deviceCode", "signature", nil)
mockRFC8628CoreStorage.
EXPECT().
CreateDeviceCodeSession(ctx, "signature", gomock.Any()).
Return(nil)
mockRFC8628CodeStrategy.
EXPECT().
GenerateUserCode(ctx).
Return("userCode", "signature", nil).
Times(1)
mockRFC8628CoreStorage.
EXPECT().
CreateUserCodeSession(ctx, "signature", gomock.Any()).
Return(nil).
Times(1)
},
check: func(t *testing.T, resp *fosite.DeviceResponse) {
assert.Equal(t, "userCode", resp.GetUserCode())
},
},
{
description: "should pass when generating a unique user code within allowed attempts",
setup: func() {
mockRFC8628CodeStrategy.
EXPECT().
GenerateDeviceCode(ctx).
Return("deviceCode", "signature", nil)
mockRFC8628CoreStorage.
EXPECT().
CreateDeviceCodeSession(ctx, "signature", gomock.Any()).
Return(nil)
gomock.InOrder(
mockRFC8628CodeStrategy.
EXPECT().
GenerateUserCode(ctx).
Return("duplicatedUserCode", "duplicatedSignature", nil),
mockRFC8628CoreStorage.
EXPECT().
CreateUserCodeSession(ctx, "duplicatedSignature", gomock.Any()).
Return(errors.New("unique constraint violation")),
mockRFC8628CodeStrategy.
EXPECT().
GenerateUserCode(ctx).
Return("uniqueUserCode", "uniqueSignature", nil),
mockRFC8628CoreStorage.
EXPECT().
CreateUserCodeSession(ctx, "uniqueSignature", gomock.Any()).
Return(nil),
)
},
check: func(t *testing.T, resp *fosite.DeviceResponse) {
assert.Equal(t, "uniqueUserCode", resp.GetUserCode())
},
},
{
description: "should fail after maximum retries to generate a unique user code",
setup: func() {
mockRFC8628CodeStrategy.
EXPECT().
GenerateDeviceCode(ctx).
Return("deviceCode", "signature", nil)
mockRFC8628CoreStorage.
EXPECT().
CreateDeviceCodeSession(ctx, "signature", gomock.Any()).
Return(nil)
mockRFC8628CodeStrategy.
EXPECT().
GenerateUserCode(ctx).
Return("duplicatedUserCode", "duplicatedSignature", nil).
Times(rfc8628.MaxAttempts)
mockRFC8628CoreStorage.
EXPECT().
CreateUserCodeSession(ctx, "duplicatedSignature", gomock.Any()).
Return(errors.New("unique constraint violation")).
Times(rfc8628.MaxAttempts)
},
check: func(t *testing.T, resp *fosite.DeviceResponse) {
assert.Empty(t, resp.GetUserCode())
},
expectError: fosite.ErrServerError,
},
}

for _, testCase := range testCases {
t.Run(fmt.Sprintf("scenario=%s", testCase.description), func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockRFC8628CoreStorage = internal.NewMockRFC8628CoreStorage(ctrl)
mockRFC8628CodeStrategy = internal.NewMockRFC8628CodeStrategy(ctrl)

h := rfc8628.DeviceAuthHandler{
Storage: mockRFC8628CoreStorage,
Strategy: mockRFC8628CodeStrategy,
Config: &fosite.Config{
DeviceAndUserCodeLifespan: time.Minute * 10,
DeviceAuthTokenPollingInterval: time.Second * 5,
DeviceVerificationURL: "www.test.com",
AccessTokenLifespan: time.Hour,
RefreshTokenLifespan: time.Hour,
ScopeStrategy: fosite.HierarchicScopeStrategy,
AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy,
RefreshTokenScopes: []string{"offline"},
},
}

if testCase.setup != nil {
testCase.setup()
}

resp := fosite.NewDeviceResponse()
err := h.HandleDeviceEndpointRequest(ctx, req, resp)

if testCase.expectError != nil {
require.EqualError(t, err, testCase.expectError.Error(), "%+v", err)
} else {
require.NoError(t, err, "%+v", err)
}

if testCase.check != nil {
testCase.check(t, resp)
}
})
}
}
Loading
Loading