-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathsession.go
193 lines (160 loc) · 4.81 KB
/
session.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
package doubleratchet
import (
"bytes"
"fmt"
)
// Session of the party involved in the Double Ratchet Algorithm.
type Session interface {
// RatchetEncrypt performs a symmetric-key ratchet step, then AEAD-encrypts the message with
// the resulting message key.
RatchetEncrypt(plaintext, associatedData []byte) (Message, error)
// RatchetDecrypt is called to AEAD-decrypt messages.
RatchetDecrypt(m Message, associatedData []byte) ([]byte, error)
//DeleteMk remove a message key from the database
DeleteMk(Key, uint32) error
}
type sessionState struct {
id []byte
State
storage SessionStorage
}
// New creates session with the shared key.
func New(id []byte, sharedKey Key, keyPair DHPair, storage SessionStorage, opts ...option) (Session, error) {
state, err := newState(sharedKey, opts...)
if err != nil {
return nil, err
}
state.DHs = keyPair
session := &sessionState{id: id, State: state, storage: storage}
return session, session.store()
}
// NewWithRemoteKey creates session with the shared key and public key of the other party.
func NewWithRemoteKey(id []byte, sharedKey, remoteKey Key, storage SessionStorage, opts ...option) (Session, error) {
state, err := newState(sharedKey, opts...)
if err != nil {
return nil, err
}
state.DHs, err = state.Crypto.GenerateDH()
if err != nil {
return nil, fmt.Errorf("can't generate key pair: %s", err)
}
state.DHr = remoteKey
secret, err := state.Crypto.DH(state.DHs, state.DHr)
if err != nil {
return nil, fmt.Errorf("can't generate dh secret: %s", err)
}
state.SendCh, _ = state.RootCh.step(secret)
session := &sessionState{id: id, State: state, storage: storage}
return session, session.store()
}
// Load a session from a SessionStorage implementation and apply options.
func Load(id []byte, store SessionStorage, opts ...option) (Session, error) {
state, err := store.Load(id)
if err != nil {
return nil, err
}
if state == nil {
return nil, nil
}
if err = state.applyOptions(opts); err != nil {
return nil, err
}
s := &sessionState{id: id, State: *state}
s.storage = store
return s, nil
}
func (s *sessionState) store() error {
if s.storage != nil {
err := s.storage.Save(s.id, &s.State)
if err != nil {
return err
}
}
return nil
}
// RatchetEncrypt performs a symmetric-key ratchet step, then encrypts the message with
// the resulting message key.
func (s *sessionState) RatchetEncrypt(plaintext, ad []byte) (Message, error) {
var (
h = MessageHeader{
DH: s.DHs.PublicKey(),
N: s.SendCh.N,
PN: s.PN,
}
mk = s.SendCh.step()
)
ct, err := s.Crypto.Encrypt(mk, plaintext, append(ad, h.Encode()...))
if err != nil {
return Message{}, err
}
// Store state
if err := s.store(); err != nil {
return Message{}, err
}
return Message{h, ct}, nil
}
// DeleteMk deletes a message key
func (s *sessionState) DeleteMk(dh Key, n uint32) error {
return s.MkSkipped.DeleteMk(dh, uint(n))
}
// RatchetDecrypt is called to decrypt messages.
func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
// Is the message one of the skipped?
mk, ok, err := s.MkSkipped.Get(m.Header.DH, uint(m.Header.N))
if err != nil {
return nil, err
}
if ok {
plaintext, err := s.Crypto.Decrypt(mk, m.Ciphertext, append(ad, m.Header.Encode()...))
if err != nil {
return nil, fmt.Errorf("can't decrypt skipped message: %s", err)
}
if err := s.store(); err != nil {
return nil, err
}
return plaintext, nil
}
var (
// All changes must be applied on a different session object, so that this session won't be modified nor left in a dirty session.
sc = s.State
skippedKeys1 []skippedKey
skippedKeys2 []skippedKey
)
// Is there a new ratchet key?
if !bytes.Equal(m.Header.DH, sc.DHr) {
if skippedKeys1, err = sc.skipMessageKeys(sc.DHr, uint(m.Header.PN)); err != nil {
return nil, fmt.Errorf("can't skip previous chain message keys: %s", err)
}
if err = sc.dhRatchet(m.Header); err != nil {
return nil, fmt.Errorf("can't perform ratchet step: %s", err)
}
}
// After all, update the current chain.
if skippedKeys2, err = sc.skipMessageKeys(sc.DHr, uint(m.Header.N)); err != nil {
return nil, fmt.Errorf("can't skip current chain message keys: %s", err)
}
mk = sc.RecvCh.step()
plaintext, err := s.Crypto.Decrypt(mk, m.Ciphertext, append(ad, m.Header.Encode()...))
if err != nil {
return nil, fmt.Errorf("can't decrypt: %s", err)
}
// Append current key, waiting for confirmation
skippedKeys := append(skippedKeys1, skippedKeys2...)
skippedKeys = append(skippedKeys, skippedKey{
key: sc.DHr,
nr: uint(m.Header.N),
mk: mk,
seq: sc.KeysCount,
})
// Increment the number of keys
sc.KeysCount++
// Apply changes.
if err := s.applyChanges(sc, s.id, skippedKeys); err != nil {
return nil, err
}
// Store state
if err := s.store(); err != nil {
return nil, err
}
return plaintext, nil
}