diff --git a/config/auth.go b/config/auth.go index 7c8586d71..5a58245ca 100644 --- a/config/auth.go +++ b/config/auth.go @@ -1,6 +1,7 @@ package config import ( + "errors" "fmt" "net/url" "strings" @@ -81,6 +82,30 @@ func (c *AuthConfig) RefreshTokenRotationEnabled() bool { } } +// AddOidcProvider adds an OpenID Connect provider to the list of supported authentication providers +func (c *AuthConfig) AddOidcProvider(name string, issuerUrl string, clientId string) error { + if name == "" { + return errors.New("provider name cannot be empty") + } + if invalidUrl(issuerUrl) { + return fmt.Errorf("invalid issuerUrl: %s", issuerUrl) + } + if clientId == "" { + return errors.New("provider clientId cannot be empty") + } + + provider := Provider{ + Type: OpenIdConnectProvider, + Name: name, + IssuerUrl: issuerUrl, + ClientId: clientId, + } + + c.Providers = append(c.Providers, provider) + return nil +} + +// GetOidcProviders returns all OpenID Connect compatible authentication providers func (c *AuthConfig) GetOidcProviders() []Provider { oidcProviders := []Provider{} for _, p := range c.Providers { @@ -113,6 +138,7 @@ func (c *AuthConfig) GetOidcProvidersByIssuer(issuer string) ([]Provider, error) return providers, nil } +// GetIssuer retrieves the issuer URL for the provider func (c *Provider) GetIssuer() (string, error) { switch c.Type { case GoogleProvider: @@ -216,8 +242,7 @@ func findAuthProviderMissingClientId(providers []Provider) []Provider { func findAuthProviderMissingOrInvalidIssuerUrl(providers []Provider) []Provider { invalid := []Provider{} for _, p := range providers { - u, err := url.Parse(p.IssuerUrl) - if err != nil || u.Scheme != "https" { + if invalidUrl(p.IssuerUrl) { invalid = append(invalid, p) continue } @@ -230,8 +255,7 @@ func findAuthProviderMissingOrInvalidIssuerUrl(providers []Provider) []Provider func findAuthProviderMissingOrInvalidTokenUrl(providers []Provider) []Provider { invalid := []Provider{} for _, p := range providers { - u, err := url.Parse(p.TokenUrl) - if err != nil || u.Scheme != "https" { + if invalidUrl(p.TokenUrl) { invalid = append(invalid, p) continue } @@ -244,12 +268,15 @@ func findAuthProviderMissingOrInvalidTokenUrl(providers []Provider) []Provider { func findAuthProviderMissingOrInvalidAuthorizationUrl(providers []Provider) []Provider { invalid := []Provider{} for _, p := range providers { - u, err := url.Parse(p.AuthorizationUrl) - if err != nil || u.Scheme != "https" { + if invalidUrl(p.AuthorizationUrl) { invalid = append(invalid, p) continue } } - return invalid } + +func invalidUrl(u string) bool { + parsed, err := url.Parse(u) + return err != nil || parsed.Scheme != "https" +} diff --git a/config/config_test.go b/config/config_test.go index c5d652774..765019fe8 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -245,3 +245,19 @@ func TestGetOidcSameIssuers(t *testing.T) { assert.NoError(t, err) assert.Len(t, googleIssuer, 3) } + +func TestAddOidcProvider(t *testing.T) { + config, err := Load("fixtures/test_auth.yaml") + assert.NoError(t, err) + + assert.Len(t, config.Auth.GetOidcProviders(), 1) + + err = config.Auth.AddOidcProvider("Custom Auth", "https://mycustomoidc.com", "1234") + assert.NoError(t, err) + + assert.Len(t, config.Auth.GetOidcProviders(), 2) + + byIssuer, err := config.Auth.GetOidcProvidersByIssuer("https://mycustomoidc.com") + assert.NoError(t, err) + assert.Len(t, byIssuer, 1) +} diff --git a/runtime/auth/auth_test.go b/runtime/auth/auth_test.go index 7f2c4a350..12518052d 100644 --- a/runtime/auth/auth_test.go +++ b/runtime/auth/auth_test.go @@ -491,7 +491,7 @@ func TestAllowAllIssuers(t *testing.T) { require.NoError(t, err) } -func TestEnVarFallback(t *testing.T) { +func TestEnvVarFallback(t *testing.T) { ctx := newContext()