Skip to content

Commit 321782a

Browse files
committed
Sign/Verify does take the decoded form now
1 parent 352f411 commit 321782a

18 files changed

+153
-115
lines changed

ecdsa.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,19 @@ func (m *SigningMethodECDSA) Verify(signingString string, sig []byte, key interf
8989

9090
// Sign implements token signing for the SigningMethod.
9191
// For this signing method, key must be an ecdsa.PrivateKey struct
92-
func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string, error) {
92+
func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) ([]byte, error) {
9393
// Get the key
9494
var ecdsaKey *ecdsa.PrivateKey
9595
switch k := key.(type) {
9696
case *ecdsa.PrivateKey:
9797
ecdsaKey = k
9898
default:
99-
return "", ErrInvalidKeyType
99+
return nil, ErrInvalidKeyType
100100
}
101101

102102
// Create the hasher
103103
if !m.Hash.Available() {
104-
return "", ErrHashUnavailable
104+
return nil, ErrHashUnavailable
105105
}
106106

107107
hasher := m.Hash.New()
@@ -112,7 +112,7 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string
112112
curveBits := ecdsaKey.Curve.Params().BitSize
113113

114114
if m.CurveBits != curveBits {
115-
return "", ErrInvalidKey
115+
return nil, ErrInvalidKey
116116
}
117117

118118
keyBytes := curveBits / 8
@@ -127,8 +127,8 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string
127127
r.FillBytes(out[0:keyBytes]) // r is assigned to the first half of output.
128128
s.FillBytes(out[keyBytes:]) // s is assigned to the second half of output.
129129

130-
return EncodeSegment(out), nil
130+
return out, nil
131131
} else {
132-
return "", err
132+
return nil, err
133133
}
134134
}

ecdsa_test.go

+12-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package jwt_test
33
import (
44
"crypto/ecdsa"
55
"os"
6+
"reflect"
67
"strings"
78
"testing"
89

@@ -90,15 +91,16 @@ func TestECDSASign(t *testing.T) {
9091
toSign := strings.Join(parts[0:2], ".")
9192
method := jwt.GetSigningMethod(data.alg)
9293
sig, err := method.Sign(toSign, ecdsaKey)
93-
9494
if err != nil {
9595
t.Errorf("[%v] Error signing token: %v", data.name, err)
9696
}
97-
if sig == parts[2] {
98-
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig)
97+
98+
ssig := encodeSegment(sig)
99+
if ssig == parts[2] {
100+
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], ssig)
99101
}
100102

101-
err = method.Verify(toSign, decodeSegment(t, sig), ecdsaKey.Public())
103+
err = method.Verify(toSign, sig, ecdsaKey.Public())
102104
if err != nil {
103105
t.Errorf("[%v] Sign produced an invalid signature: %v", data.name, err)
104106
}
@@ -155,15 +157,15 @@ func BenchmarkECDSASigning(b *testing.B) {
155157
if err != nil {
156158
b.Fatalf("[%v] Error signing token: %v", data.name, err)
157159
}
158-
if sig == parts[2] {
160+
if reflect.DeepEqual(sig, decodeSegment(b, parts[2])) {
159161
b.Fatalf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig)
160162
}
161163
}
162164
})
163165
}
164166
}
165167

166-
func decodeSegment(t *testing.T, signature string) (sig []byte) {
168+
func decodeSegment(t interface{ Fatalf(string, ...any) }, signature string) (sig []byte) {
167169
var err error
168170
sig, err = jwt.NewParser().DecodeSegment(signature)
169171
if err != nil {
@@ -172,3 +174,7 @@ func decodeSegment(t *testing.T, signature string) (sig []byte) {
172174

173175
return
174176
}
177+
178+
func encodeSegment(sig []byte) string {
179+
return (&jwt.Token{}).EncodeSegment(sig)
180+
}

ed25519.go

+9-7
Original file line numberDiff line numberDiff line change
@@ -56,23 +56,25 @@ func (m *SigningMethodEd25519) Verify(signingString string, sig []byte, key inte
5656

5757
// Sign implements token signing for the SigningMethod.
5858
// For this signing method, key must be an ed25519.PrivateKey
59-
func (m *SigningMethodEd25519) Sign(signingString string, key interface{}) (string, error) {
59+
func (m *SigningMethodEd25519) Sign(signingString string, key interface{}) ([]byte, error) {
6060
var ed25519Key crypto.Signer
6161
var ok bool
6262

6363
if ed25519Key, ok = key.(crypto.Signer); !ok {
64-
return "", ErrInvalidKeyType
64+
return nil, ErrInvalidKeyType
6565
}
6666

6767
if _, ok := ed25519Key.Public().(ed25519.PublicKey); !ok {
68-
return "", ErrInvalidKey
68+
return nil, ErrInvalidKey
6969
}
7070

71-
// Sign the string and return the encoded result
72-
// ed25519 performs a two-pass hash as part of its algorithm. Therefore, we need to pass a non-prehashed message into the Sign function, as indicated by crypto.Hash(0)
71+
// Sign the string and return the result. ed25519 performs a two-pass hash
72+
// as part of its algorithm. Therefore, we need to pass a non-prehashed
73+
// message into the Sign function, as indicated by crypto.Hash(0)
7374
sig, err := ed25519Key.Sign(rand.Reader, []byte(signingString), crypto.Hash(0))
7475
if err != nil {
75-
return "", err
76+
return nil, err
7677
}
77-
return EncodeSegment(sig), nil
78+
79+
return sig, nil
7880
}

ed25519_test.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@ func TestEd25519Sign(t *testing.T) {
7777
if err != nil {
7878
t.Errorf("[%v] Error signing token: %v", data.name, err)
7979
}
80-
if sig == parts[2] && !data.valid {
81-
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig)
80+
81+
ssig := encodeSegment(sig)
82+
if ssig == parts[2] && !data.valid {
83+
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], ssig)
8284
}
8385
}
8486
}

hmac.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,17 @@ func (m *SigningMethodHMAC) Verify(signingString string, sig []byte, key interfa
7373

7474
// Sign implements token signing for the SigningMethod.
7575
// Key must be []byte
76-
func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string, error) {
76+
func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) ([]byte, error) {
7777
if keyBytes, ok := key.([]byte); ok {
7878
if !m.Hash.Available() {
79-
return "", ErrHashUnavailable
79+
return nil, ErrHashUnavailable
8080
}
8181

8282
hasher := hmac.New(m.Hash.New, keyBytes)
8383
hasher.Write([]byte(signingString))
8484

85-
return EncodeSegment(hasher.Sum(nil)), nil
85+
return hasher.Sum(nil), nil
8686
}
8787

88-
return "", ErrInvalidKeyType
88+
return nil, ErrInvalidKeyType
8989
}

hmac_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package jwt_test
22

33
import (
44
"os"
5+
"reflect"
56
"strings"
67
"testing"
78

@@ -72,7 +73,7 @@ func TestHMACSign(t *testing.T) {
7273
if err != nil {
7374
t.Errorf("[%v] Error signing token: %v", data.name, err)
7475
}
75-
if sig != parts[2] {
76+
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
7677
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
7778
}
7879
}

none.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ func (m *signingMethodNone) Verify(signingString string, sig []byte, key interfa
4141
}
4242

4343
// Only allow 'none' signing if UnsafeAllowNoneSignatureType is specified as the key
44-
func (m *signingMethodNone) Sign(signingString string, key interface{}) (string, error) {
44+
func (m *signingMethodNone) Sign(signingString string, key interface{}) ([]byte, error) {
4545
if _, ok := key.(unsafeNoneMagicConstant); ok {
46-
return "", nil
46+
return []byte{}, nil
4747
}
48-
return "", NoneSignatureTypeDisallowedError
48+
49+
return nil, NoneSignatureTypeDisallowedError
4950
}

none_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package jwt_test
22

33
import (
4+
"reflect"
45
"strings"
56
"testing"
67

@@ -65,7 +66,7 @@ func TestNoneSign(t *testing.T) {
6566
if err != nil {
6667
t.Errorf("[%v] Error signing token: %v", data.name, err)
6768
}
68-
if sig != parts[2] {
69+
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
6970
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
7071
}
7172
}

parser.go

+31-6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ type Parser struct {
1919
skipClaimsValidation bool
2020

2121
validator *validator
22+
23+
decodeStrict bool
24+
25+
decodePaddingAllowed bool
2226
}
2327

2428
// NewParser creates a new Parser with the specified options
@@ -169,22 +173,43 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
169173
return token, parts, nil
170174
}
171175

172-
// DecodeSegment decodes a JWT specific base64url encoding with padding stripped
173-
//
174-
// Deprecated: In a future release, we will demote this function to a
175-
// non-exported function, since it should only be used internally
176+
// DecodeSegment decodes a JWT specific base64url encoding. This function will
177+
// take into account whether the [Parser] is configured with additional options,
178+
// such as [WithStrictDecoding] or [WithPaddingAllowed].
176179
func (p *Parser) DecodeSegment(seg string) ([]byte, error) {
177180
encoding := base64.RawURLEncoding
178181

179-
if DecodePaddingAllowed {
182+
if p.decodePaddingAllowed {
180183
if l := len(seg) % 4; l > 0 {
181184
seg += strings.Repeat("=", 4-l)
182185
}
183186
encoding = base64.URLEncoding
184187
}
185188

186-
if DecodeStrict {
189+
if p.decodeStrict {
187190
encoding = encoding.Strict()
188191
}
189192
return encoding.DecodeString(seg)
190193
}
194+
195+
// Parse parses, validates, verifies the signature and returns the parsed token.
196+
// keyFunc will receive the parsed token and should return the cryptographic key
197+
// for verifying the signature. The caller is strongly encouraged to set the
198+
// WithValidMethods option to validate the 'alg' claim in the token matches the
199+
// expected algorithm. For more details about the importance of validating the
200+
// 'alg' claim, see
201+
// https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
202+
func Parse(tokenString string, keyFunc Keyfunc, options ...ParserOption) (*Token, error) {
203+
return NewParser(options...).Parse(tokenString, keyFunc)
204+
}
205+
206+
// ParseWithClaims is a shortcut for NewParser().ParseWithClaims().
207+
//
208+
// Note: If you provide a custom claim implementation that embeds one of the
209+
// standard claims (such as RegisteredClaims), make sure that a) you either
210+
// embed a non-pointer version of the claims or b) if you are using a pointer,
211+
// allocate the proper memory for it before passing in the overall claims,
212+
// otherwise you might run into a panic.
213+
func ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc, options ...ParserOption) (*Token, error) {
214+
return NewParser(options...).ParseWithClaims(tokenString, claims, keyFunc)
215+
}

parser_option.go

+26
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,29 @@ func WithSubject(sub string) ParserOption {
9999
p.validator.expectedSub = sub
100100
}
101101
}
102+
103+
// WithPaddingAllowed will enable the codec used for decoding JWTs to allow
104+
// padding. Note that the JWS RFC7515 states that the tokens will utilize a
105+
// Base64url encoding with no padding. Unfortunately, some implementations of
106+
// JWT are producing non-standard tokens, and thus require support for decoding.
107+
// Note that this is a global variable, and updating it will change the behavior
108+
// on a package level, and is also NOT go-routine safe. To use the
109+
// non-recommended decoding, set this boolean to `true` prior to using this
110+
// package.
111+
func WithPaddingAllowed() ParserOption {
112+
return func(p *Parser) {
113+
p.decodePaddingAllowed = true
114+
}
115+
}
116+
117+
// WithStrictDecoding will switch the codec used for decoding JWTs into strict
118+
// mode. In this mode, the decoder requires that trailing padding bits are zero,
119+
// as described in RFC 4648 section 3.5. Note that this is a global variable,
120+
// and updating it will change the behavior on a package level, and is also NOT
121+
// go-routine safe. To use strict decoding, set this boolean to `true` prior to
122+
// using this package.
123+
func WithStrictDecoding() ParserOption {
124+
return func(p *Parser) {
125+
p.decodeStrict = true
126+
}
127+
}

parser_test.go

+10-6
Original file line numberDiff line numberDiff line change
@@ -641,9 +641,6 @@ var setPaddingTestData = []struct {
641641
func TestSetPadding(t *testing.T) {
642642
for _, data := range setPaddingTestData {
643643
t.Run(data.name, func(t *testing.T) {
644-
jwt.DecodePaddingAllowed = data.paddedDecode
645-
jwt.DecodeStrict = data.strictDecode
646-
647644
// If the token string is blank, use helper function to generate string
648645
if data.tokenString == "" {
649646
data.tokenString = signToken(data.claims, data.signingMethod)
@@ -652,7 +649,16 @@ func TestSetPadding(t *testing.T) {
652649
// Parse the token
653650
var token *jwt.Token
654651
var err error
655-
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
652+
var opts []jwt.ParserOption = []jwt.ParserOption{jwt.WithoutClaimsValidation()}
653+
654+
if data.paddedDecode {
655+
opts = append(opts, jwt.WithPaddingAllowed())
656+
}
657+
if data.strictDecode {
658+
opts = append(opts, jwt.WithStrictDecoding())
659+
}
660+
661+
parser := jwt.NewParser(opts...)
656662

657663
// Figure out correct claims type
658664
token, err = parser.ParseWithClaims(data.tokenString, jwt.MapClaims{}, data.keyfunc)
@@ -666,8 +672,6 @@ func TestSetPadding(t *testing.T) {
666672
}
667673

668674
})
669-
jwt.DecodePaddingAllowed = false
670-
jwt.DecodeStrict = false
671675
}
672676
}
673677

rsa.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -67,27 +67,27 @@ func (m *SigningMethodRSA) Verify(signingString string, sig []byte, key interfac
6767

6868
// Sign implements token signing for the SigningMethod
6969
// For this signing method, must be an *rsa.PrivateKey structure.
70-
func (m *SigningMethodRSA) Sign(signingString string, key interface{}) (string, error) {
70+
func (m *SigningMethodRSA) Sign(signingString string, key interface{}) ([]byte, error) {
7171
var rsaKey *rsa.PrivateKey
7272
var ok bool
7373

7474
// Validate type of key
7575
if rsaKey, ok = key.(*rsa.PrivateKey); !ok {
76-
return "", ErrInvalidKey
76+
return nil, ErrInvalidKey
7777
}
7878

7979
// Create the hasher
8080
if !m.Hash.Available() {
81-
return "", ErrHashUnavailable
81+
return nil, ErrHashUnavailable
8282
}
8383

8484
hasher := m.Hash.New()
8585
hasher.Write([]byte(signingString))
8686

8787
// Sign the string and return the encoded bytes
8888
if sigBytes, err := rsa.SignPKCS1v15(rand.Reader, rsaKey, m.Hash, hasher.Sum(nil)); err == nil {
89-
return EncodeSegment(sigBytes), nil
89+
return sigBytes, nil
9090
} else {
91-
return "", err
91+
return nil, err
9292
}
9393
}

0 commit comments

Comments
 (0)