Skip to content

Commit

Permalink
fix: review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nsklikas committed Feb 12, 2024
1 parent f2c68de commit 6322f68
Show file tree
Hide file tree
Showing 24 changed files with 229 additions and 212 deletions.
4 changes: 3 additions & 1 deletion compose/compose_rfc8628.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright © 2023 Ory Corp
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

// Package compose provides various objects which can be used to
// instantiate OAuth2Providers with different functionality.
package compose

import (
Expand Down
1 change: 1 addition & 0 deletions compose/compose_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func NewOpenIDConnectStrategy(keyGetter func(context.Context) (interface{}, erro
}
}

// Create a new device strategy
func NewDeviceStrategy(config fosite.Configurator) *rfc8628.DefaultDeviceStrategy {
return &rfc8628.DefaultDeviceStrategy{
Enigma: &hmac.HMACStrategy{Config: config},
Expand Down
2 changes: 2 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type IDTokenLifespanProvider interface {
GetIDTokenLifespan(ctx context.Context) time.Duration
}

// DeviceAndUserCodeLifespanProvider returns the provider for configuring the device and user code lifespan
type DeviceAndUserCodeLifespanProvider interface {
GetDeviceAndUserCodeLifespan(ctx context.Context) time.Duration
}
Expand Down Expand Up @@ -80,6 +81,7 @@ type DisableRefreshTokenValidationProvider interface {
GetDisableRefreshTokenValidation(ctx context.Context) bool
}

// DeviceProvider returns the provider for configuring the device flow
type DeviceProvider interface {
GetDeviceVerificationURL(ctx context.Context) string
GetDeviceAuthTokenPollingInterval(ctx context.Context) time.Duration
Expand Down
15 changes: 11 additions & 4 deletions config_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ import (
)

const (
defaultPARPrefix = "urn:ietf:params:oauth:request_uri:"
defaultPARContextLifetime = 5 * time.Minute
defaultPARPrefix = "urn:ietf:params:oauth:request_uri:"
defaultPARContextLifetime = 5 * time.Minute
defaultDeviceAndUserCodeLifespan = 10 * time.Minute
defaultAuthTokenPollingInterval = 5 * time.Second
)

var (
Expand Down Expand Up @@ -257,6 +259,7 @@ func (c *Config) GetTokenIntrospectionHandlers(ctx context.Context) TokenIntrosp
return c.TokenIntrospectionHandlers
}

// GetDeviceEndpointHandlers return the Device Endpoint Handlers
func (c *Config) GetDeviceEndpointHandlers(ctx context.Context) DeviceEndpointHandlers {
return c.DeviceEndpointHandlers
}
Expand Down Expand Up @@ -412,9 +415,11 @@ func (c *Config) GetRefreshTokenLifespan(_ context.Context) time.Duration {
return c.RefreshTokenLifespan
}

// GetDeviceAndUserCodeLifespan returns how long the device and user codes should be valid.
// Defaults to 10 minutes
func (c *Config) GetDeviceAndUserCodeLifespan(_ context.Context) time.Duration {
if c.DeviceAndUserCodeLifespan == 0 {
return time.Minute * 10
return defaultDeviceAndUserCodeLifespan
}
return c.DeviceAndUserCodeLifespan
}
Expand Down Expand Up @@ -523,13 +528,15 @@ func (c *Config) EnforcePushedAuthorize(ctx context.Context) bool {
return c.IsPushedAuthorizeEnforced
}

// GetDeviceVerificationURL returns the device verification URL
func (c *Config) GetDeviceVerificationURL(ctx context.Context) string {
return c.DeviceVerificationURL
}

// GetDeviceAuthTokenPollingInterval returns configured device token endpoint polling interval
func (c *Config) GetDeviceAuthTokenPollingInterval(ctx context.Context) time.Duration {
if c.DeviceAuthTokenPollingInterval == 0 {
return time.Second * 5
return defaultAuthTokenPollingInterval
}
return c.DeviceAuthTokenPollingInterval
}
3 changes: 2 additions & 1 deletion device_request.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright © 2023 Ory Corp
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package fosite
Expand All @@ -8,6 +8,7 @@ type DeviceRequest struct {
Request
}

// NewDeviceRequest returns a new device request
func NewDeviceRequest() *DeviceRequest {
return &DeviceRequest{
Request: *NewRequest(),
Expand Down
33 changes: 8 additions & 25 deletions device_request_handler.go
Original file line number Diff line number Diff line change
@@ -1,40 +1,21 @@
// Copyright © 2023 Ory Corp
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

/*
* Copyright © 2015-2021 Aeneas Rekkas <[email protected]>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* @author Aeneas Rekkas <[email protected]>
* @copyright 2015-2021 Aeneas Rekkas <[email protected]>
* @license Apache-2.0
*
*/

package fosite

import (
"context"
"net/http"
"strings"

"github.com/ory/fosite/i18n"
"github.com/ory/x/errorsx"
"github.com/ory/x/otelx"
"go.opentelemetry.io/otel/trace"

"github.com/ory/fosite/i18n"
)

// NewDeviceRequest parses an http Request returns a Device request
func (f *Fosite) NewDeviceRequest(ctx context.Context, r *http.Request) (_ DeviceRequester, err error) {
ctx, span := trace.SpanFromContext(ctx).TracerProvider().Tracer("github.com/ory/fosite").Start(ctx, "Fosite.NewAccessRequest")
defer otelx.End(span, &err)
Expand All @@ -44,9 +25,11 @@ func (f *Fosite) NewDeviceRequest(ctx context.Context, r *http.Request) (_ Devic

if r.Method != "POST" {
return request, errorsx.WithStack(ErrInvalidRequest.WithHintf("HTTP method is '%s', expected 'POST'.", r.Method))
} else if err := r.ParseForm(); err != nil {
}
if err := r.ParseForm(); err != nil {
return nil, errorsx.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithWrap(err).WithDebug(err.Error()))
} else if len(r.PostForm) == 0 {
}
if len(r.PostForm) == 0 {
return request, errorsx.WithStack(ErrInvalidRequest.WithHint("The POST body can not be empty."))
}
request.Form = r.PostForm
Expand Down
157 changes: 75 additions & 82 deletions device_request_handler_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright © 2023 Ory Corp
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package fosite_test
Expand Down Expand Up @@ -33,94 +33,87 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) {
expectedError error
mock func()
expect DeviceRequester
}{
/* invalid Method */
{
expectedError: ErrInvalidRequest,
method: "GET",
mock: func() {},
description string
}{{
description: "invalid method",
expectedError: ErrInvalidRequest,
method: "GET",
mock: func() {},
}, {
description: "empty request",
expectedError: ErrInvalidRequest,
method: "POST",
mock: func() {},
}, {
description: "invalid client",
form: url.Values{
"client_id": {"client_id"},
"scope": {"foo bar"},
},
/* empty request */
{
expectedError: ErrInvalidRequest,
method: "POST",
mock: func() {},
expectedError: ErrInvalidClient,
method: "POST",
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(nil, errors.New(""))
},
/* invalid client */
{
form: url.Values{
"client_id": {"client_id"},
"scope": {"foo bar"},
},
expectedError: ErrInvalidClient,
method: "POST",
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(nil, errors.New(""))
},
}, {
description: "fails because scope not allowed",
form: url.Values{
"client_id": {"client_id"},
"scope": {"17 42 foo"},
},
/* fails because scope not allowed */
{
form: url.Values{
"client_id": {"client_id"},
"scope": {"17 42 foo"},
},
method: "POST",
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
client.Public = true
client.Scopes = []string{"17", "42"}
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
},
expectedError: ErrInvalidScope,
method: "POST",
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
client.Public = true
client.Scopes = []string{"17", "42"}
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
},
/* fails because scope not allowed */
{
form: url.Values{
"client_id": {"client_id"},
"scope": {"17 42"},
"audience": {"aud"},
},
method: "POST",
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
client.Public = true
client.Scopes = []string{"17", "42"}
client.Audience = []string{"aud2"}
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
},
expectedError: ErrInvalidRequest,
expectedError: ErrInvalidScope,
}, {
description: "fails because audience not allowed",
form: url.Values{
"client_id": {"client_id"},
"scope": {"17 42"},
"audience": {"aud"},
},
/* should fail because doesn't have proper grant */
{
form: url.Values{
"client_id": {"client_id"},
"scope": {"17 42"},
},
method: "POST",
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
client.Public = true
client.Scopes = []string{"17", "42"}
client.GrantTypes = []string{"authorization_code"}
},
expectedError: ErrInvalidGrant,
method: "POST",
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
client.Public = true
client.Scopes = []string{"17", "42"}
client.Audience = []string{"aud2"}
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
},
/* success case */
{
form: url.Values{
"client_id": {"client_id"},
"scope": {"17 42"},
},
method: "POST",
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
client.Public = true
client.Scopes = []string{"17", "42"}
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
},
expectedError: ErrInvalidRequest,
}, {
description: "fails because it doesn't have the proper grant",
form: url.Values{
"client_id": {"client_id"},
"scope": {"17 42"},
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
method: "POST",
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
client.Public = true
client.Scopes = []string{"17", "42"}
client.GrantTypes = []string{"authorization_code"}
},
expectedError: ErrInvalidGrant,
}, {
description: "success",
form: url.Values{
"client_id": {"client_id"},
"scope": {"17 42"},
},
method: "POST",
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil)
client.Public = true
client.Scopes = []string{"17", "42"}
client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"}
},
}} {
t.Run(fmt.Sprintf("case=%d description=%s", k, c.description), func(t *testing.T) {
c.mock()
r := &http.Request{
Header: c.header,
Expand Down
2 changes: 1 addition & 1 deletion device_request_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright © 2023 Ory Corp
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package fosite
Expand Down
Loading

0 comments on commit 6322f68

Please sign in to comment.