Skip to content

Commit

Permalink
feat: keelconfig auth configuration (#1296)
Browse files Browse the repository at this point in the history
  • Loading branch information
davenewza authored Nov 13, 2023
1 parent 7fc95de commit e75a27c
Show file tree
Hide file tree
Showing 21 changed files with 1,035 additions and 106 deletions.
1 change: 1 addition & 0 deletions cmd/program/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {

ctx = db.WithDatabase(ctx, m.Database)
ctx = runtimectx.WithSecrets(ctx, m.Secrets)
ctx = runtimectx.WithOAuthConfig(ctx, &m.Config.Auth)

mailClient := mail.NewSMTPClientFromEnv()
if mailClient != nil {
Expand Down
205 changes: 205 additions & 0 deletions config/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
package config

import (
"fmt"
"net/url"

"github.com/samber/lo"
)

const (
GoogleProvider = "google"
OpenIdConnectProvider = "oidc"
OAuthProvider = "oauth"
)

var (
SupportedProviderTypes = []string{
GoogleProvider,
OpenIdConnectProvider,
OAuthProvider,
}
)

type AuthConfig struct {
Tokens *TokensConfig `yaml:"tokens"`
Providers []Provider `yaml:"providers"`
}

type TokensConfig struct {
AccessTokenExpiry int `yaml:"accessTokenExpiry"`
RefreshTokenExpiry int `yaml:"refreshTokenExpiry"`
}

type Provider struct {
Type string `yaml:"type"`
Name string `yaml:"name"`
ClientId string `yaml:"clientId"`
IssuerUrl string `yaml:"issuerUrl"`
TokenUrl string `yaml:"tokenUrl"`
AuthorizationUrl string `yaml:"authorizationUrl"`
}

func (c *AuthConfig) GetOidcProviders() []Provider {
oidcProviders := []Provider{}
for _, p := range c.Providers {
if p.Type == OpenIdConnectProvider {
oidcProviders = append(oidcProviders, p)
}
}
return oidcProviders
}

func (c *AuthConfig) HasOidcIssuer(issuer string) (bool, error) {
for _, p := range c.Providers {
if p.Type == OAuthProvider {
continue
}

issuerUrl, err := p.GetIssuer()
if err != nil {
return false, err
}
if issuerUrl == issuer {
return true, nil
}
}
return false, nil
}

func (c *Provider) GetIssuer() (string, error) {
switch c.Type {
case GoogleProvider:
return "https://accounts.google.com/", nil
case OpenIdConnectProvider:
return c.IssuerUrl, nil
default:
return "", fmt.Errorf("the provider type '%s' should not have an issuer url configured", c.Type)
}
}

func (c *AuthConfig) GetOAuthProviders() []Provider {
oidcProviders := []Provider{}
for _, p := range c.Providers {
if p.Type == OAuthProvider {
oidcProviders = append(oidcProviders, p)
}
}
return oidcProviders
}

func (c *Provider) GetTokenUrl() (string, error) {
switch c.Type {
case GoogleProvider:
return "https://accounts.google.com/o/oauth2/token", nil
case OAuthProvider:
return c.TokenUrl, nil
default:
return "", fmt.Errorf("the provider type '%s' should not have a token url configured", c.Type)
}
}

func (c *Provider) GetAuthorizationUrl() (string, error) {
switch c.Type {
case GoogleProvider:
return "https://accounts.google.com/o/oauth2/auth", nil
case OAuthProvider:
return c.AuthorizationUrl, nil
default:
return "", fmt.Errorf("the provider type '%s' should not have a token url configured", c.Type)
}
}

// findAuthProviderMissingName checks for missing provider names
func findAuthProviderMissingName(providers []Provider) []Provider {
invalid := []Provider{}
for _, p := range providers {
if p.Name == "" {
invalid = append(invalid, p)
}
}

return invalid
}

// findAuthProviderDuplicateName checks for duplicate auth provider names
func findAuthProviderDuplicateName(providers []Provider) []Provider {
keys := make(map[string]bool)

duplicates := []Provider{}
for _, p := range providers {
if _, value := keys[p.Name]; !value {
keys[p.Name] = true
} else {
duplicates = append(duplicates, p)
}
}

return duplicates
}

// findAuthProviderInvalidType checks for invalid provider types
func findAuthProviderInvalidType(providers []Provider) []Provider {
invalid := []Provider{}
for _, p := range providers {
if !lo.Contains(SupportedProviderTypes, p.Type) {
invalid = append(invalid, p)
}
}

return invalid
}

// findAuthProviderMissingClientId checks for missing client IDs
func findAuthProviderMissingClientId(providers []Provider) []Provider {
invalid := []Provider{}
for _, p := range providers {
if p.ClientId == "" {
invalid = append(invalid, p)
}
}

return invalid
}

// findAuthProviderMissingIssuerUrl checks for missing or invalid issuer URLs
func findAuthProviderMissingOrInvalidIssuerUrl(providers []Provider) []Provider {
invalid := []Provider{}
for _, p := range providers {
u, err := url.Parse(p.IssuerUrl)
if err != nil || u.Scheme != "https" {
invalid = append(invalid, p)
continue
}
}

return invalid
}

// findAuthProviderMissingOrInvalidTokenUrl checks for missing or invalid token URLs
func findAuthProviderMissingOrInvalidTokenUrl(providers []Provider) []Provider {
invalid := []Provider{}
for _, p := range providers {
u, err := url.Parse(p.TokenUrl)
if err != nil || u.Scheme != "https" {
invalid = append(invalid, p)
continue
}
}

return invalid
}

// findAuthProviderMissingOrInvalidAuthorizationUrl checks for missing or invalid authorization URLs
func findAuthProviderMissingOrInvalidAuthorizationUrl(providers []Provider) []Provider {
invalid := []Provider{}
for _, p := range providers {
u, err := url.Parse(p.AuthorizationUrl)
if err != nil || u.Scheme != "https" {
invalid = append(invalid, p)
continue
}
}

return invalid
}
103 changes: 99 additions & 4 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Config struct{}
type ProjectConfig struct {
Environment EnvironmentConfig `yaml:"environment"`
Secrets []Input `yaml:"secrets"`
Auth AuthConfig `yaml:"auth"`
DisableAuth bool `yaml:"disableKeelAuth"`
}

Expand Down Expand Up @@ -117,10 +118,16 @@ type ConfigError struct {
}

const (
ConfigDuplicateErrorString = "environment variable %s has a duplicate set in environment: %s"
ConfigRequiredErrorString = "environment variable %s is required but not defined in the following environments: %s"
ConfigIncorrectNamingErrorString = "%s must be written in upper snakecase"
ConfigReservedNameErrorString = "environment variable %s cannot start with %s as it is reserved"
ConfigDuplicateErrorString = "environment variable %s has a duplicate set in environment: %s"
ConfigRequiredErrorString = "environment variable %s is required but not defined in the following environments: %s"
ConfigIncorrectNamingErrorString = "%s must be written in upper snakecase"
ConfigReservedNameErrorString = "environment variable %s cannot start with %s as it is reserved"
ConfigAuthTokenExpiryMustBePositive = "%s token lifespan cannot be negative or zero for field: %s"
ConfigAuthProviderMissingFieldAtIndexErrorString = "auth provider at index %v is missing field: %s"
ConfigAuthProviderMissingFieldErrorString = "auth provider '%s' is missing field: %s"
ConfigAuthProviderInvalidTypeErrorString = "auth provider '%s' has invalid type '%s' which must be one of: %s"
ConfigAuthProviderDuplicateErrorString = "auth provider name '%s' has been defined more than once, but must be unique"
ConfigAuthProviderInvalidHttpUrlErrorString = "auth provider '%s' has missing or invalid https url for field: %s"
)

type ConfigErrors struct {
Expand Down Expand Up @@ -232,6 +239,94 @@ func Validate(config *ProjectConfig) *ConfigErrors {
}
}

if config.Auth.Tokens != nil && config.Auth.Tokens.AccessTokenExpiry <= 0 {
errors = append(errors, &ConfigError{
Type: "invalid",
Message: fmt.Sprintf(ConfigAuthTokenExpiryMustBePositive, "access", "accessTokenExpiry"),
})
}

if config.Auth.Tokens != nil && config.Auth.Tokens.RefreshTokenExpiry <= 0 {
errors = append(errors, &ConfigError{
Type: "invalid",
Message: fmt.Sprintf(ConfigAuthTokenExpiryMustBePositive, "refresh", "refreshTokenExpiry"),
})
}

missingProviderNames := findAuthProviderMissingName(config.Auth.Providers)
for i := range missingProviderNames {
errors = append(errors, &ConfigError{
Type: "missing",
Message: fmt.Sprintf(ConfigAuthProviderMissingFieldAtIndexErrorString, i, "name"),
})
}

invalidProviderTypes := findAuthProviderInvalidType(config.Auth.Providers)
for _, p := range invalidProviderTypes {
if p.Name == "" {
continue
}
errors = append(errors, &ConfigError{
Type: "missing",
Message: fmt.Sprintf(ConfigAuthProviderInvalidTypeErrorString, p.Name, p.Type, strings.Join(SupportedProviderTypes, ", ")),
})
}

duplicateProviders := findAuthProviderDuplicateName(config.Auth.Providers)
for _, p := range duplicateProviders {
if p.Name == "" {
continue
}
errors = append(errors, &ConfigError{
Type: "duplicate",
Message: fmt.Sprintf(ConfigAuthProviderDuplicateErrorString, p.Name),
})
}

missingClientIds := findAuthProviderMissingClientId(config.Auth.Providers)
for _, p := range missingClientIds {
if p.Name == "" {
continue
}
errors = append(errors, &ConfigError{
Type: "missing",
Message: fmt.Sprintf(ConfigAuthProviderMissingFieldErrorString, p.Name, "clientId"),
})
}

missingOrInvalidIssuerUrls := findAuthProviderMissingOrInvalidIssuerUrl(config.Auth.GetOidcProviders())
for _, p := range missingOrInvalidIssuerUrls {
if p.Name == "" {
continue
}
errors = append(errors, &ConfigError{
Type: "invalid",
Message: fmt.Sprintf(ConfigAuthProviderInvalidHttpUrlErrorString, p.Name, "issuerUrl"),
})
}

missingOrInvalidTokenUrls := findAuthProviderMissingOrInvalidTokenUrl(config.Auth.GetOAuthProviders())
for _, p := range missingOrInvalidTokenUrls {
if p.Name == "" {
continue
}
errors = append(errors, &ConfigError{
Type: "invalid",
Message: fmt.Sprintf(ConfigAuthProviderInvalidHttpUrlErrorString, p.Name, "tokenUrl"),
})
}

missingOrInvalidAuthUrls := findAuthProviderMissingOrInvalidAuthorizationUrl(config.Auth.GetOAuthProviders())
for _, p := range missingOrInvalidAuthUrls {
if p.Name == "" {
continue
}
errors = append(errors, &ConfigError{
Type: "invalid",
Message: fmt.Sprintf(ConfigAuthProviderInvalidHttpUrlErrorString, p.Name, "authorizationUrl"),
})
}

if len(errors) == 0 {
return nil
}
Expand Down
Loading

0 comments on commit e75a27c

Please sign in to comment.