From 6322f685a80a06a7440ab52ba874334de6102301 Mon Sep 17 00:00:00 2001 From: Nikos Date: Fri, 9 Feb 2024 14:31:09 +0200 Subject: [PATCH] fix: review comments --- compose/compose_rfc8628.go | 4 +- compose/compose_strategy.go | 1 + config.go | 2 + config_default.go | 15 +- device_request.go | 3 +- device_request_handler.go | 33 +--- device_request_handler_test.go | 157 +++++++++--------- device_request_test.go | 2 +- device_response.go | 59 ++++--- device_response_test.go | 2 +- device_response_writer.go | 3 +- device_write.go | 25 +-- device_write_test.go | 2 +- handler.go | 1 + handler/rfc8628/auth_handler.go | 10 +- handler/rfc8628/auth_handler_test.go | 62 +++---- handler/rfc8628/storage.go | 5 +- handler/rfc8628/strategy.go | 6 +- handler/rfc8628/strategy_hmacsha.go | 22 ++- handler/rfc8628/strategy_hmacsha_test.go | 2 +- .../authorize_device_grant_request_test.go | 4 +- oauth2.go | 13 ++ storage/memory.go | 7 + token/hmac/hmacsha.go | 1 + 24 files changed, 229 insertions(+), 212 deletions(-) diff --git a/compose/compose_rfc8628.go b/compose/compose_rfc8628.go index fb377606b..5217aeb7f 100644 --- a/compose/compose_rfc8628.go +++ b/compose/compose_rfc8628.go @@ -1,6 +1,8 @@ -// Copyright © 2023 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 +// Package compose provides various objects which can be used to +// instantiate OAuth2Providers with different functionality. package compose import ( diff --git a/compose/compose_strategy.go b/compose/compose_strategy.go index 748d8eab4..1367485a6 100644 --- a/compose/compose_strategy.go +++ b/compose/compose_strategy.go @@ -55,6 +55,7 @@ func NewOpenIDConnectStrategy(keyGetter func(context.Context) (interface{}, erro } } +// Create a new device strategy func NewDeviceStrategy(config fosite.Configurator) *rfc8628.DefaultDeviceStrategy { return &rfc8628.DefaultDeviceStrategy{ Enigma: &hmac.HMACStrategy{Config: config}, diff --git a/config.go b/config.go index d95fa88d9..86802f0e5 100644 --- a/config.go +++ b/config.go @@ -46,6 +46,7 @@ type IDTokenLifespanProvider interface { GetIDTokenLifespan(ctx context.Context) time.Duration } +// DeviceAndUserCodeLifespanProvider returns the provider for configuring the device and user code lifespan type DeviceAndUserCodeLifespanProvider interface { GetDeviceAndUserCodeLifespan(ctx context.Context) time.Duration } @@ -80,6 +81,7 @@ type DisableRefreshTokenValidationProvider interface { GetDisableRefreshTokenValidation(ctx context.Context) bool } +// DeviceProvider returns the provider for configuring the device flow type DeviceProvider interface { GetDeviceVerificationURL(ctx context.Context) string GetDeviceAuthTokenPollingInterval(ctx context.Context) time.Duration diff --git a/config_default.go b/config_default.go index cc6acece3..4d654dbc8 100644 --- a/config_default.go +++ b/config_default.go @@ -18,8 +18,10 @@ import ( ) const ( - defaultPARPrefix = "urn:ietf:params:oauth:request_uri:" - defaultPARContextLifetime = 5 * time.Minute + defaultPARPrefix = "urn:ietf:params:oauth:request_uri:" + defaultPARContextLifetime = 5 * time.Minute + defaultDeviceAndUserCodeLifespan = 10 * time.Minute + defaultAuthTokenPollingInterval = 5 * time.Second ) var ( @@ -257,6 +259,7 @@ func (c *Config) GetTokenIntrospectionHandlers(ctx context.Context) TokenIntrosp return c.TokenIntrospectionHandlers } +// GetDeviceEndpointHandlers return the Device Endpoint Handlers func (c *Config) GetDeviceEndpointHandlers(ctx context.Context) DeviceEndpointHandlers { return c.DeviceEndpointHandlers } @@ -412,9 +415,11 @@ func (c *Config) GetRefreshTokenLifespan(_ context.Context) time.Duration { return c.RefreshTokenLifespan } +// GetDeviceAndUserCodeLifespan returns how long the device and user codes should be valid. +// Defaults to 10 minutes func (c *Config) GetDeviceAndUserCodeLifespan(_ context.Context) time.Duration { if c.DeviceAndUserCodeLifespan == 0 { - return time.Minute * 10 + return defaultDeviceAndUserCodeLifespan } return c.DeviceAndUserCodeLifespan } @@ -523,13 +528,15 @@ func (c *Config) EnforcePushedAuthorize(ctx context.Context) bool { return c.IsPushedAuthorizeEnforced } +// GetDeviceVerificationURL returns the device verification URL func (c *Config) GetDeviceVerificationURL(ctx context.Context) string { return c.DeviceVerificationURL } +// GetDeviceAuthTokenPollingInterval returns configured device token endpoint polling interval func (c *Config) GetDeviceAuthTokenPollingInterval(ctx context.Context) time.Duration { if c.DeviceAuthTokenPollingInterval == 0 { - return time.Second * 5 + return defaultAuthTokenPollingInterval } return c.DeviceAuthTokenPollingInterval } diff --git a/device_request.go b/device_request.go index 3c2df0636..0b243b015 100644 --- a/device_request.go +++ b/device_request.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 package fosite @@ -8,6 +8,7 @@ type DeviceRequest struct { Request } +// NewDeviceRequest returns a new device request func NewDeviceRequest() *DeviceRequest { return &DeviceRequest{ Request: *NewRequest(), diff --git a/device_request_handler.go b/device_request_handler.go index 20d3b1e26..cec617df5 100644 --- a/device_request_handler.go +++ b/device_request_handler.go @@ -1,27 +1,6 @@ -// Copyright © 2023 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 -/* - * Copyright © 2015-2021 Aeneas Rekkas - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * @author Aeneas Rekkas - * @copyright 2015-2021 Aeneas Rekkas - * @license Apache-2.0 - * - */ - package fosite import ( @@ -29,12 +8,14 @@ import ( "net/http" "strings" - "github.com/ory/fosite/i18n" "github.com/ory/x/errorsx" "github.com/ory/x/otelx" "go.opentelemetry.io/otel/trace" + + "github.com/ory/fosite/i18n" ) +// NewDeviceRequest parses an http Request returns a Device request func (f *Fosite) NewDeviceRequest(ctx context.Context, r *http.Request) (_ DeviceRequester, err error) { ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("github.com/ory/fosite").Start(ctx, "Fosite.NewAccessRequest") defer otelx.End(span, &err) @@ -44,9 +25,11 @@ func (f *Fosite) NewDeviceRequest(ctx context.Context, r *http.Request) (_ Devic if r.Method != "POST" { return request, errorsx.WithStack(ErrInvalidRequest.WithHintf("HTTP method is '%s', expected 'POST'.", r.Method)) - } else if err := r.ParseForm(); err != nil { + } + if err := r.ParseForm(); err != nil { return nil, errorsx.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithWrap(err).WithDebug(err.Error())) - } else if len(r.PostForm) == 0 { + } + if len(r.PostForm) == 0 { return request, errorsx.WithStack(ErrInvalidRequest.WithHint("The POST body can not be empty.")) } request.Form = r.PostForm diff --git a/device_request_handler_test.go b/device_request_handler_test.go index d6dceeffa..0b5b38e6d 100644 --- a/device_request_handler_test.go +++ b/device_request_handler_test.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 package fosite_test @@ -33,94 +33,87 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) { expectedError error mock func() expect DeviceRequester - }{ - /* invalid Method */ - { - expectedError: ErrInvalidRequest, - method: "GET", - mock: func() {}, + description string + }{{ + description: "invalid method", + expectedError: ErrInvalidRequest, + method: "GET", + mock: func() {}, + }, { + description: "empty request", + expectedError: ErrInvalidRequest, + method: "POST", + mock: func() {}, + }, { + description: "invalid client", + form: url.Values{ + "client_id": {"client_id"}, + "scope": {"foo bar"}, }, - /* empty request */ - { - expectedError: ErrInvalidRequest, - method: "POST", - mock: func() {}, + expectedError: ErrInvalidClient, + method: "POST", + mock: func() { + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(nil, errors.New("")) }, - /* invalid client */ - { - form: url.Values{ - "client_id": {"client_id"}, - "scope": {"foo bar"}, - }, - expectedError: ErrInvalidClient, - method: "POST", - mock: func() { - store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(nil, errors.New("")) - }, + }, { + description: "fails because scope not allowed", + form: url.Values{ + "client_id": {"client_id"}, + "scope": {"17 42 foo"}, }, - /* fails because scope not allowed */ - { - form: url.Values{ - "client_id": {"client_id"}, - "scope": {"17 42 foo"}, - }, - method: "POST", - mock: func() { - store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) - client.Public = true - client.Scopes = []string{"17", "42"} - client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"} - }, - expectedError: ErrInvalidScope, + method: "POST", + mock: func() { + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) + client.Public = true + client.Scopes = []string{"17", "42"} + client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"} }, - /* fails because scope not allowed */ - { - form: url.Values{ - "client_id": {"client_id"}, - "scope": {"17 42"}, - "audience": {"aud"}, - }, - method: "POST", - mock: func() { - store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) - client.Public = true - client.Scopes = []string{"17", "42"} - client.Audience = []string{"aud2"} - client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"} - }, - expectedError: ErrInvalidRequest, + expectedError: ErrInvalidScope, + }, { + description: "fails because audience not allowed", + form: url.Values{ + "client_id": {"client_id"}, + "scope": {"17 42"}, + "audience": {"aud"}, }, - /* should fail because doesn't have proper grant */ - { - form: url.Values{ - "client_id": {"client_id"}, - "scope": {"17 42"}, - }, - method: "POST", - mock: func() { - store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) - client.Public = true - client.Scopes = []string{"17", "42"} - client.GrantTypes = []string{"authorization_code"} - }, - expectedError: ErrInvalidGrant, + method: "POST", + mock: func() { + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) + client.Public = true + client.Scopes = []string{"17", "42"} + client.Audience = []string{"aud2"} + client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"} }, - /* success case */ - { - form: url.Values{ - "client_id": {"client_id"}, - "scope": {"17 42"}, - }, - method: "POST", - mock: func() { - store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) - client.Public = true - client.Scopes = []string{"17", "42"} - client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"} - }, + expectedError: ErrInvalidRequest, + }, { + description: "fails because it doesn't have the proper grant", + form: url.Values{ + "client_id": {"client_id"}, + "scope": {"17 42"}, }, - } { - t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + method: "POST", + mock: func() { + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) + client.Public = true + client.Scopes = []string{"17", "42"} + client.GrantTypes = []string{"authorization_code"} + }, + expectedError: ErrInvalidGrant, + }, { + description: "success", + form: url.Values{ + "client_id": {"client_id"}, + "scope": {"17 42"}, + }, + method: "POST", + mock: func() { + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) + client.Public = true + client.Scopes = []string{"17", "42"} + client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"} + }, + }} { + t.Run(fmt.Sprintf("case=%d description=%s", k, c.description), func(t *testing.T) { c.mock() r := &http.Request{ Header: c.header, diff --git a/device_request_test.go b/device_request_test.go index 571e83ba8..7e67c0529 100644 --- a/device_request_test.go +++ b/device_request_test.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 package fosite diff --git a/device_response.go b/device_response.go index b9b9d655e..bad3e2064 100644 --- a/device_response.go +++ b/device_response.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 package fosite @@ -9,7 +9,8 @@ import ( "net/http" ) -type deviceResponse struct { +// DeviceResponse represents the device authorization response +type DeviceResponse struct { Header http.Header DeviceCode string `json:"device_code"` UserCode string `json:"user_code"` @@ -19,79 +20,87 @@ type deviceResponse struct { Interval int `json:"interval,omitempty"` } -type DeviceResponse struct { - deviceResponse -} - +// NewDeviceResponse returns a new DeviceResponse func NewDeviceResponse() *DeviceResponse { return &DeviceResponse{} } +// GetDeviceCode returns the response's device_code func (d *DeviceResponse) GetDeviceCode() string { - return d.deviceResponse.DeviceCode + return d.DeviceCode } -// SetDeviceCode returns the response's user code +// SetDeviceCode sets the response's device_code func (d *DeviceResponse) SetDeviceCode(code string) { - d.deviceResponse.DeviceCode = code + d.DeviceCode = code } +// GetUserCode returns the response's user_code func (d *DeviceResponse) GetUserCode() string { - return d.deviceResponse.UserCode + return d.UserCode } +// SetUserCode sets the response's user_code func (d *DeviceResponse) SetUserCode(code string) { - d.deviceResponse.UserCode = code + d.UserCode = code } // GetVerificationURI returns the response's verification uri func (d *DeviceResponse) GetVerificationURI() string { - return d.deviceResponse.VerificationURI + return d.VerificationURI } +// SetVerificationURI sets the response's verification uri func (d *DeviceResponse) SetVerificationURI(uri string) { - d.deviceResponse.VerificationURI = uri + d.VerificationURI = uri } // GetVerificationURIComplete returns the response's complete verification uri if set func (d *DeviceResponse) GetVerificationURIComplete() string { - return d.deviceResponse.VerificationURIComplete + return d.VerificationURIComplete } +// SetVerificationURIComplete sets the response's complete verification uri func (d *DeviceResponse) SetVerificationURIComplete(uri string) { - d.deviceResponse.VerificationURIComplete = uri + d.VerificationURIComplete = uri } // GetExpiresIn returns the response's device code and user code lifetime in seconds if set func (d *DeviceResponse) GetExpiresIn() int64 { - return d.deviceResponse.ExpiresIn + return d.ExpiresIn } +// SetExpiresIn sets the response's device code and user code lifetime in seconds func (d *DeviceResponse) SetExpiresIn(seconds int64) { - d.deviceResponse.ExpiresIn = seconds + d.ExpiresIn = seconds } // GetInterval returns the response's polling interval if set func (d *DeviceResponse) GetInterval() int { - return d.deviceResponse.Interval + return d.Interval } +// SetInterval sets the response's polling interval func (d *DeviceResponse) SetInterval(seconds int) { - d.deviceResponse.Interval = seconds + d.Interval = seconds } -func (a *DeviceResponse) GetHeader() http.Header { - return a.deviceResponse.Header +// GetHeader returns the response's headers +func (d *DeviceResponse) GetHeader() http.Header { + return d.Header } -func (a *DeviceResponse) AddHeader(key, value string) { - a.deviceResponse.Header.Add(key, value) +// AddHeader adds a header to the response +func (d *DeviceResponse) AddHeader(key, value string) { + d.Header.Add(key, value) } +// FromJson populates a response's fields from a json func (d *DeviceResponse) FromJson(r io.Reader) error { - return json.NewDecoder(r).Decode(&d.deviceResponse) + return json.NewDecoder(r).Decode(&d) } +// ToJson writes a response as a json func (d *DeviceResponse) ToJson(rw io.Writer) error { - return json.NewEncoder(rw).Encode(&d.deviceResponse) + return json.NewEncoder(rw).Encode(&d) } diff --git a/device_response_test.go b/device_response_test.go index 366899a1b..a4e95e168 100644 --- a/device_response_test.go +++ b/device_response_test.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 package fosite diff --git a/device_response_writer.go b/device_response_writer.go index 82c5d7c93..2cc17d096 100644 --- a/device_response_writer.go +++ b/device_response_writer.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 package fosite @@ -7,6 +7,7 @@ import ( "context" ) +// NewDeviceResponse returns a new DeviceResponder func (f *Fosite) NewDeviceResponse(ctx context.Context, r DeviceRequester, session Session) (DeviceResponder, error) { var resp = &DeviceResponse{} diff --git a/device_write.go b/device_write.go index 0e8fa77b5..9240140f9 100644 --- a/device_write.go +++ b/device_write.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 package fosite @@ -8,8 +8,7 @@ import ( "net/http" ) -// TODO: Do documentation - +// WriteDeviceResponse writes the device response func (f *Fosite) WriteDeviceResponse(ctx context.Context, rw http.ResponseWriter, requester DeviceRequester, responder DeviceResponder) { // Set custom headers, e.g. "X-MySuperCoolCustomHeader" or "X-DONT-CACHE-ME"... wh := rw.Header() @@ -23,15 +22,17 @@ func (f *Fosite) WriteDeviceResponse(ctx context.Context, rw http.ResponseWriter rw.Header().Set("Pragma", "no-cache") deviceResponse := &DeviceResponse{ - deviceResponse{ - DeviceCode: responder.GetDeviceCode(), - UserCode: responder.GetUserCode(), - VerificationURI: responder.GetVerificationURI(), - VerificationURIComplete: responder.GetVerificationURIComplete(), - ExpiresIn: responder.GetExpiresIn(), - Interval: responder.GetInterval(), - }, + DeviceCode: responder.GetDeviceCode(), + UserCode: responder.GetUserCode(), + VerificationURI: responder.GetVerificationURI(), + VerificationURIComplete: responder.GetVerificationURIComplete(), + ExpiresIn: responder.GetExpiresIn(), + Interval: responder.GetInterval(), } - _ = deviceResponse.ToJson(rw) + err := deviceResponse.ToJson(rw) + if err != nil { + http.Error(rw, ErrServerError.WithWrap(err).WithDebug(err.Error()).Error(), http.StatusInternalServerError) + return + } } diff --git a/device_write_test.go b/device_write_test.go index 0ed418cb7..7ec9b5d43 100644 --- a/device_write_test.go +++ b/device_write_test.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 package fosite_test diff --git a/handler.go b/handler.go index d95b0f367..6f11626f0 100644 --- a/handler.go +++ b/handler.go @@ -67,6 +67,7 @@ type PushedAuthorizeEndpointHandler interface { HandlePushedAuthorizeEndpointRequest(ctx context.Context, requester AuthorizeRequester, responder PushedAuthorizeResponder) error } +// DeviceEndpointHandler is the interface that handles https://tools.ietf.org/html/rfc8628 type DeviceEndpointHandler interface { // HandleDeviceEndpointRequest handles a device authorize endpoint request. To extend the handler's capabilities, the http request // is passed along, if further information retrieval is required. If the handler feels that he is not responsible for diff --git a/handler/rfc8628/auth_handler.go b/handler/rfc8628/auth_handler.go index 9072cbe46..6d97da3aa 100644 --- a/handler/rfc8628/auth_handler.go +++ b/handler/rfc8628/auth_handler.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 package rfc8628 @@ -7,10 +7,13 @@ import ( "context" "time" - "github.com/ory/fosite" "github.com/ory/x/errorsx" + + "github.com/ory/fosite" ) +// DeviceAuthHandler is a response handler for the Device Authorisation Grant as +// defined in https://tools.ietf.org/html/rfc8628#section-3.1 type DeviceAuthHandler struct { Storage RFC8628CoreStorage Strategy RFC8628CodeStrategy @@ -20,8 +23,7 @@ type DeviceAuthHandler struct { } } -// DeviceAuthorizationHandler is a response handler for the Device Authorisation Grant as -// defined in https://tools.ietf.org/html/rfc8628#section-3.1 +// HandleDeviceEndpointRequest implements https://tools.ietf.org/html/rfc8628#section-3.1 func (d *DeviceAuthHandler) HandleDeviceEndpointRequest(ctx context.Context, dar fosite.DeviceRequester, resp fosite.DeviceResponder) error { deviceCode, deviceCodeSignature, err := d.Strategy.GenerateDeviceCode(ctx) if err != nil { diff --git a/handler/rfc8628/auth_handler_test.go b/handler/rfc8628/auth_handler_test.go index c29112bf6..220a03475 100644 --- a/handler/rfc8628/auth_handler_test.go +++ b/handler/rfc8628/auth_handler_test.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 package rfc8628_test @@ -9,18 +9,19 @@ import ( "time" "github.com/golang/mock/gomock" - "github.com/ory/fosite" - . "github.com/ory/fosite/handler/rfc8628" "github.com/ory/fosite/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/ory/fosite" + "github.com/ory/fosite/handler/rfc8628" ) func Test_HandleDeviceEndpointRequest(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() store := storage.NewMemoryStore() - handler := DeviceAuthHandler{ + handler := rfc8628.DeviceAuthHandler{ Storage: store, Strategy: &hmacshaStrategy, Config: &fosite.Config{ @@ -35,45 +36,22 @@ func Test_HandleDeviceEndpointRequest(t *testing.T) { }, } - for _, c := range []struct { - handler DeviceAuthHandler - req *fosite.DeviceRequest - description string - expectErr error - expect func(t *testing.T, req *fosite.DeviceRequest, resp *fosite.DeviceResponse) - }{ - { - handler: handler, - req: &fosite.DeviceRequest{ - Request: fosite.Request{ - Client: &fosite.DefaultClient{ - Audience: []string{"https://www.ory.sh/api"}, - }, - Session: &fosite.DefaultSession{}, - }, - }, - expect: func(t *testing.T, req *fosite.DeviceRequest, resp *fosite.DeviceResponse) { - assert.NotEmpty(t, resp.GetDeviceCode()) - assert.NotEmpty(t, resp.GetUserCode()) - assert.Equal(t, len(resp.GetUserCode()), 8) - assert.Contains(t, resp.GetDeviceCode(), "ory_dc_") - assert.Contains(t, resp.GetDeviceCode(), ".") - assert.Equal(t, resp.GetVerificationURI(), "www.test.com") + req := &fosite.DeviceRequest{ + Request: fosite.Request{ + Client: &fosite.DefaultClient{ + Audience: []string{"https://www.ory.sh/api"}, }, + Session: &fosite.DefaultSession{}, }, - } { - t.Run("case="+c.description, func(t *testing.T) { - resp := fosite.NewDeviceResponse() - err := c.handler.HandleDeviceEndpointRequest(context.Background(), c.req, resp) - if c.expectErr != nil { - require.EqualError(t, err, c.expectErr.Error()) - } else { - require.NoError(t, err) - } - - if c.expect != nil { - c.expect(t, c.req, resp) - } - }) } + resp := fosite.NewDeviceResponse() + err := handler.HandleDeviceEndpointRequest(context.Background(), req, resp) + + require.NoError(t, err) + assert.NotEmpty(t, resp.GetDeviceCode()) + assert.NotEmpty(t, resp.GetUserCode()) + assert.Equal(t, len(resp.GetUserCode()), 8) + assert.Contains(t, resp.GetDeviceCode(), "ory_dc_") + assert.Contains(t, resp.GetDeviceCode(), ".") + assert.Equal(t, resp.GetVerificationURI(), "www.test.com") } diff --git a/handler/rfc8628/storage.go b/handler/rfc8628/storage.go index f356c6aaa..17571ab18 100644 --- a/handler/rfc8628/storage.go +++ b/handler/rfc8628/storage.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 package rfc8628 @@ -10,6 +10,7 @@ import ( "github.com/ory/fosite/handler/oauth2" ) +// RFC8628CoreStorage is the storage needed for the DeviceAuthHandler type RFC8628CoreStorage interface { DeviceCodeStorage UserCodeStorage @@ -17,6 +18,7 @@ type RFC8628CoreStorage interface { oauth2.RefreshTokenStorage } +// DeviceCodeStorage handles the device_code storage type DeviceCodeStorage interface { // CreateDeviceCodeSession stores the device request for a given device code. CreateDeviceCodeSession(ctx context.Context, signature string, request fosite.Requester) (err error) @@ -34,6 +36,7 @@ type DeviceCodeStorage interface { InvalidateDeviceCodeSession(ctx context.Context, signature string) (err error) } +// UserCodeStorage handles the user_code storage type UserCodeStorage interface { // CreateUserCodeSession stores the device request for a given user code. CreateUserCodeSession(ctx context.Context, signature string, request fosite.Requester) (err error) diff --git a/handler/rfc8628/strategy.go b/handler/rfc8628/strategy.go index 3b0df7a71..33900a20c 100644 --- a/handler/rfc8628/strategy.go +++ b/handler/rfc8628/strategy.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 package rfc8628 @@ -9,22 +9,26 @@ import ( "github.com/ory/fosite" ) +// RFC8628CodeStrategy is the code strategy needed for the DeviceAuthHandler type RFC8628CodeStrategy interface { DeviceRateLimitStrategy DeviceCodeStrategy UserCodeStrategy } +// DeviceRateLimitStrategy handles the rate limiting strategy type DeviceRateLimitStrategy interface { ShouldRateLimit(ctx context.Context, code string) bool } +// DeviceCodeStrategy handles the device_code strategy type DeviceCodeStrategy interface { DeviceCodeSignature(ctx context.Context, code string) (signature string, err error) GenerateDeviceCode(ctx context.Context) (code string, signature string, err error) ValidateDeviceCode(ctx context.Context, r fosite.Requester, code string) (err error) } +// UserCodeStrategy handles the user_code strategy type UserCodeStrategy interface { UserCodeSignature(ctx context.Context, code string) (signature string, err error) GenerateUserCode(ctx context.Context) (code string, signature string, err error) diff --git a/handler/rfc8628/strategy_hmacsha.go b/handler/rfc8628/strategy_hmacsha.go index 3e600a164..6f8068c1e 100644 --- a/handler/rfc8628/strategy_hmacsha.go +++ b/handler/rfc8628/strategy_hmacsha.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 package rfc8628 @@ -14,6 +14,7 @@ import ( enigma "github.com/ory/fosite/token/hmac" ) +// DefaultDeviceStrategy implements the default device strategy type DefaultDeviceStrategy struct { Enigma *enigma.HMACStrategy RateLimiterCache *cache.Cache @@ -25,7 +26,8 @@ type DefaultDeviceStrategy struct { var _ RFC8628CodeStrategy = (*DefaultDeviceStrategy)(nil) -func (h *DefaultDeviceStrategy) GenerateUserCode(ctx context.Context) (token string, signature string, err error) { +// GenerateUserCode generates a user_code +func (h *DefaultDeviceStrategy) GenerateUserCode(ctx context.Context) (string, string, error) { seq, err := randx.RuneSequence(8, []rune(randx.AlphaUpper)) if err != nil { return "", "", err @@ -38,16 +40,19 @@ func (h *DefaultDeviceStrategy) GenerateUserCode(ctx context.Context) (token str return userCode, signUserCode, nil } -func (h *DefaultDeviceStrategy) UserCodeSignature(ctx context.Context, token string) (signature string, err error) { +// UserCodeSignature generates a user_code signature +func (h *DefaultDeviceStrategy) UserCodeSignature(ctx context.Context, token string) (string, error) { return h.Enigma.GenerateHMACForString(ctx, token) } -func (h *DefaultDeviceStrategy) ValidateUserCode(ctx context.Context, r fosite.Requester, code string) (err error) { +// ValidateUserCode validates a user_code +func (h *DefaultDeviceStrategy) ValidateUserCode(ctx context.Context, r fosite.Requester, code string) error { // TODO return nil } -func (h *DefaultDeviceStrategy) GenerateDeviceCode(ctx context.Context) (token string, signature string, err error) { +// GenerateDeviceCode generates a device_code +func (h *DefaultDeviceStrategy) GenerateDeviceCode(ctx context.Context) (string, string, error) { token, sig, err := h.Enigma.Generate(ctx) if err != nil { return "", "", err @@ -56,15 +61,18 @@ func (h *DefaultDeviceStrategy) GenerateDeviceCode(ctx context.Context) (token s return "ory_dc_" + token, sig, nil } -func (h *DefaultDeviceStrategy) DeviceCodeSignature(ctx context.Context, token string) (signature string, err error) { +// DeviceCodeSignature generates a device_code signature +func (h *DefaultDeviceStrategy) DeviceCodeSignature(ctx context.Context, token string) (string, error) { return h.Enigma.Signature(token), nil } -func (h *DefaultDeviceStrategy) ValidateDeviceCode(ctx context.Context, r fosite.Requester, code string) (err error) { +// ValidateDeviceCode validates a device_code +func (h *DefaultDeviceStrategy) ValidateDeviceCode(ctx context.Context, r fosite.Requester, code string) error { // TODO return nil } +// ShouldRateLimit is used to decide whether a request should be rate-limites func (t *DefaultDeviceStrategy) ShouldRateLimit(context context.Context, code string) bool { key := code + "_limiter" diff --git a/handler/rfc8628/strategy_hmacsha_test.go b/handler/rfc8628/strategy_hmacsha_test.go index 894fc2ae3..70b687c0c 100644 --- a/handler/rfc8628/strategy_hmacsha_test.go +++ b/handler/rfc8628/strategy_hmacsha_test.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 package rfc8628_test diff --git a/integration/authorize_device_grant_request_test.go b/integration/authorize_device_grant_request_test.go index 2f548f147..2deb30678 100644 --- a/integration/authorize_device_grant_request_test.go +++ b/integration/authorize_device_grant_request_test.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Ory Corp +// Copyright © 2024 Ory Corp // SPDX-License-Identifier: Apache-2.0 package integration_test @@ -100,7 +100,7 @@ func runDeviceFlowTest(t *testing.T, strategy interface{}) { err: false, }, } { - t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + t.Run(fmt.Sprintf("case=%d description=%s", k, c.description), func(t *testing.T) { // Restore client fositeStore.Clients["device-client"] = &fosite.DefaultClient{ ID: "device-client", diff --git a/oauth2.go b/oauth2.go index b1d737441..cad5a9550 100644 --- a/oauth2.go +++ b/oauth2.go @@ -402,23 +402,36 @@ type G11NContext interface { GetLang() language.Tag } +// DeviceResponder is the device authorization endpoint's response type DeviceResponder interface { + // GetDeviceCode returns the device_code GetDeviceCode() string + // SetDeviceCode sets the device_code SetDeviceCode(code string) + // GetUserCode returns the user_code GetUserCode() string + // SetUserCode sets the user_code SetUserCode(code string) + // GetVerificationURI returns the verification_uri GetVerificationURI() string + // SetVerificationURI sets the verification_uri SetVerificationURI(uri string) + // GetVerificationURIComplete returns the verification_uri_complete GetVerificationURIComplete() string + // SetVerificationURIComplete sets the verification_uri_complete SetVerificationURIComplete(uri string) + // GetExpiresIn returns the expires_in GetExpiresIn() int64 + // SetExpiresIn sets the expires_in SetExpiresIn(seconds int64) + // GetInterval returns the interval GetInterval() int + // SetInterval sets the interval SetInterval(seconds int) // GetHeader returns the response's header diff --git a/storage/memory.go b/storage/memory.go index b4f88aa8a..becddd515 100644 --- a/storage/memory.go +++ b/storage/memory.go @@ -514,6 +514,7 @@ func (s *MemoryStore) DeletePARSession(ctx context.Context, requestURI string) ( return nil } +// CreateDeviceCodeSession stores the device code session func (s *MemoryStore) CreateDeviceCodeSession(_ context.Context, signature string, req fosite.Requester) error { // We first lock accessTokenRequestIDsMutex and then accessTokensMutex because this is the same order // locking happens in RevokeAccessToken and using the same order prevents deadlocks. @@ -527,6 +528,7 @@ func (s *MemoryStore) CreateDeviceCodeSession(_ context.Context, signature strin return nil } +// UpdateDeviceCodeSession updates the device code session func (s *MemoryStore) UpdateDeviceCodeSession(_ context.Context, signature string, req fosite.Requester) error { s.deviceCodesRequestIDsMutex.Lock() defer s.deviceCodesRequestIDsMutex.Unlock() @@ -541,6 +543,7 @@ func (s *MemoryStore) UpdateDeviceCodeSession(_ context.Context, signature strin return nil } +// GetDeviceCodeSession gets the device code session func (s *MemoryStore) GetDeviceCodeSession(_ context.Context, signature string, _ fosite.Session) (fosite.Requester, error) { s.deviceCodesMutex.RLock() defer s.deviceCodesMutex.RUnlock() @@ -552,6 +555,7 @@ func (s *MemoryStore) GetDeviceCodeSession(_ context.Context, signature string, return rel, nil } +// InvalidateDeviceCodeSession invalidates the device code session func (s *MemoryStore) InvalidateDeviceCodeSession(_ context.Context, code string) error { s.deviceCodesRequestIDsMutex.Lock() defer s.deviceCodesRequestIDsMutex.Unlock() @@ -562,6 +566,7 @@ func (s *MemoryStore) InvalidateDeviceCodeSession(_ context.Context, code string return nil } +// CreateUserCodeSession stores the user code session func (s *MemoryStore) CreateUserCodeSession(_ context.Context, signature string, req fosite.Requester) error { s.userCodesRequestIDsMutex.Lock() defer s.userCodesRequestIDsMutex.Unlock() @@ -573,6 +578,7 @@ func (s *MemoryStore) CreateUserCodeSession(_ context.Context, signature string, return nil } +// GetUserCodeSession gets the user code session func (s *MemoryStore) GetUserCodeSession(_ context.Context, signature string, _ fosite.Session) (fosite.Requester, error) { s.userCodesMutex.RLock() defer s.userCodesMutex.RUnlock() @@ -584,6 +590,7 @@ func (s *MemoryStore) GetUserCodeSession(_ context.Context, signature string, _ return rel, nil } +// GetUserCodeSession invalidates the user code session func (s *MemoryStore) InvalidateUserCodeSession(_ context.Context, code string) error { s.userCodesRequestIDsMutex.Lock() defer s.userCodesRequestIDsMutex.Unlock() diff --git a/token/hmac/hmacsha.go b/token/hmac/hmacsha.go index 86875f6b1..cf7507ae9 100644 --- a/token/hmac/hmacsha.go +++ b/token/hmac/hmacsha.go @@ -170,6 +170,7 @@ func (c *HMACStrategy) Signature(token string) string { return split[1] } +// GenerateHMACForString returns an HMAC for a string func (c *HMACStrategy) GenerateHMACForString(ctx context.Context, text string) (string, error) { var signingKey [32]byte