-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathdatabase.go
540 lines (481 loc) · 15.2 KB
/
database.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package webrisk
import (
"bytes"
"compress/gzip"
"context"
"encoding/gob"
"errors"
"log"
"math/rand"
"os"
"sync"
"time"
pb "github.com/google/webrisk/internal/webrisk_proto"
)
// jitter is the maximum amount of time that we expect an API list update to
// actually take. We add this time to the update period time to give some
// leeway before declaring the database as stale.
const (
maxRetryDelay = 24 * time.Hour
baseRetryDelay = 15 * time.Minute
jitter = 30 * time.Second
)
// database tracks the state of the threat lists published by the Webrisk API.
// Since the global blocklist is constantly changing, the contents of the
// database needs to be periodically synced with the Webrisk servers in
// order to provide protection for the latest threats.
//
// The process for updating the database is as follows:
// - At startup, if a database file is provided, then load it. If loaded
// properly (not corrupted and not stale), then set tfu as the contents.
// Otherwise, pull a new threat list from the Web Risk API.
// - Periodically, synchronize the database with the Web Risk API.
// This uses the Version Token fields to update only parts of the threat list that have
// changed since the last sync.
// - Anytime tfu is updated, generate a new tfl.
//
// The process for querying the database is as follows:
// - Check if the requested full hash matches any partial hash in tfl.
// If a match is found, return a set of ThreatTypes with a partial match.
type database struct {
ml sync.RWMutex // Protects tfl, err, and last
// threatsForLookup maps ThreatTypes to sets of partial hashes.
// This data structure is in a format that is easily queried.
tfl threatsForLookup
err error // Last error encountered
last time.Time // Last time the threat list were synced
config *Config
// threatsForUpdate maps ThreatTypes to lists of partial hashes.
// This data structure is in a format that is easily updated by the API.
// It is also the form that is written to disk.
tfu threatsForUpdate
mu sync.Mutex // Protects tfu
readyCh chan struct{} // Used for waiting until not in an error state.
updateAPIErrors uint // Number of times we attempted to contact the api and failed
log *log.Logger
}
type threatsForUpdate map[ThreatType]partialHashes
type partialHashes struct {
// Since the Hashes field is only needed when storing to disk and when
// updating, this field is cleared except for when it is in use.
// This is done to reduce memory usage as the contents of this can be
// regenerated from the tfl.
Hashes hashPrefixes
SHA256 []byte // The SHA256 over Hashes
State []byte // Arbitrary binary blob to synchronize state with API
}
type threatsForLookup map[ThreatType]hashSet
// databaseFormat is a light struct used only for gob encoding and decoding.
// As written to disk, the format of the database file is basically the gzip
// compressed version of the gob encoding of databaseFormat.
type databaseFormat struct {
Table threatsForUpdate
Time time.Time
}
// Init initializes the database from the specified file in config.DBPath.
// It reports true if the database was successfully loaded. If it reports false
// use Status for more details on the failure.
func (db *database) Init(config *Config, logger *log.Logger) bool {
db.mu.Lock()
defer db.mu.Unlock()
db.setError(errors.New("not intialized"))
db.config = config
db.log = logger
if db.config.DBPath == "" {
db.log.Printf("no database file specified")
db.setError(errors.New("no database loaded"))
return false
}
dbf, err := loadDatabase(db.config.DBPath)
if err != nil {
db.log.Printf("load failure: %v", err)
db.setError(err)
return false
}
// Validate that the database threat list stored on disk is not too stale.
if db.isStale(dbf.Time) {
db.log.Printf("database loaded is stale")
db.ml.Lock()
defer db.ml.Unlock()
db.setStale()
return false
}
// Validate that the database threat list stored on disk is at least a
// superset of the specified configuration.
tfuNew := make(threatsForUpdate)
for _, td := range db.config.ThreatLists {
if row, ok := dbf.Table[td]; ok {
tfuNew[td] = row
} else {
db.log.Printf("database configuration mismatch, missing %v", td)
db.setError(errors.New("database configuration mismatch"))
return false
}
}
db.tfu = tfuNew
db.generateThreatsForLookups(dbf.Time)
return true
}
// Status reports the health of the database. The database is considered faulted
// if there was an error during update or if the last update has gone stale. If
// in a faulted state, the db may repair itself on the next Update.
func (db *database) Status() error {
db.ml.RLock()
defer db.ml.RUnlock()
if db.err != nil {
return db.err
}
if db.isStale(db.last) {
db.setStale()
return db.err
}
return nil
}
// UpdateLag reports the amount of time in between when we expected to run
// a database update and the current time
func (db *database) UpdateLag() time.Duration {
lag := db.SinceLastUpdate()
if lag < db.config.UpdatePeriod {
return 0
}
return lag - db.config.UpdatePeriod
}
// SinceLastUpdate gives the duration since the last database update
func (db *database) SinceLastUpdate() time.Duration {
db.ml.RLock()
defer db.ml.RUnlock()
return db.config.now().Sub(db.last)
}
// Ready returns a channel that's closed when the database is ready for queries.
func (db *database) Ready() <-chan struct{} {
return db.readyCh
}
// Update synchronizes the local threat lists with those maintained by the
// global Web Risk API servers. If the update is successful, Status should
// report a nil error.
func (db *database) Update(ctx context.Context, api api) (time.Duration, bool) {
db.mu.Lock()
defer db.mu.Unlock()
// Construct and make the requests.
var s []*pb.ComputeThreatListDiffRequest
for _, td := range db.config.ThreatLists {
var state []byte
if row, ok := db.tfu[td]; ok {
state = row.State
}
s = append(s, &pb.ComputeThreatListDiffRequest{
ThreatType: pb.ThreatType(td),
Constraints: &pb.ComputeThreatListDiffRequest_Constraints{
SupportedCompressions: db.config.compressionTypes,
MaxDiffEntries: db.config.MaxDiffEntries,
MaxDatabaseEntries: db.config.MaxDatabaseEntries,
},
VersionToken: state,
})
}
var resps []*pb.ComputeThreatListDiffResponse
// add jitter to wait time to avoid all servers lining up
nextUpdateWait := db.config.UpdatePeriod + time.Duration(rand.Int31n(60)-30)*time.Second
last := db.config.now()
for _, req := range s {
// Query the API for the threat list and update the database.
resp, err := api.ListUpdate(ctx, req)
if err != nil {
db.log.Printf("ListUpdate failure (%d): %v", db.updateAPIErrors+1, err)
db.setError(err)
// backoff strategy: MIN((2**N-1 * 15 minutes) * (RAND + 1), 24 hours)
n := 1 << db.updateAPIErrors
delay := time.Duration(float64(n) * (rand.Float64() + 1) * float64(baseRetryDelay))
if delay > maxRetryDelay {
delay = maxRetryDelay
}
db.updateAPIErrors++
return delay, false
}
resps = append(resps, resp)
if resp.RecommendedNextDiff != nil {
ndiff := resp.RecommendedNextDiff.AsTime()
serverMinWait := time.Duration(ndiff.Sub(time.Now()))
if serverMinWait > nextUpdateWait {
nextUpdateWait = serverMinWait
db.log.Printf("Server requested next update in %v", nextUpdateWait)
}
}
}
// If for some reason we missed a request or didn't get a response the
// rest of the logic may fail.
if len(s) != len(resps) {
db.setError(errors.New("mismatch between requests sent and responses received"))
return nextUpdateWait, false
}
db.updateAPIErrors = 0
// Update the threat database with the response.
db.generateThreatsForUpdate()
for i, resp := range resps {
// Assume a 1:1 correspondence between request and response
if err := db.tfu.update(resp, ThreatType(s[i].ThreatType)); err != nil {
db.setError(err)
db.log.Printf("update failure: %v", err)
db.tfu = nil
return nextUpdateWait, false
}
}
dbf := databaseFormat{make(threatsForUpdate), last}
for td, phs := range db.tfu {
// Copy of partialHashes before generateThreatsForLookups clobbers it.
dbf.Table[td] = phs
}
db.generateThreatsForLookups(last)
// Regenerate the database and store it.
if db.config.DBPath != "" {
// Semantically, we ignore save errors, but we do log them.
if err := saveDatabase(db.config.DBPath, dbf); err != nil {
db.log.Printf("save failure: %v", err)
}
}
return nextUpdateWait, true
}
// Lookup looks up the full hash in the threat list and returns a partial
// hash and a set of ThreatTypes that may match the full hash.
func (db *database) Lookup(hash hashPrefix) (h hashPrefix, tds []ThreatType) {
if !hash.IsFull() {
panic("hash is not full")
}
db.ml.RLock()
for td, hs := range db.tfl {
if n := hs.Lookup(hash); n > 0 {
h = hash[:n]
tds = append(tds, td)
}
}
db.ml.RUnlock()
return h, tds
}
// setError clears the database state and sets the last error to be err.
//
// This assumes that the db.mu lock is already held.
func (db *database) setError(err error) {
db.tfu = nil
db.ml.Lock()
if db.err == nil {
db.readyCh = make(chan struct{})
}
db.tfl, db.err, db.last = nil, err, time.Time{}
db.ml.Unlock()
}
// isStale checks whether the last successful update should be considered stale.
// Staleness is defined as being older than two of the configured update periods
// plus jitter.
func (db *database) isStale(lastUpdate time.Time) bool {
return db.config.now().Sub(lastUpdate) > 2*(db.config.UpdatePeriod+jitter)
}
// setStale sets the error state to a stale message, without clearing
// the database state.
//
// This assumes that the db.ml lock is already held.
func (db *database) setStale() {
if db.err == nil {
db.readyCh = make(chan struct{})
}
db.err = errStale
}
// clearError clears the db error state, and unblocks any callers of
// WaitUntilReady.
//
// This assumes that the db.mu lock is already held.
func (db *database) clearError() {
db.ml.Lock()
defer db.ml.Unlock()
if db.err != nil {
close(db.readyCh)
}
db.err = nil
}
// generateThreatsForUpdate regenerates the threatsForUpdate hashes from
// the threatsForLookup. We do this to avoid holding onto the hash lists for
// a long time, needlessly occupying lots of memory.
//
// This assumes that the db.mu lock is already held.
func (db *database) generateThreatsForUpdate() {
if db.tfu == nil {
db.tfu = make(threatsForUpdate)
}
db.ml.RLock()
for td, hs := range db.tfl {
phs := db.tfu[td]
phs.Hashes = hs.Export()
db.tfu[td] = phs
}
db.ml.RUnlock()
}
// generateThreatsForLookups regenerates the threatsForLookup data structure
// from the threatsForUpdate data structure and stores the last timestamp.
// Since the hashes are effectively stored as a set inside the threatsForLookup,
// we clear out the hashes slice in threatsForUpdate so that it can be GCed.
//
// This assumes that the db.mu lock is already held.
func (db *database) generateThreatsForLookups(last time.Time) {
tfl := make(threatsForLookup)
for td, phs := range db.tfu {
var hs hashSet
hs.Import(phs.Hashes)
tfl[td] = hs
phs.Hashes = nil // Clear hashes to keep memory usage low
db.tfu[td] = phs
}
db.ml.Lock()
wasBad := db.err != nil
db.tfl, db.last = tfl, last
db.ml.Unlock()
if wasBad {
db.clearError()
db.log.Printf("database is now healthy")
}
}
// saveDatabase saves the database threat list to a file.
func saveDatabase(path string, db databaseFormat) (err error) {
var file *os.File
file, err = os.Create(path)
if err != nil {
return err
}
defer func() {
if cerr := file.Close(); err == nil {
err = cerr
}
}()
gz, err := gzip.NewWriterLevel(file, gzip.BestCompression)
if err != nil {
return err
}
defer func() {
if zerr := gz.Close(); err == nil {
err = zerr
}
}()
encoder := gob.NewEncoder(gz)
if err = encoder.Encode(db); err != nil {
return err
}
return nil
}
// loadDatabase loads the database state from a file.
func loadDatabase(path string) (db databaseFormat, err error) {
var file *os.File
file, err = os.Open(path)
if err != nil {
return db, err
}
defer func() {
if cerr := file.Close(); err == nil {
err = cerr
}
}()
gz, err := gzip.NewReader(file)
if err != nil {
return db, err
}
defer func() {
if zerr := gz.Close(); err == nil {
err = zerr
}
}()
decoder := gob.NewDecoder(gz)
if err = decoder.Decode(&db); err != nil {
return db, err
}
for _, dv := range db.Table {
if !bytes.Equal(dv.SHA256, dv.Hashes.SHA256()) {
return db, errors.New("webrisk: threat list SHA256 mismatch")
}
}
return db, nil
}
// update updates the threat list according to the API response.
func (tfu threatsForUpdate) update(resp *pb.ComputeThreatListDiffResponse, td ThreatType) error {
phs, ok := tfu[td]
removalQuantity := 0
if resp.ResponseType == pb.ComputeThreatListDiffResponse_RESET {
phs = partialHashes{}
}
if resp.Removals != nil {
if resp.Removals.RawIndices != nil {
removalQuantity += len(resp.Removals.RawIndices.Indices)
}
if resp.Removals.RiceIndices != nil {
if resp.Removals.RiceIndices.EntryCount == 0 {
removalQuantity++
} else {
removalQuantity += int(resp.Removals.RiceIndices.EntryCount)
}
}
switch resp.ResponseType {
case pb.ComputeThreatListDiffResponse_DIFF:
if !ok {
return errors.New("webrisk: partial update received for non-existent key")
}
case pb.ComputeThreatListDiffResponse_RESET:
if removalQuantity > 0 {
return errors.New("webrisk: indices to be removed included in a full update")
}
default:
return errors.New("webrisk: unknown response type")
}
// Hashes must be sorted for removal logic to work properly.
phs.Hashes.Sort()
idxs, err := decodeIndices(resp.Removals)
if err != nil {
return err
}
for _, i := range idxs {
if i < 0 || i >= int32(len(phs.Hashes)) {
return errors.New("webrisk: invalid removal index")
}
phs.Hashes[i] = ""
}
// If any removal was performed, compact the list of hashes.
if removalQuantity > 0 {
compactHashes := phs.Hashes[:0]
for _, h := range phs.Hashes {
if h != "" {
compactHashes = append(compactHashes, h)
}
}
phs.Hashes = compactHashes
}
}
if resp.Additions != nil {
hashes, err := decodeHashes(resp.Additions)
if err != nil {
return err
}
phs.Hashes = append(phs.Hashes, hashes...)
}
// Hashes must be sorted for SHA256 checksum to be correct.
phs.Hashes.Sort()
if err := phs.Hashes.Validate(); err != nil {
return err
}
if cs := resp.GetChecksum(); cs != nil {
phs.SHA256 = cs.Sha256
}
if !bytes.Equal(phs.SHA256, phs.Hashes.SHA256()) {
return errors.New("webrisk: threat list SHA256 mismatch")
}
phs.State = resp.NewVersionToken
tfu[td] = phs
return nil
}