Skip to content

Commit

Permalink
feat: refresh token rotation configuration & github and facebook prov…
Browse files Browse the repository at this point in the history
…iders (#1298)
  • Loading branch information
davenewza authored Nov 14, 2023
1 parent 269ff40 commit 4611221
Show file tree
Hide file tree
Showing 14 changed files with 624 additions and 255 deletions.
63 changes: 54 additions & 9 deletions config/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,46 @@ package config
import (
"fmt"
"net/url"
"strings"
"time"

"github.com/samber/lo"
)

const (
// 24 hours is the default access token expiry period
DefaultAccessTokenExpiry time.Duration = time.Hour * 24
// 3 months is the default refresh token expiry period
DefaultRefreshTokenExpiry time.Duration = time.Hour * 24 * 90
)

const (
GoogleProvider = "google"
FacebookProvider = "facebook"
GitLabProvider = "gitlab"
OpenIdConnectProvider = "oidc"
OAuthProvider = "oauth"
)

var (
SupportedProviderTypes = []string{
GoogleProvider,
FacebookProvider,
GitLabProvider,
OpenIdConnectProvider,
OAuthProvider,
}
)

type AuthConfig struct {
Tokens *TokensConfig `yaml:"tokens"`
Providers []Provider `yaml:"providers"`
Tokens TokensConfig `yaml:"tokens"`
Providers []Provider `yaml:"providers"`
}

type TokensConfig struct {
AccessTokenExpiry int `yaml:"accessTokenExpiry"`
RefreshTokenExpiry int `yaml:"refreshTokenExpiry"`
AccessTokenExpiry *int `yaml:"accessTokenExpiry,omitempty"`
RefreshTokenExpiry *int `yaml:"refreshTokenExpiry,omitempty"`
RefreshTokenRotationEnabled *bool `yaml:"refreshTokenRotationEnabled,omitempty"`
}

type Provider struct {
Expand All @@ -40,6 +54,33 @@ type Provider struct {
AuthorizationUrl string `yaml:"authorizationUrl"`
}

// AccessTokenExpiry retrieves the configured or default access token expiry
func (c *AuthConfig) AccessTokenExpiry() time.Duration {
if c.Tokens.AccessTokenExpiry != nil {
return time.Duration(*c.Tokens.AccessTokenExpiry) * time.Second
} else {
return DefaultAccessTokenExpiry
}
}

// RefreshTokenExpiry retrieves the configured or default refresh token expiry
func (c *AuthConfig) RefreshTokenExpiry() time.Duration {
if c.Tokens.RefreshTokenExpiry != nil {
return time.Duration(*c.Tokens.RefreshTokenExpiry) * time.Second
} else {
return DefaultRefreshTokenExpiry
}
}

// RefreshTokenRotationEnabled retrieves the configured or default refresh token rotation
func (c *AuthConfig) RefreshTokenRotationEnabled() bool {
if c.Tokens.RefreshTokenRotationEnabled != nil {
return *c.Tokens.RefreshTokenRotationEnabled
} else {
return true
}
}

func (c *AuthConfig) GetOidcProviders() []Provider {
oidcProviders := []Provider{}
for _, p := range c.Providers {
Expand All @@ -50,9 +91,9 @@ func (c *AuthConfig) GetOidcProviders() []Provider {
return oidcProviders
}

// GetProvidersOidcIssuer gets all providers by issuer url.
// It's possible that multiple providers from the same issuer as configured.
func (c *AuthConfig) GetProvidersOidcIssuer(issuer string) ([]Provider, error) {
// GetOidcProvidersByIssuer gets all OpenID Connect providers by issuer url.
// It's possible that multiple providers from the same issuer are configured.
func (c *AuthConfig) GetOidcProvidersByIssuer(issuer string) ([]Provider, error) {
providers := []Provider{}

for _, p := range c.Providers {
Expand All @@ -64,7 +105,7 @@ func (c *AuthConfig) GetProvidersOidcIssuer(issuer string) ([]Provider, error) {
if err != nil {
return nil, err
}
if issuerUrl == issuer {
if strings.TrimSuffix(issuerUrl, "/") == strings.TrimSuffix(issuer, "/") {
providers = append(providers, p)
}
}
Expand All @@ -75,7 +116,11 @@ func (c *AuthConfig) GetProvidersOidcIssuer(issuer string) ([]Provider, error) {
func (c *Provider) GetIssuer() (string, error) {
switch c.Type {
case GoogleProvider:
return "https://accounts.google.com/", nil
return "https://accounts.google.com", nil
case FacebookProvider:
return "https://www.facebook.com", nil
case GitLabProvider:
return "https://gitlab.com", nil
case OpenIdConnectProvider:
return c.IssuerUrl, nil
default:
Expand Down
5 changes: 2 additions & 3 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,22 +231,21 @@ func Validate(config *ProjectConfig) *ConfigErrors {
if hasIncorrectNames {
for incorrectName := range incorrectNames {
startsWith := reservedEnvVarRegex.FindString(incorrectName)

errors = append(errors, &ConfigError{
Type: "reserved",
Message: fmt.Sprintf(ConfigReservedNameErrorString, incorrectName, startsWith),
})
}
}

if config.Auth.Tokens != nil && config.Auth.Tokens.AccessTokenExpiry < 0 {
if config.Auth.AccessTokenExpiry() <= 0 {
errors = append(errors, &ConfigError{
Type: "invalid",
Message: fmt.Sprintf(ConfigAuthTokenExpiryMustBePositive, "access", "accessTokenExpiry"),
})
}

if config.Auth.Tokens != nil && config.Auth.Tokens.RefreshTokenExpiry < 0 {
if config.Auth.RefreshTokenExpiry() <= 0 {
errors = append(errors, &ConfigError{
Type: "invalid",
Message: fmt.Sprintf(ConfigAuthTokenExpiryMustBePositive, "refresh", "refreshTokenExpiry"),
Expand Down
37 changes: 28 additions & 9 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package config

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -119,8 +120,26 @@ func TestAuthTokens(t *testing.T) {
config, err := Load("fixtures/test_auth.yaml")
assert.NoError(t, err)

assert.Equal(t, 3600, config.Auth.Tokens.AccessTokenExpiry)
assert.Equal(t, 604800, config.Auth.Tokens.RefreshTokenExpiry)
assert.Equal(t, 3600, *config.Auth.Tokens.AccessTokenExpiry)
assert.Equal(t, 604800, *config.Auth.Tokens.RefreshTokenExpiry)
assert.Equal(t, false, *config.Auth.Tokens.RefreshTokenRotationEnabled)

assert.Equal(t, time.Duration(3600)*time.Second, config.Auth.AccessTokenExpiry())
assert.Equal(t, time.Duration(604800)*time.Second, config.Auth.RefreshTokenExpiry())
assert.Equal(t, false, config.Auth.RefreshTokenRotationEnabled())
}

func TestAuthDefaults(t *testing.T) {
config, err := Load("fixtures/test_auth_empty.yaml")
assert.NoError(t, err)

assert.Nil(t, config.Auth.Tokens.AccessTokenExpiry)
assert.Nil(t, config.Auth.Tokens.RefreshTokenExpiry)
assert.Nil(t, config.Auth.Tokens.RefreshTokenRotationEnabled)

assert.Equal(t, time.Duration(24)*time.Hour, config.Auth.AccessTokenExpiry())
assert.Equal(t, time.Duration(24)*time.Hour*90, config.Auth.RefreshTokenExpiry())
assert.Equal(t, true, config.Auth.RefreshTokenRotationEnabled())
}

func TestAuthNegativeTokenLifespan(t *testing.T) {
Expand Down Expand Up @@ -171,9 +190,9 @@ func TestDuplicateProviderName(t *testing.T) {
func TestInvalidProviderTypes(t *testing.T) {
_, err := Load("fixtures/test_auth_invalid_types.yaml")

assert.Contains(t, err.Error(), "auth provider 'google_1' has invalid type 'google_1' which must be one of: google, oidc, oauth\n")
assert.Contains(t, err.Error(), "auth provider 'google_2' has invalid type 'Google' which must be one of: google, oidc, oauth\n")
assert.Contains(t, err.Error(), "auth provider 'Baidu' has invalid type 'whoops' which must be one of: google, oidc, oauth\n")
assert.Contains(t, err.Error(), "auth provider 'google_1' has invalid type 'google_1' which must be one of: google, facebook, gitlab, oidc, oauth\n")
assert.Contains(t, err.Error(), "auth provider 'google_2' has invalid type 'Google' which must be one of: google, facebook, gitlab, oidc, oauth\n")
assert.Contains(t, err.Error(), "auth provider 'Baidu' has invalid type 'whoops' which must be one of: google, facebook, gitlab, oidc, oauth\n")
}

func TestMissingClientId(t *testing.T) {
Expand Down Expand Up @@ -205,15 +224,15 @@ func TestGetOidcIssuer(t *testing.T) {
config, err := Load("fixtures/test_auth.yaml")
assert.NoError(t, err)

googleIssuer, err := config.Auth.GetProvidersOidcIssuer("https://accounts.google.com/")
googleIssuer, err := config.Auth.GetOidcProvidersByIssuer("https://accounts.google.com/")
assert.NoError(t, err)
assert.Len(t, googleIssuer, 2)

auth0Issuer, err := config.Auth.GetProvidersOidcIssuer("https://dev-skhlutl45lbqkvhv.us.auth0.com")
auth0Issuer, err := config.Auth.GetOidcProvidersByIssuer("https://dev-skhlutl45lbqkvhv.us.auth0.com")
assert.NoError(t, err)
assert.Len(t, auth0Issuer, 1)

nopeIssuer, err := config.Auth.GetProvidersOidcIssuer("https://nope.com")
nopeIssuer, err := config.Auth.GetOidcProvidersByIssuer("https://nope.com")
assert.NoError(t, err)
assert.Len(t, nopeIssuer, 0)
}
Expand All @@ -222,7 +241,7 @@ func TestGetOidcSameIssuers(t *testing.T) {
config, err := Load("fixtures/test_auth_same_issuers.yaml")
assert.NoError(t, err)

googleIssuer, err := config.Auth.GetProvidersOidcIssuer("https://accounts.google.com/")
googleIssuer, err := config.Auth.GetOidcProvidersByIssuer("https://accounts.google.com/")
assert.NoError(t, err)
assert.Len(t, googleIssuer, 3)
}
3 changes: 2 additions & 1 deletion config/fixtures/test_auth.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ auth:
tokens:
accessTokenExpiry: 3600
refreshTokenExpiry: 604800

refreshTokenRotationEnabled: false

providers:
# Built-in Google provider
- type: google
Expand Down
Empty file.
45 changes: 43 additions & 2 deletions runtime/apis/authapi/revoke_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,53 @@ import (

"github.com/teamkeel/keel/proto"
"github.com/teamkeel/keel/runtime/common"
"github.com/teamkeel/keel/runtime/oauth"
"go.opentelemetry.io/otel/trace"
)

type RevokeEndpointErrorResponse struct {
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
}

func RevokeHandler(schema *proto.Schema) common.HandlerFunc {
return func(r *http.Request) common.Response {
return common.Response{
Status: http.StatusNotImplemented,
ctx, span := tracer.Start(r.Context(), "Revoke Token")
defer span.End()

if r.Method != http.MethodPost {
return common.NewJsonResponse(http.StatusMethodNotAllowed, &ErrorResponse{
Error: InvalidRequest,
ErrorDescription: "the revoke endpoint only accepts POST",
}, nil)
}

if !HasContentType(r.Header, "application/x-www-form-urlencoded") {
return common.NewJsonResponse(http.StatusBadRequest, &ErrorResponse{
Error: InvalidRequest,
ErrorDescription: "the request must be an encoded form with Content-Type application/x-www-form-urlencoded",
}, nil)
}

refreshTokenRaw := r.FormValue(ArgToken)

if refreshTokenRaw == "" {
return common.NewJsonResponse(http.StatusBadRequest, &ErrorResponse{
Error: InvalidRequest,
ErrorDescription: "the refresh token must be provided in the token field",
}, nil)
}

// Revoke the refresh token
err := oauth.RevokeRefreshToken(ctx, refreshTokenRaw)
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
return common.NewJsonResponse(http.StatusUnauthorized, &ErrorResponse{
Error: InvalidClient,
ErrorDescription: "possible causes may be that the id token is invalid, has expired, or has insufficient claims",
}, nil)
}

return common.NewJsonResponse(http.StatusOK, nil, nil)
}
}
Loading

0 comments on commit 4611221

Please sign in to comment.