Skip to content

Commit

Permalink
Add oauth2_metadata config option (#320)
Browse files Browse the repository at this point in the history
Signed-off-by: Dave Dykstra <[email protected]>
  • Loading branch information
DrDaveD authored Aug 14, 2024
1 parent 69b88e6 commit 23e8687
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 1 deletion.
21 changes: 21 additions & 0 deletions path_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,24 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request,
}
}

// Also fetch any requested extra oauth2 metadata
oauth2Metadata := make(map[string]string)
for _, mdname := range role.Oauth2Metadata {
var md string
switch mdname {
case "id_token":
md = string(token.IDToken())
case "refresh_token":
md = string(token.RefreshToken())
case "access_token":
md = string(token.AccessToken())
default:
// previously validated so this should never happen
return logical.ErrorResponse(errLoginFailed + " Unrecognized oauth2 metadata name " + mdname), nil
}
oauth2Metadata[mdname] = md
}

if role.VerboseOIDCLogging {
if c, err := json.Marshal(allClaims); err == nil {
b.Logger().Debug("OIDC provider response", "claims", string(c))
Expand All @@ -388,6 +406,9 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request,
for k, v := range alias.Metadata {
tokenMetadata[k] = v
}
for k, v := range oauth2Metadata {
tokenMetadata["oauth2_"+k] = v
}

auth := &logical.Auth{
Policies: role.Policies,
Expand Down
11 changes: 10 additions & 1 deletion path_oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -846,8 +846,16 @@ func TestOIDC_Callback(t *testing.T) {

auth := resp.Auth

if auth != nil {
// Can't predict the content of oauth2_id_token
// so instead copy it. This does at least
// verify that it is present because if not it
// introduces an empty value into expected.
expected.Metadata["oauth2_id_token"] = auth.Metadata["oauth2_id_token"]
}

if !reflect.DeepEqual(auth, expected) {
t.Fatalf("expected: %v, auth: %v", expected, resp)
t.Fatalf("expected: %v, resp: %v", expected, resp)
}
}
})
Expand Down Expand Up @@ -1617,6 +1625,7 @@ func getBackendAndServer(t *testing.T, boundCIDRs bool, callbackMode string) (lo
"/nested/secret_code": "bar",
"temperature": "76",
},
"oauth2_metadata": []string{"id_token"},
}

if boundCIDRs {
Expand Down
20 changes: 20 additions & 0 deletions path_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"errors"
"fmt"
"slices"
"strings"
"time"

Expand Down Expand Up @@ -148,6 +149,10 @@ Defaults to 60 (1 minute) if set to 0 and can be disabled if set to -1.`,
Type: framework.TypeKVPairs,
Description: `Mappings of claims (key) that will be copied to a metadata field (value)`,
},
"oauth2_metadata": {
Type: framework.TypeCommaStringSlice,
Description: `Comma-separated list of one or more of access_token, id_token, refresh_token to return in metadata`,
},
"user_claim": {
Type: framework.TypeString,
Description: `The claim to use for the Identity entity alias name`,
Expand Down Expand Up @@ -238,6 +243,7 @@ type jwtRole struct {
BoundClaimsType string `json:"bound_claims_type"`
BoundClaims map[string]interface{} `json:"bound_claims"`
ClaimMappings map[string]string `json:"claim_mappings"`
Oauth2Metadata []string `json:"oauth2_metadata"`
UserClaim string `json:"user_claim"`
GroupsClaim string `json:"groups_claim"`
OIDCScopes []string `json:"oidc_scopes"`
Expand Down Expand Up @@ -354,6 +360,7 @@ func (b *jwtAuthBackend) pathRoleRead(ctx context.Context, req *logical.Request,
"bound_claims_type": role.BoundClaimsType,
"bound_claims": role.BoundClaims,
"claim_mappings": role.ClaimMappings,
"oauth2_metadata": role.Oauth2Metadata,
"user_claim": role.UserClaim,
"user_claim_json_pointer": role.UserClaimJSONPointer,
"groups_claim": role.GroupsClaim,
Expand Down Expand Up @@ -561,6 +568,19 @@ func (b *jwtAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical.
role.ClaimMappings = claimMappings
}

if oauth2Metadata, ok := data.GetOk("oauth2_metadata"); ok {
role.Oauth2Metadata = oauth2Metadata.([]string)
for _, mdname := range role.Oauth2Metadata {
if !slices.Contains([]string{
"id_token",
"refresh_token",
"access_token",
}, mdname) {
return logical.ErrorResponse("Unrecognized oauth2 metadata name " + mdname), nil
}
}
}

if userClaim, ok := data.GetOk("user_claim"); ok {
role.UserClaim = userClaim.(string)
}
Expand Down
1 change: 1 addition & 0 deletions path_role_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,7 @@ func TestPath_Read(t *testing.T) {
"bound_claims_type": "string",
"bound_claims": map[string]interface{}(nil),
"claim_mappings": map[string]string(nil),
"oauth2_metadata": []string(nil),
"bound_subject": "testsub",
"bound_audiences": []string{"vault"},
"allowed_redirect_uris": []string{"http://127.0.0.1"},
Expand Down

0 comments on commit 23e8687

Please sign in to comment.