forked from cloudtrust/common-service
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathauthentication.go
371 lines (309 loc) · 13.6 KB
/
authentication.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
package middleware
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"regexp"
"strings"
cs "github.com/cloudtrust/common-service/v2"
errorhandler "github.com/cloudtrust/common-service/v2/errors"
"github.com/cloudtrust/common-service/v2/log"
"github.com/cloudtrust/common-service/v2/security"
"github.com/golang-jwt/jwt/v5"
errorsPkg "github.com/pkg/errors"
)
func splitIssuer(issuer string) (string, string) {
var splitIssuer = strings.Split(issuer, "/auth/realms/")
if len(splitIssuer) <= 1 {
splitIssuer = strings.Split(issuer, "/realms/")
}
return splitIssuer[0], splitIssuer[1]
}
// MakeHTTPBasicAuthenticationFuncMW retrieve the token from the HTTP header 'Basic' and
// check credentials according to the given callback function
// If there is no such header, the request is not allowed.
// If the password is correct, the username is added into the context
func MakeHTTPBasicAuthenticationFuncMW(credsMatcher func(token string) (*string, error), logger log.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var ctx = context.TODO()
var token, err = extractBasicAuthentication(ctx, req.Header.Get("Authorization"), logger)
if err != nil {
httpErrorHandler(ctx, http.StatusForbidden, err, w)
return
}
var authenticated *string
if authenticated, err = credsMatcher(token); err != nil {
httpErrorHandler(ctx, http.StatusForbidden, err, w)
return
} else if authenticated == nil {
logger.Info(ctx, "msg", "Authorization error: Invalid password value")
httpErrorHandler(ctx, http.StatusUnauthorized, errors.New(errorhandler.MsgErrInvalidParam+"."+errorhandler.Token), w)
return
}
ctx = context.WithValue(req.Context(), cs.CtContextUsername, *authenticated)
next.ServeHTTP(w, req.WithContext(ctx))
})
}
}
// MakeHTTPBasicAuthenticationMapMW retrieve the token from the HTTP header 'Basic' and
// check credentials according to the given credentials map
// If there is no such header, the request is not allowed.
// If the password is correct, the username is added into the context
func MakeHTTPBasicAuthenticationMapMW(credentials map[string]string, logger log.Logger) func(http.Handler) http.Handler {
var authTokens = make(map[string]string)
for user, password := range credentials {
var token = fmt.Sprintf("%s:%s", user, password)
var token64 = base64.StdEncoding.EncodeToString([]byte(token))
authTokens[token64] = user
}
return MakeHTTPBasicAuthenticationFuncMW(func(token string) (*string, error) {
if username, ok := authTokens[token]; ok {
return &username, nil
}
return nil, nil
}, logger)
}
// MakeHTTPBasicAuthenticationMW retrieve the token from the HTTP header 'Basic' and
// check if the password value match the allowed one.
// If there is no such header, the request is not allowed.
// If the password is correct, the username is added into the context:
// - username: username extracted from the token
func MakeHTTPBasicAuthenticationMW(passwordToMatch string, logger log.Logger) func(http.Handler) http.Handler {
return MakeHTTPBasicAuthenticationFuncMW(func(token string) (*string, error) {
var ctx = context.TODO()
var username, password, err = decodeBasicAuthToken(ctx, token, logger)
if err != nil {
return nil, err
}
if password == passwordToMatch {
return &username, nil
}
return nil, nil
}, logger)
}
func extractBasicAuthentication(ctx context.Context, authorizationHeader string, logger log.Logger) (string, error) {
if authorizationHeader == "" {
logger.Info(ctx, "msg", "Authorization error: Missing Authorization header")
return "", errors.New(errorhandler.MsgErrMissingParam + "." + errorhandler.AuthHeader)
}
var regexpBasicAuth = `^[Bb]asic (.+)$`
var r = regexp.MustCompile(regexpBasicAuth)
var match = r.FindStringSubmatch(authorizationHeader)
if match == nil {
logger.Info(ctx, "msg", "Authorization error: Missing basic token")
return "", errors.New(errorhandler.MsgErrMissingParam + "." + errorhandler.BasicToken)
}
return match[1], nil
}
func decodeBasicAuthToken(ctx context.Context, authToken string, logger log.Logger) (string, string, error) {
// Decode base 64
decodedToken, err := base64.StdEncoding.DecodeString(authToken)
if err != nil {
logger.Info(ctx, "msg", "Authorization error: Invalid base64 token")
return "", "", errors.New(errorhandler.MsgErrInvalidParam + "." + errorhandler.Token)
}
// Extract username & password values
var tokenSubparts = strings.Split(string(decodedToken), ":")
if len(tokenSubparts) != 2 {
logger.Info(ctx, "msg", "Authorization error: Invalid token format (username:password)")
return "", "", errors.New(errorhandler.MsgErrInvalidParam + "." + errorhandler.Token)
}
return tokenSubparts[0], tokenSubparts[1], nil
}
// KeycloakClient is the interface of the keycloak client.
type KeycloakClient interface {
VerifyToken(issuer string, realmName string, accessToken string) error
}
// MakeHTTPOIDCTokenValidationMW retrieve the oidc token from the HTTP header 'Bearer' and
// check its validity for the Keycloak instance binded to the component.
// If there is no such header, the request is not allowed.
// If the token is validated, the following informations are added into the context:
// - access_token: the recieved access token in raw format
// - realm: realm name extracted from the Issuer information of the token
// - username: username extracted from the token
func MakeHTTPOIDCTokenValidationMW(keycloakClient KeycloakClient, audienceRequired string, logger log.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var authorizationHeader = req.Header.Get("Authorization")
var ctx = context.TODO()
if authorizationHeader == "" {
logger.Info(ctx, "msg", "Authorization error: Missing Authorization header")
httpErrorHandler(ctx, http.StatusForbidden, errors.New(errorhandler.MsgErrMissingParam+"."+errorhandler.AuthHeader), w)
return
}
var r = regexp.MustCompile(`^[Bb]earer +([^ ]+)$`)
var match = r.FindStringSubmatch(authorizationHeader)
if match == nil {
logger.Info(ctx, "msg", "Authorization error: Missing bearer token")
httpErrorHandler(ctx, http.StatusForbidden, errors.New(errorhandler.MsgErrMissingParam+"."+errorhandler.BearerToken), w)
return
}
// match[0] is the global matched group. match[1] is the first captured group
var accessToken = match[1]
var jot TokenAudience
jot, err := ParseAndValidateOIDCToken(ctx, accessToken, keycloakClient, audienceRequired, logger)
// If there was an error during the validation process, raise an error and stop
if err != nil {
switch errorsPkg.Cause(err).(type) {
case security.ForbiddenError:
httpErrorHandler(ctx, http.StatusForbidden, errors.New(errorhandler.MsgErrInvalidParam+"."+errorhandler.Token), w)
break
case errorhandler.UnauthorizedError:
httpErrorHandler(ctx, http.StatusUnauthorized, errors.New(errorhandler.MsgErrInvalidParam+"."+errorhandler.Token), w)
break
}
return
}
var issuer, issuerDomain, realm string
issuer = jot.GetIssuer()
issuerDomain, realm = splitIssuer(issuer)
ctx = context.WithValue(req.Context(), cs.CtContextAccessToken, accessToken)
ctx = context.WithValue(ctx, cs.CtContextRealm, realm)
ctx = context.WithValue(ctx, cs.CtContextUserID, jot.GetSubject())
ctx = context.WithValue(ctx, cs.CtContextUsername, jot.GetUsername())
ctx = context.WithValue(ctx, cs.CtContextGroups, ExtractGroups(jot.GetGroups()))
ctx = context.WithValue(ctx, cs.CtContextIssuerDomain, issuerDomain)
next.ServeHTTP(w, req.WithContext(ctx))
})
}
}
// ParseAndValidateOIDCToken ensures the OIDC token given in parameter is valid. This method must be public as it is used externally by some projects
func ParseAndValidateOIDCToken(ctx context.Context, accessToken string, keycloakClient KeycloakClient, audienceRequired string, logger log.Logger) (TokenAudience, error) {
token, _, err := jwt.NewParser().ParseUnverified(accessToken, jwt.MapClaims{})
if err != nil {
logger.Info(ctx, "msg", "Authorization error", "err", err)
return nil, security.ForbiddenError{}
}
payload, err := json.Marshal(token.Claims)
if err != nil {
logger.Info(ctx, "msg", "Authorization error", "err", err)
return nil, security.ForbiddenError{}
}
var jot TokenAudience
if jot, err = unmarshalTokenAudience(payload); err != nil {
logger.Info(ctx, "msg", "Authorization error", "err", err)
return nil, security.ForbiddenError{}
}
if !jot.AssertMatchingAudience(audienceRequired) {
logger.Info(ctx, "msg", "Authorization error: Incorrect audience", "audience", jot.GetAudience())
return nil, security.ForbiddenError{}
}
var issuer = jot.GetIssuer()
var issuerDomain, realm = splitIssuer(issuer)
if err = keycloakClient.VerifyToken(issuerDomain, realm, accessToken); err != nil {
logger.Info(ctx, "msg", "Authorization error", "err", err)
return nil, errorhandler.UnauthorizedError{}
}
// if there was no error during the token validation process, return true
return jot, nil
}
// AssertMatchingAudience checks if the required audience is in the jwt list of audiences
func AssertMatchingAudience(jwtAudiences []string, requiredAudience string) bool {
for _, jwtAudience := range jwtAudiences {
if requiredAudience == jwtAudience {
return true
}
}
return false
}
// ExtractGroups extracts the list of groups
func ExtractGroups(kcGroups []string) []string {
var groups = []string{}
for _, kcGroup := range kcGroups {
groups = append(groups, strings.TrimPrefix(kcGroup, "/"))
}
return groups
}
// TokenAudienceStringArray is JWT token and the custom fields present in OIDC Token provided by Keycloak.
// Audience can be a string or a string array according the specification.
// The libraries are not supporting tit at this time (Fix in progress), meanwhile we circumvent it with a quick fix.
type TokenAudienceStringArray struct {
hdr *header
Issuer string `json:"iss,omitempty"`
Subject string `json:"sub,omitempty"`
Audience []string `json:"aud,omitempty"`
ExpirationTime int64 `json:"exp,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
ID string `json:"jti,omitempty"`
Username string `json:"preferred_username,omitempty"`
Groups []string `json:"groups,omitempty"`
}
// TokenAudienceString is JWT token with an Audience field represented as a string
type TokenAudienceString struct {
hdr *header
Issuer string `json:"iss,omitempty"`
Subject string `json:"sub,omitempty"`
Audience string `json:"aud,omitempty"`
ExpirationTime int64 `json:"exp,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
ID string `json:"jti,omitempty"`
Username string `json:"preferred_username,omitempty"`
Groups []string `json:"groups,omitempty"`
}
// TokenAudience interface
type TokenAudience interface {
GetSubject() string
GetUsername() string
GetIssuer() string
GetGroups() []string
GetAudience() any
AssertMatchingAudience(requiredValue string) bool
}
type header struct {
Algorithm string `json:"alg,omitempty"`
KeyID string `json:"kid,omitempty"`
Type string `json:"typ,omitempty"`
ContentType string `json:"cty,omitempty"`
}
func unmarshalTokenAudience(payload []byte) (TokenAudience, error) {
var err error
// The audience in JWT may be a string array or a string.
// First we try with a string array, if a failure occurs we try with a string
{
var jot TokenAudienceStringArray
if err = json.Unmarshal(payload, &jot); err == nil {
return &jot, nil
}
}
{
var jot TokenAudienceString
if err = json.Unmarshal(payload, &jot); err == nil {
return &jot, nil
}
}
return nil, err
}
// GetSubject provides the subject from the token
func (ta *TokenAudienceStringArray) GetSubject() string { return ta.Subject }
// GetUsername provides the username from the token
func (ta *TokenAudienceStringArray) GetUsername() string { return ta.Username }
// GetIssuer provides the issuer from the token
func (ta *TokenAudienceStringArray) GetIssuer() string { return ta.Issuer }
// GetGroups provides the groups from the token
func (ta *TokenAudienceStringArray) GetGroups() []string { return ta.Groups }
// GetAudience provides the audience from the token
func (ta *TokenAudienceStringArray) GetAudience() any { return ta.Audience }
// AssertMatchingAudience checks if the required audience is in the token list of audiences
func (ta *TokenAudienceStringArray) AssertMatchingAudience(requiredValue string) bool {
return AssertMatchingAudience(ta.Audience, requiredValue)
}
// GetSubject provides the subject from the token
func (ta *TokenAudienceString) GetSubject() string { return ta.Subject }
// GetUsername provides the username from the token
func (ta *TokenAudienceString) GetUsername() string { return ta.Username }
// GetIssuer provides the issuer from the token
func (ta *TokenAudienceString) GetIssuer() string { return ta.Issuer }
// GetGroups provides the groups from the token
func (ta *TokenAudienceString) GetGroups() []string { return ta.Groups }
// GetAudience provides the audience from the token
func (ta *TokenAudienceString) GetAudience() any { return ta.Audience }
// AssertMatchingAudience checks if the required audience is in the token list of audiences
func (ta *TokenAudienceString) AssertMatchingAudience(requiredValue string) bool {
return ta.Audience == requiredValue
}