diff --git a/cmd/argocd/commands/login.go b/cmd/argocd/commands/login.go index 72b89dae1771c..d8e31f4fd679b 100644 --- a/cmd/argocd/commands/login.go +++ b/cmd/argocd/commands/login.go @@ -20,6 +20,7 @@ import ( "golang.org/x/oauth2" "github.com/argoproj/argo-cd/v2/cmd/argocd/commands/headless" + "github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils" argocdclient "github.com/argoproj/argo-cd/v2/pkg/apiclient" sessionpkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/session" settingspkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/settings" @@ -196,7 +197,7 @@ func userDisplayName(claims jwt.MapClaims) string { if name := jwtutil.StringField(claims, "name"); name != "" { return name } - return jwtutil.StringField(claims, "sub") + return utils.GetUserIdentifier(claims) } // oauth2Login opens a browser, runs a temporary HTTP server to delegate OAuth2 login flow and diff --git a/cmd/argocd/commands/project_role.go b/cmd/argocd/commands/project_role.go index a0da6793fa7e6..2bfe74739375d 100644 --- a/cmd/argocd/commands/project_role.go +++ b/cmd/argocd/commands/project_role.go @@ -332,7 +332,7 @@ Create token succeeded for proj:test-project:test-role. issuedAt, _ := jwt.IssuedAt(claims) expiresAt := int64(jwt.Float64Field(claims, "exp")) id := jwt.StringField(claims, "jti") - subject := jwt.StringField(claims, "sub") + subject := utils.GetUserIdentifier(claims) if !outputTokenOnly { fmt.Printf("Create token succeeded for %s.\n", subject) diff --git a/cmd/argocd/commands/utils/claims.go b/cmd/argocd/commands/utils/claims.go new file mode 100644 index 0000000000000..f901ab50e0b8e --- /dev/null +++ b/cmd/argocd/commands/utils/claims.go @@ -0,0 +1,19 @@ +package utils + +import ( + "github.com/golang-jwt/jwt/v4" +) + +// GetUserIdentifier returns a consistent user identifier, checking federated_claims.user_id when Dex is in use +func GetUserIdentifier(claims jwt.MapClaims) string { + if federatedClaims, ok := claims["federated_claims"].(map[string]interface{}); ok { + if userID, exists := federatedClaims["user_id"].(string); exists && userID != "" { + return userID + } + } + // Fallback to sub + if sub, ok := claims["sub"].(string); ok && sub != "" { + return sub + } + return "" +} diff --git a/server/account/account.go b/server/account/account.go index 541401a731022..0e377329f4fe2 100644 --- a/server/account/account.go +++ b/server/account/account.go @@ -47,6 +47,7 @@ func (s *Server) UpdatePassword(ctx context.Context, q *account.UpdatePasswordRe // check for permission is user is trying to change someone else's password // assuming user is trying to update someone else if username is different or issuer is not Argo CD if updatedUsername != username || issuer != session.SessionManagerClaimsIssuer { + log.Printf("Claims: %+v", ctx.Value("claims")) // this line for debug if err := s.enf.EnforceErr(ctx.Value("claims"), rbacpolicy.ResourceAccounts, rbacpolicy.ActionUpdate, q.Name); err != nil { return nil, fmt.Errorf("permission denied: %w", err) } diff --git a/server/account/account_test.go b/server/account/account_test.go index 2e7f9ab669e9d..ffcba4209d7a0 100644 --- a/server/account/account_test.go +++ b/server/account/account_test.go @@ -82,30 +82,54 @@ func getAdminAccount(mgr *settings.SettingsManager) (*settings.Account, error) { } func adminContext(ctx context.Context) context.Context { - // nolint:staticcheck - return context.WithValue(ctx, "claims", &jwt.RegisteredClaims{Subject: "admin", Issuer: sessionutil.SessionManagerClaimsIssuer}) + claims := jwt.MapClaims{ + "sub": "admin", + "iss": sessionutil.SessionManagerClaimsIssuer, + "groups": []string{"role:admin"}, + "federated_claims": map[string]interface{}{ + "user_id": "admin", + }, + } + ctx = context.WithValue(ctx, sessionutil.ClaimsKey(), claims) + //nolint:staticcheck + ctx = context.WithValue(ctx, "claims", claims) + return ctx } func ssoAdminContext(ctx context.Context, iat time.Time) context.Context { - // nolint:staticcheck - return context.WithValue(ctx, "claims", &jwt.RegisteredClaims{ - Subject: "admin", - Issuer: "https://myargocdhost.com/api/dex", - IssuedAt: jwt.NewNumericDate(iat), - }) + claims := jwt.MapClaims{ + "sub": "admin", + "iss": "https://myargocdhost.com/api/dex", + "iat": jwt.NewNumericDate(iat), + "groups": []string{"role:admin"}, // Add admin group + "federated_claims": map[string]interface{}{ + "user_id": "admin", + }, + } + // Set both context values + ctx = context.WithValue(ctx, sessionutil.ClaimsKey(), claims) + //nolint:staticcheck + ctx = context.WithValue(ctx, "claims", claims) + + return ctx } func projTokenContext(ctx context.Context) context.Context { + claims := jwt.MapClaims{ + "sub": "proj:demo:deployer", + "iss": sessionutil.SessionManagerClaimsIssuer, + "groups": []string{"proj:demo:deployer"}, + } + ctx = context.WithValue(ctx, sessionutil.ClaimsKey(), claims) // nolint:staticcheck - return context.WithValue(ctx, "claims", &jwt.RegisteredClaims{ - Subject: "proj:demo:deployer", - Issuer: sessionutil.SessionManagerClaimsIssuer, - }) + ctx = context.WithValue(ctx, "claims", claims) + return ctx } func TestUpdatePassword(t *testing.T) { accountServer, sessionServer := newTestAccountServer(context.Background()) ctx := adminContext(context.Background()) + var err error // ensure password is not allowed to be updated if given bad password diff --git a/server/rbacpolicy/rbacpolicy.go b/server/rbacpolicy/rbacpolicy.go index 800dcd43c064a..326f294298743 100644 --- a/server/rbacpolicy/rbacpolicy.go +++ b/server/rbacpolicy/rbacpolicy.go @@ -6,6 +6,7 @@ import ( "github.com/golang-jwt/jwt/v4" log "github.com/sirupsen/logrus" + "github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils" "github.com/argoproj/argo-cd/v2/pkg/apis/application/v1alpha1" applister "github.com/argoproj/argo-cd/v2/pkg/client/listers/application/v1alpha1" jwtutil "github.com/argoproj/argo-cd/v2/util/jwt" @@ -114,7 +115,7 @@ func (p *RBACPolicyEnforcer) EnforceClaims(claims jwt.Claims, rvals ...interface return false } - subject := jwtutil.StringField(mapClaims, "sub") + subject := utils.GetUserIdentifier(mapClaims) // Check if the request is for an application resource. We have special enforcement which takes // into consideration the project's token and group bindings var runtimePolicy string diff --git a/server/server.go b/server/server.go index 6625461dfab03..b4be459194e28 100644 --- a/server/server.go +++ b/server/server.go @@ -59,6 +59,7 @@ import ( "k8s.io/client-go/tools/cache" "sigs.k8s.io/controller-runtime/pkg/client" + "github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils" "github.com/argoproj/argo-cd/v2/common" "github.com/argoproj/argo-cd/v2/pkg/apiclient" accountpkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/account" @@ -1417,7 +1418,7 @@ func (a *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, error log.Errorf("error fetching user info endpoint: %v", err) return claims, "", status.Errorf(codes.Internal, "invalid userinfo response") } - if groupClaims["sub"] != userInfo["sub"] { + if utils.GetUserIdentifier(groupClaims) != utils.GetUserIdentifier(userInfo) { return claims, "", status.Error(codes.Unknown, "subject of claims from user info endpoint didn't match subject of idToken, see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo") } groupClaims["groups"] = userInfo["groups"] diff --git a/test/container/Dockerfile b/test/container/Dockerfile index cb01e50e0952b..8909b9f49bf53 100644 --- a/test/container/Dockerfile +++ b/test/container/Dockerfile @@ -8,7 +8,7 @@ RUN ln -s /usr/lib/$(uname -m)-linux-gnu /usr/lib/linux-gnu # Please make sure to also check the contained yarn version and update the references below when upgrading this image's version FROM docker.io/library/node:22.9.0@sha256:69e667a79aa41ec0db50bc452a60e705ca16f35285eaf037ebe627a65a5cdf52 as node -FROM docker.io/library/golang:1.23.3@sha256:d56c3e08fe5b27729ee3834854ae8f7015af48fd651cd25d1e3bcf3c19830174 as golang +FROM docker.io/library/golang:1.23.1@sha256:4f063a24d429510e512cc730c3330292ff49f3ade3ae79bda8f84a24fa25ecb0 as golang FROM docker.io/library/registry:2.8@sha256:ac0192b549007e22998eb74e8d8488dcfe70f1489520c3b144a6047ac5efbe90 as registry diff --git a/util/oidc/oidc.go b/util/oidc/oidc.go index 2f01dc167e3d4..f1f2e4bbc82ee 100644 --- a/util/oidc/oidc.go +++ b/util/oidc/oidc.go @@ -21,6 +21,7 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/oauth2" + "github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils" "github.com/argoproj/argo-cd/v2/common" "github.com/argoproj/argo-cd/v2/server/settings/oidc" "github.com/argoproj/argo-cd/v2/util/cache" @@ -402,9 +403,8 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) { log.Errorf("cannot encrypt accessToken: %v (claims=%s)", err, claimsJSON) return } - sub := jwtutil.StringField(claims, "sub") err = a.clientCache.Set(&cache.Item{ - Key: formatAccessTokenCacheKey(sub), + Key: formatAccessTokenCacheKey(claims), Object: encToken, CacheActionOpts: cache.CacheActionOpts{ Expiration: getTokenExpiration(claims), @@ -552,12 +552,12 @@ func createClaimsAuthenticationRequestParameter(requestedClaims map[string]*oidc // GetUserInfo queries the IDP userinfo endpoint for claims func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) { - sub := jwtutil.StringField(actualClaims, "sub") + sub := utils.GetUserIdentifier(actualClaims) var claims jwt.MapClaims var encClaims []byte // in case we got it in the cache, we just return the item - clientCacheKey := formatUserInfoResponseCacheKey(sub) + clientCacheKey := formatUserInfoResponseCacheKey(actualClaims) if err := a.clientCache.Get(clientCacheKey, &encClaims); err == nil { claimsRaw, err := crypto.Decrypt(encClaims, a.encryptionKey) if err != nil { @@ -575,7 +575,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP // check if the accessToken for the user is still present var encAccessToken []byte - err := a.clientCache.Get(formatAccessTokenCacheKey(sub), &encAccessToken) + err := a.clientCache.Get(formatAccessTokenCacheKey(actualClaims), &encAccessToken) // without an accessToken we can't query the user info endpoint // thus the user needs to reauthenticate for argocd to get a new accessToken if errors.Is(err, cache.ErrCacheMiss) { @@ -607,6 +607,9 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP if response.StatusCode == http.StatusUnauthorized { return claims, true, err } + if response.StatusCode == http.StatusNotFound { + return jwt.MapClaims{}, true, fmt.Errorf("user info path not found: %s", userInfoPath) + } // according to https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponseValidation // the response should be validated @@ -684,11 +687,13 @@ func getTokenExpiration(claims jwt.MapClaims) time.Duration { } // formatUserInfoResponseCacheKey returns the key which is used to store userinfo of user in cache -func formatUserInfoResponseCacheKey(sub string) string { - return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, sub) +func formatUserInfoResponseCacheKey(claims jwt.MapClaims) string { + userID := utils.GetUserIdentifier(claims) + return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, userID) } // formatAccessTokenCacheKey returns the key which is used to store the accessToken of a user in cache -func formatAccessTokenCacheKey(sub string) string { - return fmt.Sprintf("%s_%s", AccessTokenCachePrefix, sub) +func formatAccessTokenCacheKey(claims jwt.MapClaims) string { + userID := utils.GetUserIdentifier(claims) + return fmt.Sprintf("%s_%s", AccessTokenCachePrefix, userID) } diff --git a/util/oidc/oidc_test.go b/util/oidc/oidc_test.go index 40c606dcd9671..2aebddaf7a706 100644 --- a/util/oidc/oidc_test.go +++ b/util/oidc/oidc_test.go @@ -629,9 +629,9 @@ func TestGetUserInfo(t *testing.T) { { name: "call UserInfo with wrong userInfoPath", userInfoPath: "/user", - expectedOutput: jwt.MapClaims(nil), + expectedOutput: jwt.MapClaims{}, expectError: true, - expectUnauthenticated: false, + expectUnauthenticated: true, expectedCacheItems: []struct { key string value string @@ -639,11 +639,11 @@ func TestGetUserInfo(t *testing.T) { expectError bool }{ { - key: formatUserInfoResponseCacheKey("randomUser"), + key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]interface{}{"user_id": "randomUser"}}), expectError: true, }, }, - idpClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + idpClaims: jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]interface{}{"user_id": "randomUser"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, idpHandler: func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) }, @@ -654,7 +654,7 @@ func TestGetUserInfo(t *testing.T) { encrypt bool }{ { - key: formatAccessTokenCacheKey("randomUser"), + key: formatAccessTokenCacheKey(jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]interface{}{"user_id": "randomUser"}}), value: "FakeAccessToken", encrypt: true, }, @@ -673,11 +673,11 @@ func TestGetUserInfo(t *testing.T) { expectError bool }{ { - key: formatUserInfoResponseCacheKey("randomUser"), + key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "fallbackUser"}), expectError: true, }, }, - idpClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + idpClaims: jwt.MapClaims{"sub": "fallbackUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, idpHandler: func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) }, @@ -688,7 +688,7 @@ func TestGetUserInfo(t *testing.T) { encrypt bool }{ { - key: formatAccessTokenCacheKey("randomUser"), + key: formatAccessTokenCacheKey(jwt.MapClaims{"sub": "fallbackUser"}), value: "FakeAccessToken", encrypt: true, }, @@ -707,11 +707,11 @@ func TestGetUserInfo(t *testing.T) { expectError bool }{ { - key: formatUserInfoResponseCacheKey("randomUser"), + key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]interface{}{"user_id": "randomUser"}}), expectError: true, }, }, - idpClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + idpClaims: jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]interface{}{"user_id": "randomUser"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, idpHandler: func(w http.ResponseWriter, r *http.Request) { userInfoBytes := ` notevenJsongarbage @@ -730,7 +730,7 @@ func TestGetUserInfo(t *testing.T) { encrypt bool }{ { - key: formatAccessTokenCacheKey("randomUser"), + key: formatAccessTokenCacheKey(jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]interface{}{"user_id": "randomUser"}}), value: "FakeAccessToken", encrypt: true, }, @@ -749,11 +749,11 @@ func TestGetUserInfo(t *testing.T) { expectError bool }{ { - key: formatUserInfoResponseCacheKey("randomUser"), + key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]interface{}{"user_id": "randomUser"}}), expectError: true, }, }, - idpClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + idpClaims: jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]interface{}{"user_id": "randomUser"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, idpHandler: func(w http.ResponseWriter, r *http.Request) { userInfoBytes := ` { @@ -782,13 +782,13 @@ func TestGetUserInfo(t *testing.T) { expectError bool }{ { - key: formatUserInfoResponseCacheKey("randomUser"), + key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]interface{}{"user_id": "randomUser"}}), value: "{\"groups\":[\"githubOrg:engineers\"]}", expectEncrypted: true, expectError: false, }, }, - idpClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + idpClaims: jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]interface{}{"user_id": "randomUser"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, idpHandler: func(w http.ResponseWriter, r *http.Request) { userInfoBytes := ` { @@ -809,7 +809,177 @@ func TestGetUserInfo(t *testing.T) { encrypt bool }{ { - key: formatAccessTokenCacheKey("randomUser"), + key: formatAccessTokenCacheKey(jwt.MapClaims{"sub": "randomUser", "federated_claims": map[string]interface{}{"user_id": "randomUser"}}), + value: "FakeAccessToken", + encrypt: true, + }, + }, + }, + { + name: "call UserInfo with different sub and federated_claims", + userInfoPath: "/user-info", + expectedOutput: jwt.MapClaims{ + "sub": "different-sub", + "federated_claims": map[string]interface{}{ + "connector_id": "github", + "user_id": "preferred-id", + }, + "groups": []interface{}{"githubOrg:engineers"}, + }, + expectError: false, + expectUnauthenticated: false, + expectedCacheItems: []struct { + key string + value string + expectEncrypted bool + expectError bool + }{ + { + // Key should use federated_claims.user_id (preferred-id) instead of sub + key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "different-sub", "federated_claims": map[string]interface{}{"user_id": "preferred-id"}}), + value: `{"sub":"different-sub","federated_claims":{"connector_id":"github","user_id":"preferred-id"},"groups":["githubOrg:engineers"]}`, + expectEncrypted: true, + expectError: false, + }, + }, + idpClaims: jwt.MapClaims{ + "sub": "different-sub", + "federated_claims": map[string]interface{}{ + "connector_id": "github", + "user_id": "preferred-id", + }, + "exp": float64(time.Now().Add(5 * time.Minute).Unix()), + }, + idpHandler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("content-type", "application/json") + w.WriteHeader(http.StatusOK) + response := jwt.MapClaims{ + "sub": "different-sub", + "federated_claims": map[string]interface{}{ + "connector_id": "github", + "user_id": "preferred-id", + }, + "groups": []interface{}{"githubOrg:engineers"}, + } + if err := json.NewEncoder(w).Encode(response); err != nil { + t.Errorf("failed to encode response: %v", err) + } + }, + cache: cache.NewInMemoryCache(24 * time.Hour), + cacheItems: []struct { + key string + value string + encrypt bool + }{ + { + // Access token cache key should also use federated_claims.user_id + key: formatAccessTokenCacheKey(jwt.MapClaims{"sub": "different-sub", "federated_claims": map[string]interface{}{"user_id": "preferred-id"}}), + value: "FakeAccessToken", + encrypt: true, + }, + }, + }, + { + name: "call UserInfo with only sub claim", + userInfoPath: "/user-info", + expectedOutput: jwt.MapClaims{"sub": "sub-only-user", "groups": []interface{}{"githubOrg:engineers"}}, + expectError: false, + expectUnauthenticated: false, + expectedCacheItems: []struct { + key string + value string + expectEncrypted bool + expectError bool + }{ + { + key: formatUserInfoResponseCacheKey(jwt.MapClaims{"sub": "sub-only-user"}), + value: `{"sub":"sub-only-user","groups":["githubOrg:engineers"]}`, + expectEncrypted: true, + expectError: false, + }, + }, + idpClaims: jwt.MapClaims{ + "sub": "sub-only-user", + "exp": float64(time.Now().Add(5 * time.Minute).Unix()), + }, + idpHandler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("content-type", "application/json") + w.WriteHeader(http.StatusOK) + response := jwt.MapClaims{ + "sub": "sub-only-user", + "groups": []interface{}{"githubOrg:engineers"}, + } + if err := json.NewEncoder(w).Encode(response); err != nil { + t.Errorf("failed to encode response: %v", err) + } + }, + cache: cache.NewInMemoryCache(24 * time.Hour), + cacheItems: []struct { + key string + value string + encrypt bool + }{ + { + key: formatAccessTokenCacheKey(jwt.MapClaims{"sub": "sub-only-user"}), + value: "FakeAccessToken", + encrypt: true, + }, + }, + }, + { + name: "call UserInfo with only federated claims", + userInfoPath: "/user-info", + expectedOutput: jwt.MapClaims{ + "federated_claims": map[string]interface{}{ + "connector_id": "github", + "user_id": "federated-only-user", + }, + "groups": []interface{}{"githubOrg:engineers"}, + }, + expectError: false, + expectUnauthenticated: false, + expectedCacheItems: []struct { + key string + value string + expectEncrypted bool + expectError bool + }{ + { + key: formatUserInfoResponseCacheKey(jwt.MapClaims{"federated_claims": map[string]interface{}{"user_id": "federated-only-user"}}), + value: `{"federated_claims":{"connector_id":"github","user_id":"federated-only-user"},"groups":["githubOrg:engineers"]}`, + expectEncrypted: true, + expectError: false, + }, + }, + idpClaims: jwt.MapClaims{ + "federated_claims": map[string]interface{}{ + "connector_id": "github", + "user_id": "federated-only-user", + }, + "exp": float64(time.Now().Add(5 * time.Minute).Unix()), + }, + idpHandler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("content-type", "application/json") + w.WriteHeader(http.StatusOK) + response := jwt.MapClaims{ + "federated_claims": map[string]interface{}{ + "connector_id": "github", + "user_id": "federated-only-user", + }, + "groups": []interface{}{"githubOrg:engineers"}, + } + if err := json.NewEncoder(w).Encode(response); err != nil { + t.Errorf("failed to encode response: %v", err) + } + }, + cache: cache.NewInMemoryCache(24 * time.Hour), + cacheItems: []struct { + key string + value string + encrypt bool + }{ + { + key: formatAccessTokenCacheKey(jwt.MapClaims{"federated_claims": map[string]interface{}{"user_id": "federated-only-user"}}), value: "FakeAccessToken", encrypt: true, }, @@ -848,6 +1018,9 @@ func TestGetUserInfo(t *testing.T) { assert.Equal(t, tt.expectUnauthenticated, unauthenticated) if tt.expectError { require.Error(t, err) + if tt.userInfoPath != "/user-info" { + assert.Contains(t, err.Error(), "user info path not found") + } } else { require.NoError(t, err) } @@ -862,7 +1035,13 @@ func TestGetUserInfo(t *testing.T) { tmpValue, err = crypto.Decrypt(tmpValue, encryptionKey) require.NoError(t, err) } - assert.Equal(t, item.value, string(tmpValue)) + // Compare as objects instead of strings + var expected, actual map[string]interface{} + err = json.Unmarshal([]byte(item.value), &expected) + require.NoError(t, err) + err = json.Unmarshal(tmpValue, &actual) + require.NoError(t, err) + assert.Equal(t, expected, actual) } } }) diff --git a/util/rbac/rbac.go b/util/rbac/rbac.go index ab45cf5c0d69f..d7c24b5b8138f 100644 --- a/util/rbac/rbac.go +++ b/util/rbac/rbac.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils" "github.com/argoproj/argo-cd/v2/util/assets" "github.com/argoproj/argo-cd/v2/util/glob" jwtutil "github.com/argoproj/argo-cd/v2/util/jwt" @@ -255,7 +256,7 @@ func (e *Enforcer) EnforceErr(rvals ...interface{}) error { if err != nil { break } - if sub := jwtutil.StringField(claims, "sub"); sub != "" { + if sub := utils.GetUserIdentifier(claims); sub != "" { rvalsStrs = append(rvalsStrs, fmt.Sprintf("sub: %s", sub)) } if issuedAtTime, err := jwtutil.IssuedAtTime(claims); err == nil { diff --git a/util/session/sessionmanager.go b/util/session/sessionmanager.go index 09ba6aa43cd38..e79f25fd249ff 100644 --- a/util/session/sessionmanager.go +++ b/util/session/sessionmanager.go @@ -20,6 +20,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils" "github.com/argoproj/argo-cd/v2/common" "github.com/argoproj/argo-cd/v2/pkg/client/listers/application/v1alpha1" "github.com/argoproj/argo-cd/v2/server/rbacpolicy" @@ -159,16 +160,19 @@ func (mgr *SessionManager) Create(subject string, secondsBeforeExpiry int64, id // Create a new token object, specifying signing method and the claims // you would like it to contain. now := time.Now().UTC() - claims := jwt.RegisteredClaims{ - IssuedAt: jwt.NewNumericDate(now), - Issuer: SessionManagerClaimsIssuer, - NotBefore: jwt.NewNumericDate(now), - Subject: subject, - ID: id, + claims := jwt.MapClaims{ + "iat": now.Unix(), + "iss": SessionManagerClaimsIssuer, + "nbf": now.Unix(), + "sub": subject, + "jti": id, + "federated_claims": map[string]interface{}{ + "user_id": "", // Empty for local auth + }, } if secondsBeforeExpiry > 0 { expires := now.Add(time.Duration(secondsBeforeExpiry) * time.Second) - claims.ExpiresAt = jwt.NewNumericDate(expires) + claims["exp"] = expires.Unix() } return mgr.signClaims(claims) @@ -226,7 +230,7 @@ func (mgr *SessionManager) Parse(tokenString string) (jwt.Claims, string, error) return nil, "", err } - subject := jwtutil.StringField(claims, "sub") + subject := utils.GetUserIdentifier(claims) id := jwtutil.StringField(claims, "jti") if projName, role, ok := rbacpolicy.GetProjectRoleFromSubject(subject); ok { @@ -502,9 +506,17 @@ func WithAuthMiddleware(disabled bool, authn TokenVerifier, next http.Handler) h return } ctx := r.Context() + + // Assert that claims is of type jwt.MapClaims + mapClaims, ok := claims.(jwt.MapClaims) + if !ok { + http.Error(w, "Invalid claims type", http.StatusUnauthorized) + return + } + // Add claims to the context to inspect for RBAC // nolint:staticcheck - ctx = context.WithValue(ctx, "claims", claims) + ctx = context.WithValue(ctx, "user_id", utils.GetUserIdentifier(mapClaims)) r = r.WithContext(ctx) } next.ServeHTTP(w, r) @@ -593,12 +605,7 @@ func Username(ctx context.Context) string { if !ok { return "" } - switch jwtutil.StringField(mapClaims, "iss") { - case SessionManagerClaimsIssuer: - return jwtutil.StringField(mapClaims, "sub") - default: - return jwtutil.StringField(mapClaims, "email") - } + return utils.GetUserIdentifier(mapClaims) } func Iss(ctx context.Context) string { @@ -622,7 +629,7 @@ func Sub(ctx context.Context) string { if !ok { return "" } - return jwtutil.StringField(mapClaims, "sub") + return utils.GetUserIdentifier(mapClaims) } func Groups(ctx context.Context, scopes []string) []string { @@ -633,8 +640,17 @@ func Groups(ctx context.Context, scopes []string) []string { return jwtutil.GetGroups(mapClaims, scopes) } +type contextKey struct{} + +var claimsKey = contextKey{} + +// ClaimsKey returns the context key used for claims +func ClaimsKey() interface{} { + return claimsKey +} + func mapClaims(ctx context.Context) (jwt.MapClaims, bool) { - claims, ok := ctx.Value("claims").(jwt.Claims) + claims, ok := ctx.Value(claimsKey).(jwt.Claims) if !ok { return nil, false } diff --git a/util/session/sessionmanager_test.go b/util/session/sessionmanager_test.go index 1ef496706feb9..30345436ae5ee 100644 --- a/util/session/sessionmanager_test.go +++ b/util/session/sessionmanager_test.go @@ -24,6 +24,7 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/client-go/kubernetes/fake" + "github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils" "github.com/argoproj/argo-cd/v2/common" appv1 "github.com/argoproj/argo-cd/v2/pkg/apis/application/v1alpha1" apps "github.com/argoproj/argo-cd/v2/pkg/client/clientset/versioned/fake" @@ -99,7 +100,7 @@ func TestSessionManager_AdminToken(t *testing.T) { assert.Empty(t, newToken) mapClaims := *(claims.(*jwt.MapClaims)) - subject := mapClaims["sub"].(string) + subject := utils.GetUserIdentifier(mapClaims) if subject != "admin" { t.Errorf("Token claim subject \"%s\" does not match expected subject \"%s\".", subject, "admin") } @@ -126,7 +127,7 @@ func TestSessionManager_AdminToken_ExpiringSoon(t *testing.T) { claims, _, err := mgr.Parse(newToken) require.NoError(t, err) mapClaims := *(claims.(*jwt.MapClaims)) - subject := mapClaims["sub"].(string) + subject := utils.GetUserIdentifier(mapClaims) assert.Equal(t, "admin", subject) } @@ -234,10 +235,17 @@ type tokenVerifierMock struct { } func (tm *tokenVerifierMock) VerifyToken(token string) (jwt.Claims, string, error) { - if tm.claims == nil { + if tm.err != nil { return nil, "", tm.err } - return tm.claims, "", tm.err + mapClaims := jwt.MapClaims{ + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + } + if tm.claims == nil { + return jwt.MapClaims{}, "", nil + } + return mapClaims, "", nil } func strPointer(str string) *string { @@ -346,29 +354,36 @@ func TestSessionManager_WithAuthMiddleware(t *testing.T) { } } -var loggedOutContext = context.Background() - -// nolint:staticcheck -var loggedInContext = context.WithValue(context.Background(), "claims", &jwt.MapClaims{"iss": "qux", "sub": "foo", "email": "bar", "groups": []string{"baz"}}) +var ( + loggedOutContext = context.Background() + // nolint:staticcheck + loggedInContext = context.WithValue(context.Background(), claimsKey, &jwt.MapClaims{"iss": "qux", "sub": "foo", "email": "bar", "groups": []string{"baz"}, "federated_claims": map[string]interface{}{"user_id": "foo"}}) + // for testing without federated claims + loggedInContextNoFederated = context.WithValue(context.Background(), claimsKey, &jwt.MapClaims{"iss": "qux", "sub": "foo", "email": "bar", "groups": []string{"baz"}}) +) func TestIss(t *testing.T) { assert.Empty(t, Iss(loggedOutContext)) assert.Equal(t, "qux", Iss(loggedInContext)) + assert.Equal(t, "foo", Sub(loggedInContextNoFederated)) // Without federated claims, falls back to sub } func TestLoggedIn(t *testing.T) { assert.False(t, LoggedIn(loggedOutContext)) assert.True(t, LoggedIn(loggedInContext)) + assert.Equal(t, "foo", Username(loggedInContextNoFederated)) } func TestUsername(t *testing.T) { assert.Empty(t, Username(loggedOutContext)) - assert.Equal(t, "bar", Username(loggedInContext)) + assert.Equal(t, "foo", Username(loggedInContext)) + assert.Equal(t, "foo", Username(loggedInContextNoFederated)) } func TestSub(t *testing.T) { assert.Empty(t, Sub(loggedOutContext)) assert.Equal(t, "foo", Sub(loggedInContext)) + assert.Equal(t, "foo", Username(loggedInContextNoFederated)) } func TestGroups(t *testing.T) {