diff --git a/example_test.go b/example_test.go index 651841de..98d305e9 100644 --- a/example_test.go +++ b/example_test.go @@ -98,6 +98,75 @@ func ExampleParseWithClaims_customClaimsType() { // Output: bar test } +type claimsV1 struct { + jwt.RegisteredClaims + ID string + Exp int64 +} + +type claimsV2 struct { + jwt.RegisteredClaims + ID string + UserID string +} + +func (c *claimsV1) Valid() error { return nil } +func (c *claimsV2) Valid() error { return nil } +func (c *claimsV1) Version() string { return "v1" } +func (c *claimsV2) Version() string { return "v2" } + +func (c *claimsV1) Decode(claims jwt.Claims) (map[string]interface{}, error) { + c, ok := claims.(*claimsV1) + if !ok { + return map[string]interface{}{}, errors.New("couldnt decode") + } + + return map[string]interface{}{ + "id": "bar", + "expiration": fmt.Sprint(c.Exp), + }, nil +} + +func (c *claimsV2) Decode(claims jwt.Claims) (map[string]interface{}, error) { + c, ok := claims.(*claimsV2) + if !ok { + return map[string]interface{}{}, errors.New("couldnt decode") + } + + return map[string]interface{}{ + "id": " test", + "user_id": fmt.Sprint(c.UserID), + }, nil +} + +func ExampleParseWitVersionedClaims_customClaimsType() { + tokenStringV1 := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsInZlcnNpb24iOiJ2MSJ9.eyJJRCI6IjEyMyIsIkV4cCI6MTIzfQ.qbEStFoXm9UspByQtuSVa7vxP3z4-eGeLWf3mlONPgI" + tokenStringV2 := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsInZlcnNpb24iOiJ2MiJ9.eyJJRCI6IjEyMyIsIlVzZXJJRCI6IjEyMyJ9.HM6A9inm-Lo8S-2JhS1W7zyqOUWNcMROOfIYnQP2Rcw" + + jwtVersions := map[string]jwt.VersionedClaims{ + "v1": &claimsV1{}, + "v2": &claimsV2{}, + } + + claimsDataV1, err := jwt.ParseWithVersionedClaims(tokenStringV1, jwtVersions, func(token *jwt.Token) (interface{}, error) { + return []byte("1"), nil + }) + + claimsDataV2, err := jwt.ParseWithVersionedClaims(tokenStringV2, jwtVersions, func(token *jwt.Token) (interface{}, error) { + return []byte("1"), nil + }) + + if err != nil { + log.Fatal(err) + } else if len(claimsDataV1) > 0 && len(claimsDataV1) > 0 { + fmt.Println(claimsDataV1["id"].(string) + claimsDataV2["id"].(string)) + } else { + log.Fatal("unknown claims type, cannot proceed") + } + + // Output: bar test +} + // Example creating a token using a custom claims type and validation options. The RegisteredClaims is embedded // in the custom type to allow for easy encoding, parsing and validation of standard claims. func ExampleParseWithClaims_validationOptions() { diff --git a/parser.go b/parser.go index ecf99af7..db2eb21d 100644 --- a/parser.go +++ b/parser.go @@ -236,3 +236,57 @@ func Parse(tokenString string, keyFunc Keyfunc, options ...ParserOption) (*Token func ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc, options ...ParserOption) (*Token, error) { return NewParser(options...).ParseWithClaims(tokenString, claims, keyFunc) } + +// ParseWithVersionedClaims parses and validates a JWT token with versioned claims. +// It extracts the "version" field from the token header to select the appropriate claims +// structure from the provided claimsMap. The keyFunc supplies the key for verification. +// +// Parameters: +// - tokenString: The JWT token string to parse. +// - claimsMap: A map associating version strings with corresponding VersionedClaims structs. +// - keyFunc: A function returning the key for verification based on the parsed token. +// - options: Optional ParserOption(s) for parsing configuration. +// +// Returns: +// - map[string]interface{}: The decoded claims. +// - error: An error if parsing or validation fails. +func ParseWithVersionedClaims(tokenString string, claimsMap map[string]VersionedClaims, keyFunc Keyfunc, options ...ParserOption) (map[string]interface{}, error) { + p := NewParser(options...) + parts := strings.Split(tokenString, ".") + if len(parts) != 3 { + return nil, newError("token contains an invalid number of segments", ErrTokenMalformed) + } + + token := &Token{Raw: tokenString} + + headerBytes, err := p.DecodeSegment(parts[0]) + if err != nil { + return nil, newError("could not base64 decode header", ErrTokenMalformed, err) + } + + if err = json.Unmarshal(headerBytes, &token.Header); err != nil { + return nil, newError("could not JSON decode header", ErrTokenMalformed, err) + } + + versionValue, ok := token.Header["version"] + if !ok { + return nil, newError("version field missing in token header", ErrTokenMalformed) + } + + versionStr, ok := versionValue.(string) + if !ok { + return nil, newError("version field in token header is not a string", ErrTokenMalformed) + } + + claims, ok := claimsMap[versionStr] + if !ok { + return nil, newError(fmt.Sprintf("unsupported token version: %s", versionStr), ErrTokenMalformed) + } + + token, err = NewParser(options...).ParseWithClaims(tokenString, claims, keyFunc) + if err != nil { + return nil, newError("could not parse token with claims", ErrTokenMalformed, err) + } + + return claims.Decode(token.Claims) +} diff --git a/token.go b/token.go index 9c7f4ab0..1aca3ece 100644 --- a/token.go +++ b/token.go @@ -55,6 +55,26 @@ func NewWithClaims(method SigningMethod, claims Claims, opts ...TokenOption) *To } } +type VersionedClaims interface { + Claims + Decode(claims Claims) (map[string]interface{}, error) + Version() string +} + +// NewWithClaims creates a new [Token] with the specified signing method and +// claims. Additional options can be specified, but are currently unused. +func NewWithVersion(method SigningMethod, claims VersionedClaims, opts ...TokenOption) *Token { + return &Token{ + Header: map[string]interface{}{ + "typ": "JWT", + "alg": method.Alg(), + "version": claims.Version(), + }, + Claims: claims, + Method: method, + } +} + // SignedString creates and returns a complete, signed JWT. The token is signed // using the SigningMethod specified in the token. Please refer to // https://golang-jwt.github.io/jwt/usage/signing_methods/#signing-methods-and-key-types