-
Notifications
You must be signed in to change notification settings - Fork 1
/
claims.go
201 lines (177 loc) · 5.86 KB
/
claims.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package jwtauth
import (
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/mitchellh/pointerstructure"
"github.com/ryanuber/go-glob"
)
// setClaim sets a claim value from allClaims given a provided claim string.
// If this string is a valid JSONPointer, it will be interpreted as such to locate
// the claim. Otherwise, the claim string will be used directly.
func setClaim(logger log.Logger, allClaims map[string]interface{}, claim string, val interface{}) interface{} {
var err error
if !strings.HasPrefix(claim, "/") {
allClaims[claim] = val
} else {
val, err = pointerstructure.Set(allClaims, claim, val)
if err != nil {
logger.Warn(fmt.Sprintf("unable to set %s in claims: %s", claim, err.Error()))
return nil
}
}
return val
}
// getClaim returns a claim value from allClaims given a provided claim string.
// If this string is a valid JSONPointer, it will be interpreted as such to locate
// the claim. Otherwise, the claim string will be used directly.
func getClaim(logger log.Logger, allClaims map[string]interface{}, claim string) interface{} {
var val interface{}
var err error
if !strings.HasPrefix(claim, "/") {
val = allClaims[claim]
} else {
val, err = pointerstructure.Get(allClaims, claim)
if err != nil {
logger.Warn(fmt.Sprintf("unable to locate %s in claims: %s", claim, err.Error()))
return nil
}
}
if claim == "aud" && val != nil {
// Per RFC 7519 Section 4.1.3:
//
// > In the special case when the JWT has one audience, the "aud"
// > value MAY be a single case-sensitive string containing a
// > StringOrURI value.
//
// Because other code expects audience to be a slice, update our
// copy if we only have a single value.
if singleVal, ok := val.(string); ok {
val = []interface{}{singleVal}
}
}
// The claims unmarshalled by go-oidc don't use UseNumber, so there will
// be mismatches if they're coming in as float64 since Vault's config will
// be represented as json.Number. If the operator can coerce claims data to
// be in string form, there is no problem. As an alternative, we try to
// intelligently convert float32 and float64 to json.Number:
switch v := val.(type) {
case float32:
return json.Number(strconv.Itoa(int(v)))
case float64:
return json.Number(strconv.Itoa(int(v)))
}
return val
}
// extractMetadata builds a metadata map from a set of claims and claims mappings.
// The referenced claims must be strings and the claims mappings must be of the structure:
//
// {
// "/some/claim/pointer": "metadata_key1",
// "another_claim": "metadata_key2",
// ...
// }
func extractMetadata(logger log.Logger, allClaims map[string]interface{}, claimMappings map[string]string) (map[string]string, error) {
metadata := make(map[string]string)
for source, target := range claimMappings {
if value := getClaim(logger, allClaims, source); value != nil {
strValue, ok := value.(string)
if !ok {
return nil, fmt.Errorf("error converting claim '%s' to string", source)
}
metadata[target] = strValue
}
}
return metadata, nil
}
// validateAudience checks whether any of the audiences in audClaim match those
// in boundAudiences. If strict is true and there are no bound audiences, then the
// presence of any audience in the received claim is considered an error.
func validateAudience(boundAudiences, audClaim []string, strict bool) error {
if strict && len(boundAudiences) == 0 && len(audClaim) > 0 {
return errors.New("audience claim found in JWT but no audiences bound to the role")
}
if len(boundAudiences) > 0 {
for _, v := range boundAudiences {
if strutil.StrListContains(audClaim, v) {
return nil
}
}
return errors.New("aud claim does not match any bound audience")
}
return nil
}
// validateBoundClaims checks that all of the claim:value requirements in boundClaims are
// met in allClaims.
func validateBoundClaims(logger log.Logger, boundClaimsType string, boundClaims, allClaims map[string]interface{}) error {
useGlobs := boundClaimsType == boundClaimsTypeGlob
for claim, expValue := range boundClaims {
actValue := getClaim(logger, allClaims, claim)
if actValue == nil {
return fmt.Errorf("claim %q is missing", claim)
}
actVals, ok := normalizeList(actValue)
if !ok {
return fmt.Errorf("received claim is not a string, bool, int or list: %v", actValue)
}
expVals, ok := normalizeList(expValue)
if !ok {
return fmt.Errorf("bound claim is not a string, bool, int or list: %v", expValue)
}
found, err := matchFound(expVals, actVals, useGlobs)
if err != nil {
return err
}
if !found {
return fmt.Errorf("claim %q does not match any associated bound claim values", claim)
}
}
return nil
}
func matchFound(expVals, actVals []interface{}, useGlobs bool) (bool, error) {
for _, expVal := range expVals {
for _, actVal := range actVals {
if useGlobs {
// Only string globbing is supported.
expValStr, ok := expVal.(string)
if !ok {
return false, fmt.Errorf("received claim is not a glob string: %expVal", expVal)
}
actValStr, ok := actVal.(string)
if !ok {
continue
}
if !glob.Glob(expValStr, actValStr) {
continue
}
} else {
if actVal != expVal {
continue
}
}
return true, nil
}
}
return false, nil
}
// normalizeList takes a string, bool, json.Number or list and returns a list. This is useful when
// providers are expected to return a list (typically of strings) but reduce it
// to a string type when the list count is 1.
func normalizeList(raw interface{}) ([]interface{}, bool) {
var normalized []interface{}
switch v := raw.(type) {
case []interface{}:
normalized = v
case string, bool, json.Number:
normalized = []interface{}{v}
default:
return nil, false
}
return normalized, true
}