Skip to content

Commit

Permalink
internal/vuln: cache DB
Browse files Browse the repository at this point in the history
Put a cache in front of the vuln DB. The cache uses the modified time
of the DB to determine when it is out of date.

The cache has no maximum size, because we don't expect the vuln DB to
be too large. Currently the unzipped JSON files total under 3 MB, as
determined by listing the bucket with gsutil ls -l. Since most of it
is strings, that should be pretty close to the memory occupied by the
unmarshaled data.

Also, make other minor changes to the package that don't affect its
behavior, mostly improved doc.

Change-Id: I374fa043805fe0fa939afac80ac0ee21e71e9167
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/636515
Reviewed-by: Robert Findley <[email protected]>
kokoro-CI: kokoro <[email protected]>
Reviewed-by: Tatiana Bradley <[email protected]>
LUCI-TryBot-Result: Go LUCI <[email protected]>
  • Loading branch information
jba committed Dec 16, 2024
1 parent d5afb0b commit 0a279fb
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 97 deletions.
162 changes: 76 additions & 86 deletions internal/vuln/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
package vuln

import (
"bytes"
"context"
"encoding/json"
"fmt"
"path"
"sort"
"strings"
"sync"
"time"

"golang.org/x/pkgsite/internal/derrors"
"golang.org/x/pkgsite/internal/log"
Expand All @@ -21,18 +23,21 @@ import (

// Client reads Go vulnerability databases.
type Client struct {
src source
src source
mu sync.Mutex
cache map[string]any
modified time.Time // the modified time of the DB
modifiedFetched time.Time // when we last read the modified time from the source
}

// NewClient returns a client that can read from the vulnerability
// database in src (a URL representing either a http or file source).
// database in src, a URL representing either an http or file source.
func NewClient(src string) (*Client, error) {
s, err := NewSource(src)
if err != nil {
return nil, err
}

return &Client{src: s}, nil
return newClient(s), nil
}

// NewInMemoryClient creates an in-memory vulnerability client for use
Expand All @@ -42,9 +47,14 @@ func NewInMemoryClient(entries []*osv.Entry) (*Client, error) {
if err != nil {
return nil, err
}
return &Client{src: inMemory}, nil
return newClient(inMemory), nil
}

func newClient(src source) *Client {
return &Client{src: src, cache: map[string]any{}}
}

// A PackageRequest provides arguments to [Client.ByPackage].
type PackageRequest struct {
// Module is the module path to filter on.
// ByPackage will only return entries that affect this module.
Expand Down Expand Up @@ -95,40 +105,22 @@ func (c *Client) ByPackage(ctx context.Context, req *PackageRequest) (_ []*osv.E
})
}

// modulesFilter returns all the modules in the DB for which filter returns true.
func (c *Client) modulesFilter(ctx context.Context, filter func(*ModuleMeta) bool, n int) ([]*ModuleMeta, error) {
if n == 0 {
return nil, nil
}

b, err := c.modules(ctx)
all, _, err := get[[]*ModuleMeta](ctx, c, modulesEndpoint)
if err != nil {
return nil, err
}

dec, err := newStreamDecoder(b)
if err != nil {
return nil, err
}

ms := make([]*ModuleMeta, 0)
for dec.More() {
var m ModuleMeta
err := dec.Decode(&m)
if err != nil {
return nil, err
}
if filter(&m) {
ms = append(ms, &m)
if len(ms) == n {
return ms, nil
}
var ms []*ModuleMeta
for _, m := range all {
if filter(m) {
ms = append(ms, m)
}
}

if len(ms) == 0 {
return nil, nil
}

return ms, nil
}

Expand Down Expand Up @@ -163,18 +155,12 @@ func isAffected(e *osv.Entry, req *PackageRequest) bool {
func (c *Client) ByID(ctx context.Context, id string) (_ *osv.Entry, err error) {
derrors.Wrap(&err, "ByID(%s)", id)

b, err := c.entry(ctx, id)
entry, _, err := get[osv.Entry](ctx, c, path.Join(idDir, id))
if err != nil {
// entry only fails if the entry is not found, so do not return
// the error.
return nil, nil
}

var entry osv.Entry
if err := json.Unmarshal(b, &entry); err != nil {
return nil, err
}

return &entry, nil
}

Expand All @@ -183,29 +169,17 @@ func (c *Client) ByID(ctx context.Context, id string) (_ *osv.Entry, err error)
func (c *Client) ByAlias(ctx context.Context, alias string) (_ string, err error) {
derrors.Wrap(&err, "ByAlias(%s)", alias)

b, err := c.vulns(ctx)
vs, err := c.vulns(ctx)
if err != nil {
return "", err
}

dec, err := newStreamDecoder(b)
if err != nil {
return "", err
}

for dec.More() {
var v VulnMeta
err := dec.Decode(&v)
if err != nil {
return "", err
}
for _, v := range vs {
for _, vAlias := range v.Aliases {
if alias == vAlias {
return v.ID, nil
}
}
}

return "", derrors.NotFound
}

Expand Down Expand Up @@ -379,51 +353,67 @@ func (c *Client) byIDs(ctx context.Context, ids []string) (_ []*osv.Entry, err e

// IDs returns a list of the IDs of all the entries in the database.
func (c *Client) IDs(ctx context.Context) (_ []string, err error) {
b, err := c.vulns(ctx)
vs, err := c.vulns(ctx)
if err != nil {
return nil, err
}

dec, err := newStreamDecoder(b)
if err != nil {
return nil, err
}

var ids []string
for dec.More() {
var v VulnMeta
err := dec.Decode(&v)
if err != nil {
return nil, err
}
for _, v := range vs {
ids = append(ids, v.ID)
}

return ids, nil
}

// newStreamDecoder returns a decoder that can be used
// to read an array of JSON objects.
func newStreamDecoder(b []byte) (*json.Decoder, error) {
dec := json.NewDecoder(bytes.NewBuffer(b))
func (c *Client) vulns(ctx context.Context) ([]VulnMeta, error) {
vms, _, err := get[[]VulnMeta](ctx, c, vulnsEndpoint)
return vms, err
}

// skip open bracket
_, err := dec.Token()
// After this time, consider our value of modified to be stale.
// var for testing.
var modifiedStaleDur = 5 * time.Minute

// get returns the contents of endpoint as a T, checking the cache first.
// It also reports whether it found the value in the cache.
func get[T any](ctx context.Context, c *Client, endpoint string) (t T, cached bool, err error) {
c.mu.Lock()
defer c.mu.Unlock()
if time.Since(c.modifiedFetched) > modifiedStaleDur {
// c.modified is stale; reread the DB's modified time.
data, err := c.src.get(ctx, dbEndpoint)
if err != nil {
return t, false, err
}
var m DBMeta
if err := json.Unmarshal(data, &m); err != nil {
return t, false, fmt.Errorf("unmarshaling DBMeta: %w", err)
}
c.modifiedFetched = time.Now()
// If the DB has been modified since the last time we checked,
// clear the cache and note the new modified time.
// We only compare modified times with each other, as per
// https://go.dev/doc/security/vuln/database#api:
// "the modified time should not be compared to wall clock time".
if !m.Modified.Equal(c.modified) {
clear(c.cache)
c.modified = m.Modified
}
}
if mms, ok := c.cache[endpoint]; ok {
return mms.(T), true, nil
}
c.mu.Unlock()
data, err := c.src.get(ctx, endpoint)
c.mu.Lock()
// NOTE: Errors aren't cached, so an endpoint that repeatedly fails will be expensive.
// On the other hand, we won't turn transient errors into near-permanent ones.
// TODO: cache 4xx errors, since we get a lot of spammy traffic.
if err != nil {
return nil, err
return t, false, err
}

return dec, nil
}

func (c *Client) modules(ctx context.Context) ([]byte, error) {
return c.src.get(ctx, modulesEndpoint)
}

func (c *Client) vulns(ctx context.Context) ([]byte, error) {
return c.src.get(ctx, vulnsEndpoint)
}

func (c *Client) entry(ctx context.Context, id string) ([]byte, error) {
return c.src.get(ctx, path.Join(idDir, id))
if err := json.Unmarshal(data, &t); err != nil {
return t, false, err
}
c.cache[endpoint] = t
return t, false, nil
}
58 changes: 57 additions & 1 deletion internal/vuln/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"fmt"
"path/filepath"
"reflect"
"strconv"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -574,7 +575,7 @@ func newTestClientFromTxtar(txtarFile string) (*Client, error) {
data[f.Name] = fdata
}

return &Client{&inMemorySource{data: data}}, nil
return newClient(&inMemorySource{data: data}), nil
}

func removeWhitespace(data []byte) ([]byte, error) {
Expand All @@ -584,3 +585,58 @@ func removeWhitespace(data []byte) ([]byte, error) {
}
return b.Bytes(), nil
}

func TestCache(t *testing.T) {
// Cannot be run in parallel: changes modifiedStaleDur.
// TODO(jba): use the synctest package when pkgsite is on Go 1.24 or higher.
ctx := context.Background()
endpoint := "test/endpoint"
want := "some data"
msrc := &inMemorySource{
data: map[string][]byte{
dbEndpoint: []byte(`{"modified":"2024-07-10T17:05:50Z"}`),
"test/endpoint": []byte(strconv.Quote(want)),
},
}
c := newClient(msrc)

check := func(wantCached bool) {
t.Helper()
got, gotCached, err := get[string](ctx, c, endpoint)
if err != nil {
t.Fatal(err)
}
if got != want || gotCached != wantCached {
t.Fatalf("got (%q, %t), want (%q, %t)", got, gotCached, want, wantCached)
}
}

// The first get is not cached.
check(false)
// The next one is.
check(true)

const stale = 100 * time.Millisecond
defer func(d time.Duration) { modifiedStaleDur = d }(modifiedStaleDur)
modifiedStaleDur = stale

// The modified time is refetched when stale.
mf := c.modifiedFetched
time.Sleep(2 * stale)
// The value is still cached, because the DB's modified time hasn't changed.
check(true)
if !c.modifiedFetched.After(mf) {
t.Fatal("modifiedFetched did not advance")
}

// The DB is modified.
msrc.data[dbEndpoint] = []byte(`{"modified":"2024-08-10T17:05:50Z"}`)
// The cache doesn't notice yet.
check(true)
// The cached modified time becomes stale.
time.Sleep(2 * stale)
// The cache is cleared.
check(false)
// The cache continues to work.
check(true)
}
8 changes: 4 additions & 4 deletions internal/vuln/regexp.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,21 @@ var (
ghsaID = regexp.MustCompile(ci + ghsaRE)
)

// Canonical returns the canonical form of the given Go ID string
// CanonicalGoID returns the canonical form of the given Go ID string
// by correcting the case.
//
// If no canonical form can be found, returns false.
// If no canonical form can be found, it returns false.
func CanonicalGoID(id string) (_ string, ok bool) {
if goID.MatchString(id) {
return strings.ToUpper(id), true
}
return "", false
}

// Canonical returns the canonical form of the given alias ID string
// CanonicalAlias returns the canonical form of the given alias ID string
// (a CVE or GHSA id) by correcting the case.
//
// If no canonical form can be found, returns false.
// If no canonical form can be found, it returns false.
func CanonicalAlias(id string) (_ string, ok bool) {
if cveID.MatchString(id) {
return strings.ToUpper(id), true
Expand Down
5 changes: 4 additions & 1 deletion internal/vuln/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ package vuln

import "time"

var (
// This file describes the structure of the Go vulnerability database.
// See https://go.dev/doc/security/vuln/database.

const (
idDir = "ID"
dbEndpoint = "index/db"
modulesEndpoint = "index/modules"
Expand Down
Loading

0 comments on commit 0a279fb

Please sign in to comment.