diff --git a/filters/auth/authclient.go b/filters/auth/authclient.go index ca5463eade..ccc71a1ef5 100644 --- a/filters/auth/authclient.go +++ b/filters/auth/authclient.go @@ -32,6 +32,7 @@ type authClient struct { type tokeninfoClient interface { getTokeninfo(token string, ctx filters.FilterContext) (map[string]any, error) + Close() } var _ tokeninfoClient = &authClient{} diff --git a/filters/auth/main_test.go b/filters/auth/main_test.go index 4da87fe2a4..3df902c5e3 100644 --- a/filters/auth/main_test.go +++ b/filters/auth/main_test.go @@ -18,11 +18,7 @@ func TestMain(m *testing.M) { func cleanupAuthClients() { for _, c := range tokeninfoAuthClient { - if ac, ok := c.(*authClient); ok { - ac.Close() - } else if cc, ok := c.(*tokeninfoCache); ok { - cc.client.(*authClient).Close() - } + c.Close() } for _, c := range issuerAuthClient { diff --git a/filters/auth/tokeninfo.go b/filters/auth/tokeninfo.go index 53323012cb..f12f24d868 100644 --- a/filters/auth/tokeninfo.go +++ b/filters/auth/tokeninfo.go @@ -12,6 +12,7 @@ import ( "github.com/opentracing/opentracing-go" "github.com/zalando/skipper/filters" "github.com/zalando/skipper/filters/annotate" + "github.com/zalando/skipper/metrics" ) const ( @@ -32,9 +33,10 @@ type TokeninfoOptions struct { Timeout time.Duration MaxIdleConns int Tracer opentracing.Tracer + Metrics metrics.Metrics // CacheSize configures the maximum number of cached tokens. - // The cache evicts least recently used items first. + // The cache periodically evicts random items when number of cached tokens exceeds CacheSize. // Zero value disables tokeninfo cache. CacheSize int @@ -100,7 +102,7 @@ func (o *TokeninfoOptions) newTokeninfoClient() (tokeninfoClient, error) { } if o.CacheSize > 0 { - c = newTokeninfoCache(c, o.CacheSize, o.CacheTTL) + c = newTokeninfoCache(c, o.Metrics, o.CacheSize, o.CacheTTL) } return c, nil } diff --git a/filters/auth/tokeninfocache.go b/filters/auth/tokeninfocache.go index c925ae5834..443c390832 100644 --- a/filters/auth/tokeninfocache.go +++ b/filters/auth/tokeninfocache.go @@ -1,32 +1,32 @@ package auth import ( - "container/list" + "maps" "sync" + "sync/atomic" "time" "github.com/zalando/skipper/filters" + "github.com/zalando/skipper/metrics" ) type ( tokeninfoCache struct { - client tokeninfoClient - size int - ttl time.Duration - now func() time.Time - - mu sync.Mutex - cache map[string]*entry - // least recently used token at the end - history *list.List + client tokeninfoClient + metrics metrics.Metrics + size int + ttl time.Duration + now func() time.Time + + cache sync.Map // map[string]*entry + count atomic.Int64 // estimated number of cached entries, see https://github.com/golang/go/issues/20680 + quit chan struct{} } entry struct { - cachedAt time.Time - expiresAt time.Time - info map[string]any - // reference in the history - href *list.Element + expiresAt time.Time + info map[string]any + infoExpiresAt time.Time } ) @@ -34,15 +34,22 @@ var _ tokeninfoClient = &tokeninfoCache{} const expiresInField = "expires_in" -func newTokeninfoCache(client tokeninfoClient, size int, ttl time.Duration) *tokeninfoCache { - return &tokeninfoCache{ +func newTokeninfoCache(client tokeninfoClient, metrics metrics.Metrics, size int, ttl time.Duration) *tokeninfoCache { + c := &tokeninfoCache{ client: client, + metrics: metrics, size: size, ttl: ttl, now: time.Now, - cache: make(map[string]*entry, size), - history: list.New(), + quit: make(chan struct{}), } + go c.evictLoop() + return c +} + +func (c *tokeninfoCache) Close() { + c.client.Close() + close(c.quit) } func (c *tokeninfoCache) getTokeninfo(token string, ctx filters.FilterContext) (map[string]any, error) { @@ -58,35 +65,21 @@ func (c *tokeninfoCache) getTokeninfo(token string, ctx filters.FilterContext) ( } func (c *tokeninfoCache) cached(token string) map[string]any { - now := c.now() - - c.mu.Lock() - - if e, ok := c.cache[token]; ok { + if v, ok := c.cache.Load(token); ok { + now := c.now() + e := v.(*entry) if now.Before(e.expiresAt) { - c.history.MoveToFront(e.href) - cachedInfo := e.info - c.mu.Unlock() - // It might be ok to return cached value // without adjusting "expires_in" to avoid copy // if caller never modifies the result and // when "expires_in" did not change (same second) // or for small TTL values - info := shallowCopyOf(cachedInfo) + info := maps.Clone(e.info) - elapsed := now.Sub(e.cachedAt).Truncate(time.Second).Seconds() - info[expiresInField] = info[expiresInField].(float64) - elapsed + info[expiresInField] = e.infoExpiresAt.Sub(now).Truncate(time.Second).Seconds() return info - } else { - // remove expired - delete(c.cache, token) - c.history.Remove(e.href) } } - - c.mu.Unlock() - return nil } @@ -95,38 +88,61 @@ func (c *tokeninfoCache) tryCache(token string, info map[string]any) { if expiresIn <= 0 { return } - if c.ttl > 0 && expiresIn > c.ttl { - expiresIn = c.ttl - } now := c.now() - expiresAt := now.Add(expiresIn) + e := &entry{ + info: info, + infoExpiresAt: now.Add(expiresIn), + } - c.mu.Lock() - defer c.mu.Unlock() + if c.ttl > 0 && expiresIn > c.ttl { + e.expiresAt = now.Add(c.ttl) + } else { + e.expiresAt = e.infoExpiresAt + } - if e, ok := c.cache[token]; ok { - // update - e.cachedAt = now - e.expiresAt = expiresAt - e.info = info - c.history.MoveToFront(e.href) - return + if _, loaded := c.cache.Swap(token, e); !loaded { + c.count.Add(1) } +} - // create - c.cache[token] = &entry{ - cachedAt: now, - expiresAt: expiresAt, - info: info, - href: c.history.PushFront(token), +func (c *tokeninfoCache) evictLoop() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + for { + select { + case <-c.quit: + return + case <-ticker.C: + c.evict() + } } +} - // remove least used - if len(c.cache) > c.size { - leastUsed := c.history.Back() - delete(c.cache, leastUsed.Value.(string)) - c.history.Remove(leastUsed) +func (c *tokeninfoCache) evict() { + // Evict expired entries + c.cache.Range(func(key, value any) bool { + e := value.(*entry) + if c.now().After(e.expiresAt) { + if c.cache.CompareAndDelete(key, value) { + c.count.Add(-1) + } + } + return true + }) + + // Evict random entries until the cache size is within limits + if c.count.Load() > int64(c.size) { + c.cache.Range(func(key, value any) bool { + if c.cache.CompareAndDelete(key, value) { + c.count.Add(-1) + } + return c.count.Load() > int64(c.size) + }) + } + + if c.metrics != nil { + c.metrics.UpdateGauge("tokeninfocache.count", float64(c.count.Load())) } } @@ -141,11 +157,3 @@ func expiresIn(info map[string]any) time.Duration { } return 0 } - -func shallowCopyOf(info map[string]any) map[string]any { - m := make(map[string]any, len(info)) - for k, v := range info { - m[k] = v - } - return m -} diff --git a/filters/auth/tokeninfocache_test.go b/filters/auth/tokeninfocache_test.go index a5cfbf4a0e..86162207ac 100644 --- a/filters/auth/tokeninfocache_test.go +++ b/filters/auth/tokeninfocache_test.go @@ -12,6 +12,7 @@ import ( "github.com/zalando/skipper/filters" "github.com/zalando/skipper/filters/filtertest" + "github.com/zalando/skipper/metrics/metricstest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -22,6 +23,7 @@ type tokeninfoClientFunc func(string, filters.FilterContext) (map[string]any, er func (f tokeninfoClientFunc) getTokeninfo(token string, ctx filters.FilterContext) (map[string]any, error) { return f(token, ctx) } +func (f tokeninfoClientFunc) Close() {} type testTokeninfoToken string @@ -58,6 +60,7 @@ func (c *testClock) now() time.Time { func TestTokeninfoCache(t *testing.T) { const ( TokenTTLSeconds = 600 + CacheSize = 1 CacheTTL = 300 * time.Second // less than TokenTTLSeconds ) @@ -79,15 +82,19 @@ func TestTokeninfoCache(t *testing.T) { })) defer authServer.Close() + m := &metricstest.MockMetrics{} + defer m.Close() + o := TokeninfoOptions{ URL: authServer.URL + "/oauth2/tokeninfo", - CacheSize: 1, + CacheSize: CacheSize, CacheTTL: CacheTTL, + Metrics: m, } c, err := o.newTokeninfoClient() require.NoError(t, err) - defer c.(*tokeninfoCache).client.(*authClient).Close() + defer c.Close() c.(*tokeninfoCache).now = clock.now ctx := &filtertest.Context{FRequest: &http.Request{}} @@ -111,7 +118,7 @@ func TestTokeninfoCache(t *testing.T) { assert.Equal(t, int32(1), authRequests, "expected no request to auth sever") assert.Equal(t, token, info["uid"]) - assert.Equal(t, float64(595), info["expires_in"], "expected TokenTTLSeconds - truncate(delay)") + assert.Equal(t, float64(594), info["expires_in"], "expected truncate(TokenTTLSeconds - delay)") // Third request after "sleeping" longer than cache TTL clock.add(CacheTTL) @@ -123,7 +130,7 @@ func TestTokeninfoCache(t *testing.T) { assert.Equal(t, token, info["uid"]) assert.Equal(t, float64(294), info["expires_in"], "expected truncate(TokenTTLSeconds - CacheTTL - delay)") - // Fourth request with a new token evicts cached value + // Fourth request with a new token token = newTestTokeninfoToken(clock.now()).String() info, err = c.getTokeninfo(token, ctx) @@ -132,6 +139,19 @@ func TestTokeninfoCache(t *testing.T) { assert.Equal(t, int32(3), authRequests, "expected new request to auth sever") assert.Equal(t, token, info["uid"]) assert.Equal(t, float64(600), info["expires_in"], "expected TokenTTLSeconds") + + // Force eviction and verify cache size is within limits + c.(*tokeninfoCache).evict() + m.WithGauges(func(g map[string]float64) { + assert.Equal(t, float64(CacheSize), g["tokeninfocache.count"]) + }) + + // Force eviction after all entries expired and verify cache is empty + clock.add(CacheTTL + time.Second) + c.(*tokeninfoCache).evict() + m.WithGauges(func(g map[string]float64) { + assert.Equal(t, float64(0), g["tokeninfocache.count"]) + }) } // Tests race between reading and writing cache for the same token @@ -152,7 +172,8 @@ func TestTokeninfoCacheUpdateRace(t *testing.T) { return map[string]any{"requestNumber": requestNumber, "uid": token, "expires_in": float64(600)}, nil }) - c := newTokeninfoCache(mc, 1, time.Hour) + c := newTokeninfoCache(mc, nil, 1, time.Hour) + defer c.Close() const token = "atoken" @@ -234,7 +255,8 @@ func BenchmarkTokeninfoCache(b *testing.B) { return tokenValues[token], nil }) - c := newTokeninfoCache(mc, bi.cacheSize, time.Hour) + c := newTokeninfoCache(mc, nil, bi.cacheSize, time.Hour) + defer c.Close() var tokens []string for i := 0; i < bi.tokens; i++ { diff --git a/skipper.go b/skipper.go index e0a3fc8f31..7044ac6d73 100644 --- a/skipper.go +++ b/skipper.go @@ -1615,6 +1615,7 @@ func run(o Options, sig chan os.Signal, idleConnsCH chan struct{}) error { Timeout: o.OAuthTokeninfoTimeout, MaxIdleConns: o.IdleConnectionsPerHost, Tracer: tracer, + Metrics: mtr, CacheSize: o.OAuthTokeninfoCacheSize, CacheTTL: o.OAuthTokeninfoCacheTTL, }