Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Override sub with federated_claims.user_id when dex is used #20683

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
3 changes: 2 additions & 1 deletion cmd/argocd/commands/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"golang.org/x/oauth2"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/headless"
"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
argocdclient "github.com/argoproj/argo-cd/v2/pkg/apiclient"
sessionpkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/session"
settingspkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/settings"
Expand Down Expand Up @@ -196,7 +197,7 @@ func userDisplayName(claims jwt.MapClaims) string {
if name := jwtutil.StringField(claims, "name"); name != "" {
return name
}
return jwtutil.StringField(claims, "sub")
return utils.GetUserIdentifier(claims)
}

// oauth2Login opens a browser, runs a temporary HTTP server to delegate OAuth2 login flow and
Expand Down
2 changes: 1 addition & 1 deletion cmd/argocd/commands/project_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ Create token succeeded for proj:test-project:test-role.
issuedAt, _ := jwt.IssuedAt(claims)
expiresAt := int64(jwt.Float64Field(claims, "exp"))
id := jwt.StringField(claims, "jti")
subject := jwt.StringField(claims, "sub")
subject := utils.GetUserIdentifier(claims)

if !outputTokenOnly {
fmt.Printf("Create token succeeded for %s.\n", subject)
Expand Down
19 changes: 19 additions & 0 deletions cmd/argocd/commands/utils/claims.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package utils

import (
"github.com/golang-jwt/jwt/v4"
)

// GetUserIdentifier returns a consistent user identifier, checking federated_claims.user_id when Dex is in use
func GetUserIdentifier(claims jwt.MapClaims) string {
if federatedClaims, ok := claims["federated_claims"].(map[string]interface{}); ok {
if userID, exists := federatedClaims["user_id"].(string); exists && userID != "" {
return userID
}
}
// Fallback to sub
if sub, ok := claims["sub"].(string); ok && sub != "" {
return sub
}
return ""
}
1 change: 1 addition & 0 deletions server/account/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func (s *Server) UpdatePassword(ctx context.Context, q *account.UpdatePasswordRe
// check for permission is user is trying to change someone else's password
// assuming user is trying to update someone else if username is different or issuer is not Argo CD
if updatedUsername != username || issuer != session.SessionManagerClaimsIssuer {
log.Printf("Claims: %+v", ctx.Value("claims")) // this line for debug
if err := s.enf.EnforceErr(ctx.Value("claims"), rbacpolicy.ResourceAccounts, rbacpolicy.ActionUpdate, q.Name); err != nil {
return nil, fmt.Errorf("permission denied: %w", err)
}
Expand Down
48 changes: 36 additions & 12 deletions server/account/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,30 +82,54 @@ func getAdminAccount(mgr *settings.SettingsManager) (*settings.Account, error) {
}

func adminContext(ctx context.Context) context.Context {
// nolint:staticcheck
return context.WithValue(ctx, "claims", &jwt.RegisteredClaims{Subject: "admin", Issuer: sessionutil.SessionManagerClaimsIssuer})
claims := jwt.MapClaims{
"sub": "admin",
"iss": sessionutil.SessionManagerClaimsIssuer,
"groups": []string{"role:admin"},
"federated_claims": map[string]interface{}{
"user_id": "admin",
},
}
ctx = context.WithValue(ctx, sessionutil.ClaimsKey(), claims)
//nolint:staticcheck
ctx = context.WithValue(ctx, "claims", claims)
return ctx
}

func ssoAdminContext(ctx context.Context, iat time.Time) context.Context {
// nolint:staticcheck
return context.WithValue(ctx, "claims", &jwt.RegisteredClaims{
Subject: "admin",
Issuer: "https://myargocdhost.com/api/dex",
IssuedAt: jwt.NewNumericDate(iat),
})
claims := jwt.MapClaims{
"sub": "admin",
"iss": "https://myargocdhost.com/api/dex",
"iat": jwt.NewNumericDate(iat),
"groups": []string{"role:admin"}, // Add admin group
"federated_claims": map[string]interface{}{
"user_id": "admin",
},
}
// Set both context values
ctx = context.WithValue(ctx, sessionutil.ClaimsKey(), claims)
//nolint:staticcheck
ctx = context.WithValue(ctx, "claims", claims)

return ctx
}

func projTokenContext(ctx context.Context) context.Context {
claims := jwt.MapClaims{
"sub": "proj:demo:deployer",
"iss": sessionutil.SessionManagerClaimsIssuer,
"groups": []string{"proj:demo:deployer"},
}
ctx = context.WithValue(ctx, sessionutil.ClaimsKey(), claims)
// nolint:staticcheck
return context.WithValue(ctx, "claims", &jwt.RegisteredClaims{
Subject: "proj:demo:deployer",
Issuer: sessionutil.SessionManagerClaimsIssuer,
})
ctx = context.WithValue(ctx, "claims", claims)
return ctx
}

func TestUpdatePassword(t *testing.T) {
accountServer, sessionServer := newTestAccountServer(context.Background())
ctx := adminContext(context.Background())

var err error

// ensure password is not allowed to be updated if given bad password
Expand Down
3 changes: 2 additions & 1 deletion server/rbacpolicy/rbacpolicy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/golang-jwt/jwt/v4"
log "github.com/sirupsen/logrus"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/pkg/apis/application/v1alpha1"
applister "github.com/argoproj/argo-cd/v2/pkg/client/listers/application/v1alpha1"
jwtutil "github.com/argoproj/argo-cd/v2/util/jwt"
Expand Down Expand Up @@ -114,7 +115,7 @@ func (p *RBACPolicyEnforcer) EnforceClaims(claims jwt.Claims, rvals ...interface
return false
}

subject := jwtutil.StringField(mapClaims, "sub")
subject := utils.GetUserIdentifier(mapClaims)
// Check if the request is for an application resource. We have special enforcement which takes
// into consideration the project's token and group bindings
var runtimePolicy string
Expand Down
3 changes: 2 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ import (
"k8s.io/client-go/tools/cache"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/common"
"github.com/argoproj/argo-cd/v2/pkg/apiclient"
accountpkg "github.com/argoproj/argo-cd/v2/pkg/apiclient/account"
Expand Down Expand Up @@ -1417,7 +1418,7 @@ func (a *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, error
log.Errorf("error fetching user info endpoint: %v", err)
return claims, "", status.Errorf(codes.Internal, "invalid userinfo response")
}
if groupClaims["sub"] != userInfo["sub"] {
if utils.GetUserIdentifier(groupClaims) != utils.GetUserIdentifier(userInfo) {
return claims, "", status.Error(codes.Unknown, "subject of claims from user info endpoint didn't match subject of idToken, see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo")
}
groupClaims["groups"] = userInfo["groups"]
Expand Down
2 changes: 1 addition & 1 deletion test/container/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ RUN ln -s /usr/lib/$(uname -m)-linux-gnu /usr/lib/linux-gnu
# Please make sure to also check the contained yarn version and update the references below when upgrading this image's version
FROM docker.io/library/node:22.9.0@sha256:69e667a79aa41ec0db50bc452a60e705ca16f35285eaf037ebe627a65a5cdf52 as node

FROM docker.io/library/golang:1.23.3@sha256:d56c3e08fe5b27729ee3834854ae8f7015af48fd651cd25d1e3bcf3c19830174 as golang
FROM docker.io/library/golang:1.23.1@sha256:4f063a24d429510e512cc730c3330292ff49f3ade3ae79bda8f84a24fa25ecb0 as golang

FROM docker.io/library/registry:2.8@sha256:ac0192b549007e22998eb74e8d8488dcfe70f1489520c3b144a6047ac5efbe90 as registry

Expand Down
23 changes: 14 additions & 9 deletions util/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"

"github.com/argoproj/argo-cd/v2/cmd/argocd/commands/utils"
"github.com/argoproj/argo-cd/v2/common"
"github.com/argoproj/argo-cd/v2/server/settings/oidc"
"github.com/argoproj/argo-cd/v2/util/cache"
Expand Down Expand Up @@ -402,9 +403,8 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
log.Errorf("cannot encrypt accessToken: %v (claims=%s)", err, claimsJSON)
return
}
sub := jwtutil.StringField(claims, "sub")
err = a.clientCache.Set(&cache.Item{
Key: formatAccessTokenCacheKey(sub),
Key: formatAccessTokenCacheKey(claims),
Object: encToken,
CacheActionOpts: cache.CacheActionOpts{
Expiration: getTokenExpiration(claims),
Expand Down Expand Up @@ -552,12 +552,12 @@ func createClaimsAuthenticationRequestParameter(requestedClaims map[string]*oidc

// GetUserInfo queries the IDP userinfo endpoint for claims
func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) {
sub := jwtutil.StringField(actualClaims, "sub")
sub := utils.GetUserIdentifier(actualClaims)
var claims jwt.MapClaims
var encClaims []byte

// in case we got it in the cache, we just return the item
clientCacheKey := formatUserInfoResponseCacheKey(sub)
clientCacheKey := formatUserInfoResponseCacheKey(actualClaims)
if err := a.clientCache.Get(clientCacheKey, &encClaims); err == nil {
claimsRaw, err := crypto.Decrypt(encClaims, a.encryptionKey)
if err != nil {
Expand All @@ -575,7 +575,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP

// check if the accessToken for the user is still present
var encAccessToken []byte
err := a.clientCache.Get(formatAccessTokenCacheKey(sub), &encAccessToken)
err := a.clientCache.Get(formatAccessTokenCacheKey(actualClaims), &encAccessToken)
// without an accessToken we can't query the user info endpoint
// thus the user needs to reauthenticate for argocd to get a new accessToken
if errors.Is(err, cache.ErrCacheMiss) {
Expand Down Expand Up @@ -607,6 +607,9 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
if response.StatusCode == http.StatusUnauthorized {
return claims, true, err
}
if response.StatusCode == http.StatusNotFound {
return jwt.MapClaims{}, true, fmt.Errorf("user info path not found: %s", userInfoPath)
}

// according to https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponseValidation
// the response should be validated
Expand Down Expand Up @@ -684,11 +687,13 @@ func getTokenExpiration(claims jwt.MapClaims) time.Duration {
}

// formatUserInfoResponseCacheKey returns the key which is used to store userinfo of user in cache
func formatUserInfoResponseCacheKey(sub string) string {
return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, sub)
func formatUserInfoResponseCacheKey(claims jwt.MapClaims) string {
userID := utils.GetUserIdentifier(claims)
return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, userID)
}

// formatAccessTokenCacheKey returns the key which is used to store the accessToken of a user in cache
func formatAccessTokenCacheKey(sub string) string {
return fmt.Sprintf("%s_%s", AccessTokenCachePrefix, sub)
func formatAccessTokenCacheKey(claims jwt.MapClaims) string {
userID := utils.GetUserIdentifier(claims)
return fmt.Sprintf("%s_%s", AccessTokenCachePrefix, userID)
}
Loading
Loading