forked from cristalhq/jwt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathalgo_es.go
132 lines (114 loc) · 2.64 KB
/
algo_es.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package jwt
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"math/big"
)
// NewSignerES returns a new ECDSA-based signer.
func NewSignerES(alg Algorithm, key *ecdsa.PrivateKey) (Signer, error) {
if key == nil {
return nil, ErrNilKey
}
hash, err := getParamsES(alg, roundBytes(key.PublicKey.Params().BitSize)*2)
if err != nil {
return nil, err
}
return &esAlg{
alg: alg,
hash: hash,
privateKey: key,
signSize: roundBytes(key.PublicKey.Params().BitSize) * 2,
}, nil
}
// NewVerifierES returns a new ECDSA-based verifier.
func NewVerifierES(alg Algorithm, key *ecdsa.PublicKey) (Verifier, error) {
if key == nil {
return nil, ErrNilKey
}
hash, err := getParamsES(alg, roundBytes(key.Params().BitSize)*2)
if err != nil {
return nil, err
}
return &esAlg{
alg: alg,
hash: hash,
publicKey: key,
signSize: roundBytes(key.Params().BitSize) * 2,
}, nil
}
func getParamsES(alg Algorithm, size int) (crypto.Hash, error) {
var hash crypto.Hash
switch alg {
case ES256:
hash = crypto.SHA256
case ES384:
hash = crypto.SHA384
case ES512:
hash = crypto.SHA512
default:
return 0, ErrUnsupportedAlg
}
if alg.keySize() != size {
return 0, ErrInvalidKey
}
return hash, nil
}
type esAlg struct {
alg Algorithm
hash crypto.Hash
publicKey *ecdsa.PublicKey
privateKey *ecdsa.PrivateKey
signSize int
}
func (es *esAlg) Algorithm() Algorithm {
return es.alg
}
func (es *esAlg) SignSize() int {
return es.signSize
}
func (es *esAlg) Sign(payload []byte) ([]byte, error) {
digest, err := hashPayload(es.hash, payload)
if err != nil {
return nil, err
}
r, s, errSign := ecdsa.Sign(rand.Reader, es.privateKey, digest)
if err != nil {
return nil, errSign
}
pivot := es.SignSize() / 2
rBytes, sBytes := r.Bytes(), s.Bytes()
signature := make([]byte, es.SignSize())
copy(signature[pivot-len(rBytes):], rBytes)
copy(signature[pivot*2-len(sBytes):], sBytes)
return signature, nil
}
func (es *esAlg) VerifyToken(token *Token) error {
if constTimeAlgEqual(token.Header().Algorithm, es.alg) {
return es.Verify(token.Payload(), token.Signature())
}
return ErrAlgorithmMismatch
}
func (es *esAlg) Verify(payload, signature []byte) error {
if len(signature) != es.SignSize() {
return ErrInvalidSignature
}
digest, err := hashPayload(es.hash, payload)
if err != nil {
return err
}
pivot := es.SignSize() / 2
r := big.NewInt(0).SetBytes(signature[:pivot])
s := big.NewInt(0).SetBytes(signature[pivot:])
if !ecdsa.Verify(es.publicKey, digest, r, s) {
return ErrInvalidSignature
}
return nil
}
func roundBytes(n int) int {
res := n / 8
if n%8 > 0 {
return res + 1
}
return res
}