diff --git a/filters/apiusagemonitoring/filter.go b/filters/apiusagemonitoring/filter.go index aa4efcb94e..7b3dd6f2a2 100644 --- a/filters/apiusagemonitoring/filter.go +++ b/filters/apiusagemonitoring/filter.go @@ -1,8 +1,6 @@ package apiusagemonitoring import ( - "encoding/base64" - "encoding/json" "fmt" "net/http" "net/url" @@ -10,6 +8,7 @@ import ( "time" "github.com/zalando/skipper/filters" + "github.com/zalando/skipper/jwt" ) const ( @@ -211,30 +210,17 @@ func createAndCacheMetricsNames(path *pathInfo, method string, methodIndex int) // It returns `nil` if it was not possible to parse the JWT body. func parseJwtBody(req *http.Request) jwtTokenPayload { ahead := req.Header.Get(authorizationHeaderName) - if !strings.HasPrefix(ahead, authorizationHeaderPrefix) { + tv := strings.TrimPrefix(ahead, authorizationHeaderPrefix) + if tv == ahead { return nil } - // split the header into the 3 JWT parts - fields := strings.Split(ahead, ".") - if len(fields) != 3 { - return nil - } - - // base64-decode the JWT body part - bodyJSON, err := base64.RawURLEncoding.DecodeString(fields[1]) - if err != nil { - return nil - } - - // un-marshall the JWT body from JSON - var bodyObject map[string]interface{} - err = json.Unmarshal(bodyJSON, &bodyObject) + token, err := jwt.Parse(tv) if err != nil { return nil } - return bodyObject + return token.Claims } type jwtTokenPayload map[string]interface{} diff --git a/filters/log/log.go b/filters/log/log.go index aa271b9cc3..9ca3a00286 100644 --- a/filters/log/log.go +++ b/filters/log/log.go @@ -7,7 +7,6 @@ package log import ( "bytes" - "encoding/base64" "encoding/json" "io" "os" @@ -17,6 +16,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/zalando/skipper/filters" + "github.com/zalando/skipper/jwt" ) const ( @@ -209,31 +209,21 @@ func (ual *unverifiedAuditLogSpec) CreateFilter(args []interface{}) (filters.Fil func (ual *unverifiedAuditLogFilter) Request(ctx filters.FilterContext) { req := ctx.Request() ahead := req.Header.Get(authHeaderName) - if !strings.HasPrefix(ahead, authHeaderPrefix) { + tv := strings.TrimPrefix(ahead, authHeaderPrefix) + if tv == ahead { return } - fields := strings.FieldsFunc(ahead, func(r rune) bool { - return r == []rune(".")[0] - }) - if len(fields) == 3 { - sDec, err := base64.RawURLEncoding.DecodeString(fields[1]) - if err != nil { - return - } - - var j map[string]interface{} - err = json.Unmarshal(sDec, &j) - if err != nil { - return - } + token, err := jwt.Parse(tv) + if err != nil { + return + } - for i := 0; i < len(ual.TokenKeys); i++ { - if k, ok := j[ual.TokenKeys[i]]; ok { - if v, ok2 := k.(string); ok2 { - req.Header.Add(UnverifiedAuditHeader, cleanSub(v)) - return - } + for i := 0; i < len(ual.TokenKeys); i++ { + if k, ok := token.Claims[ual.TokenKeys[i]]; ok { + if v, ok2 := k.(string); ok2 { + req.Header.Add(UnverifiedAuditHeader, cleanSub(v)) + return } } } diff --git a/jwt/token.go b/jwt/token.go new file mode 100644 index 0000000000..1c56b13c4b --- /dev/null +++ b/jwt/token.go @@ -0,0 +1,39 @@ +package jwt + +import ( + "encoding/base64" + "encoding/json" + "errors" + "strings" +) + +var ( + errInvalidToken = errors.New("invalid jwt token") +) + +type Token struct { + Claims map[string]interface{} +} + +func Parse(value string) (*Token, error) { + parts := strings.Split(value, ".") + if len(parts) != 3 { + return nil, errInvalidToken + } + + var token Token + err := unmarshalBase64JSON(parts[1], &token.Claims) + if err != nil { + return nil, errInvalidToken + } + + return &token, nil +} + +func unmarshalBase64JSON(s string, v interface{}) error { + d, err := base64.RawURLEncoding.DecodeString(s) + if err != nil { + return err + } + return json.Unmarshal(d, v) +} diff --git a/jwt/token_test.go b/jwt/token_test.go new file mode 100644 index 0000000000..7669ff664e --- /dev/null +++ b/jwt/token_test.go @@ -0,0 +1,66 @@ +package jwt + +import ( + "encoding/base64" + "encoding/json" + "reflect" + "testing" +) + +func TestParse(t *testing.T) { + for _, tt := range []struct { + value string + ok bool + claims map[string]interface{} + }{ + { + value: "", + ok: false, + }, { + value: "x", + ok: false, + }, { + value: "x.y", + ok: false, + }, { + value: "x.y.z", + ok: false, + }, { + value: "..", + ok: false, + }, { + value: "x..z", + ok: false, + }, { + value: "x." + marshalBase64JSON(t, map[string]interface{}{"hello": "world"}) + ".z", + ok: true, + claims: map[string]interface{}{"hello": "world"}, + }, { + value: "." + marshalBase64JSON(t, map[string]interface{}{"no header": "no signature"}) + ".", + ok: true, + claims: map[string]interface{}{"no header": "no signature"}, + }, + } { + token, err := Parse(tt.value) + if tt.ok { + if err != nil { + t.Errorf("unexpected error for %s: %v", tt.value, err) + continue + } + } else { + continue + } + + if !reflect.DeepEqual(tt.claims, token.Claims) { + t.Errorf("claims mismatch, expected: %v, got %v", tt.claims, token.Claims) + } + } +} + +func marshalBase64JSON(t *testing.T, v interface{}) string { + d, err := json.Marshal(v) + if err != nil { + t.Fatalf("failed to marshal json: %v, %v", v, err) + } + return base64.RawURLEncoding.EncodeToString(d) +} diff --git a/predicates/auth/jwt.go b/predicates/auth/jwt.go index 32748c141d..a56132724b 100644 --- a/predicates/auth/jwt.go +++ b/predicates/auth/jwt.go @@ -17,12 +17,11 @@ Examples: package auth import ( - "encoding/base64" - "encoding/json" "net/http" "regexp" "strings" + "github.com/zalando/skipper/jwt" "github.com/zalando/skipper/predicates" "github.com/zalando/skipper/routing" ) @@ -153,33 +152,21 @@ func (m regexMatcher) Match(jwtValue string) bool { func (p *predicate) Match(r *http.Request) bool { ahead := r.Header.Get(authHeaderName) - if !strings.HasPrefix(ahead, authHeaderPrefix) { + tv := strings.TrimPrefix(ahead, authHeaderPrefix) + if tv == ahead { return false } - fields := strings.FieldsFunc(ahead, func(r rune) bool { - return r == []rune(".")[0] - }) - if len(fields) != 3 { - return false - } - - sDec, err := base64.RawURLEncoding.DecodeString(fields[1]) - if err != nil { - return false - } - - var payload map[string]interface{} - err = json.Unmarshal(sDec, &payload) + token, err := jwt.Parse(tv) if err != nil { return false } switch p.matchBehavior { case matchBehaviorAll: - return allMatch(p.kv, payload) + return allMatch(p.kv, token.Claims) case matchBehaviorAny: - return anyMatch(p.kv, payload) + return anyMatch(p.kv, token.Claims) default: return false }