From 665a5b627c5f257c821a09bf08209795be7dfe51 Mon Sep 17 00:00:00 2001 From: Maksim Nabokikh Date: Wed, 10 Jan 2024 20:03:37 +0100 Subject: [PATCH] Override OIDC provider discovered claims (#3267) Signed-off-by: m.nabokikh --- connector/oidc/oidc.go | 70 ++++++++++++++++++++++++++++++++++--- connector/oidc/oidc_test.go | 51 +++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 4 deletions(-) diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 21129f2227..b125979b99 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -27,6 +27,10 @@ type Config struct { ClientSecret string `json:"clientSecret"` RedirectURI string `json:"redirectURI"` + // The section to override options discovered automatically from + // the providers' discovery URL (.well-known/openid-configuration). + ProviderDiscoveryOverrides ProviderDiscoveryOverrides `json:"providerDiscoveryOverrides"` + // Causes client_secret to be passed as POST parameters instead of basic // auth. This is specifically "NOT RECOMMENDED" by the OAuth2 RFC, but some // providers require it. @@ -96,6 +100,61 @@ type Config struct { } `json:"claimModifications"` } +type ProviderDiscoveryOverrides struct { + // TokenURL provides a way to user overwrite the Token URL + // from the .well-known/openid-configuration token_endpoint + TokenURL string `json:"tokenURL"` + // AuthURL provides a way to user overwrite the Auth URL + // from the .well-known/openid-configuration authorization_endpoint + AuthURL string `json:"authURL"` +} + +func (o *ProviderDiscoveryOverrides) Empty() bool { + return o.TokenURL == "" && o.AuthURL == "" +} + +func getProvider(ctx context.Context, issuer string, overrides ProviderDiscoveryOverrides) (*oidc.Provider, error) { + provider, err := oidc.NewProvider(ctx, issuer) + if err != nil { + return nil, fmt.Errorf("failed to get provider: %v", err) + } + + if overrides.Empty() { + return provider, nil + } + + v := &struct { + Issuer string `json:"issuer"` + AuthURL string `json:"authorization_endpoint"` + TokenURL string `json:"token_endpoint"` + DeviceAuthURL string `json:"device_authorization_endpoint"` + JWKSURL string `json:"jwks_uri"` + UserInfoURL string `json:"userinfo_endpoint"` + Algorithms []string `json:"id_token_signing_alg_values_supported"` + }{} + if err := provider.Claims(v); err != nil { + return nil, fmt.Errorf("failed to extract provider discovery claims: %v", err) + } + config := oidc.ProviderConfig{ + IssuerURL: v.Issuer, + AuthURL: v.AuthURL, + TokenURL: v.TokenURL, + DeviceAuthURL: v.DeviceAuthURL, + JWKSURL: v.JWKSURL, + UserInfoURL: v.UserInfoURL, + Algorithms: v.Algorithms, + } + + if overrides.TokenURL != "" { + config.TokenURL = overrides.TokenURL + } + if overrides.AuthURL != "" { + config.AuthURL = overrides.AuthURL + } + + return config.NewProvider(context.Background()), nil +} + // NewGroupFromClaims creates a new group from a list of claims and appends it to the list of existing groups. type NewGroupFromClaims struct { // List of claim to join together @@ -152,13 +211,16 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e return nil, err } - ctx, cancel := context.WithCancel(context.Background()) - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + bgctx, cancel := context.WithCancel(context.Background()) + ctx := context.WithValue(bgctx, oauth2.HTTPClient, httpClient) - provider, err := oidc.NewProvider(ctx, c.Issuer) + provider, err := getProvider(ctx, c.Issuer, c.ProviderDiscoveryOverrides) if err != nil { cancel() - return nil, fmt.Errorf("failed to get provider: %v", err) + return nil, err + } + if !c.ProviderDiscoveryOverrides.Empty() { + logger.Warnf("overrides for connector %q are set, this can be a vulnerability when not properly configured", id) } endpoint := provider.Endpoint() diff --git a/connector/oidc/oidc_test.go b/connector/oidc/oidc_test.go index 4bb84a40d6..950d158338 100644 --- a/connector/oidc/oidc_test.go +++ b/connector/oidc/oidc_test.go @@ -584,6 +584,57 @@ func TestTokenIdentity(t *testing.T) { } } +func TestProviderOverride(t *testing.T) { + testServer, err := setupServer(map[string]any{ + "sub": "subvalue", + "name": "namevalue", + }, true) + if err != nil { + t.Fatal("failed to setup test server", err) + } + + t.Run("No override", func(t *testing.T) { + conn, err := newConnector(Config{ + Issuer: testServer.URL, + Scopes: []string{"openid", "groups"}, + }) + if err != nil { + t.Fatal("failed to create new connector", err) + } + + expAuth := fmt.Sprintf("%s/authorize", testServer.URL) + if conn.provider.Endpoint().AuthURL != expAuth { + t.Fatalf("unexpected auth URL: %s, expected: %s\n", conn.provider.Endpoint().AuthURL, expAuth) + } + + expToken := fmt.Sprintf("%s/token", testServer.URL) + if conn.provider.Endpoint().TokenURL != expToken { + t.Fatalf("unexpected token URL: %s, expected: %s\n", conn.provider.Endpoint().TokenURL, expToken) + } + }) + + t.Run("Override", func(t *testing.T) { + conn, err := newConnector(Config{ + Issuer: testServer.URL, + Scopes: []string{"openid", "groups"}, + ProviderDiscoveryOverrides: ProviderDiscoveryOverrides{TokenURL: "/test1", AuthURL: "/test2"}, + }) + if err != nil { + t.Fatal("failed to create new connector", err) + } + + expAuth := "/test2" + if conn.provider.Endpoint().AuthURL != expAuth { + t.Fatalf("unexpected auth URL: %s, expected: %s\n", conn.provider.Endpoint().AuthURL, expAuth) + } + + expToken := "/test1" + if conn.provider.Endpoint().TokenURL != expToken { + t.Fatalf("unexpected token URL: %s, expected: %s\n", conn.provider.Endpoint().TokenURL, expToken) + } + }) +} + func setupServer(tok map[string]interface{}, idTokenDesired bool) (*httptest.Server, error) { key, err := rsa.GenerateKey(rand.Reader, 1024) if err != nil {