diff --git a/loadbalancer/fadein_test.go b/loadbalancer/fadein_test.go index c05311fe23..a06532295c 100644 --- a/loadbalancer/fadein_test.go +++ b/loadbalancer/fadein_test.go @@ -62,17 +62,18 @@ func initializeEndpoints(endpointAges []time.Duration, fadeInDuration time.Durat LBFadeInDuration: fadeInDuration, LBFadeInExponent: 1, }, - Registry: routing.NewEndpointRegistry(routing.RegistryOptions{}), } + detected := map[string]time.Time{} for i := range eps { ctx.Route.LBEndpoints = append(ctx.Route.LBEndpoints, routing.LBEndpoint{ Host: eps[i], Detected: detectionTimes[i], }) - ctx.Registry.SetDetectedTime(eps[i], detectionTimes[i]) + detected[eps[i]] = detectionTimes[i] } ctx.LBEndpoints = ctx.Route.LBEndpoints + ctx.Registry = routing.NewEndpointRegistry(routing.RegistryOptions{Detected: detected}) return ctx, eps } @@ -326,14 +327,15 @@ func benchmarkFadeIn( LBFadeInDuration: fadeInDuration, LBFadeInExponent: 1, } - registry := routing.NewEndpointRegistry(routing.RegistryOptions{}) + detected := map[string]time.Time{} for i := range eps { route.LBEndpoints = append(route.LBEndpoints, routing.LBEndpoint{ Host: eps[i], Detected: detectionTimes[i], }) - registry.SetDetectedTime(eps[i], detectionTimes[i]) + detected[eps[i]] = detectionTimes[i] } + registry := routing.NewEndpointRegistry(routing.RegistryOptions{Detected: detected}) var wg sync.WaitGroup diff --git a/routing/endpointregistry.go b/routing/endpointregistry.go index 2d1679ffa9..c30253b005 100644 --- a/routing/endpointregistry.go +++ b/routing/endpointregistry.go @@ -2,6 +2,7 @@ package routing import ( "sync" + "sync/atomic" "time" "github.com/zalando/skipper/eskip" @@ -17,14 +18,22 @@ type Metrics interface { } type entry struct { - detected time.Time + detected atomic.Value // time.Time + lastSeen atomic.Value // time.Time inflightRequests int64 } var _ Metrics = &entry{} +func newEntry() (result *entry) { + result = &entry{} + result.detected.Store(time.Time{}) + result.lastSeen.Store(time.Time{}) + return +} + func (e *entry) DetectedTime() time.Time { - return e.detected + return e.detected.Load().(time.Time) } func (e *entry) InflightRequests() int64 { @@ -32,19 +41,18 @@ func (e *entry) InflightRequests() int64 { } type EndpointRegistry struct { - lastSeen map[string]time.Time lastSeenTimeout time.Duration now func() time.Time - mu sync.Mutex - - data map[string]*entry + // map[string]*entry + data sync.Map } var _ PostProcessor = &EndpointRegistry{} type RegistryOptions struct { LastSeenTimeout time.Duration + Detected map[string]time.Time } func (r *EndpointRegistry) Do(routes []*Route) []*Route { @@ -55,24 +63,22 @@ func (r *EndpointRegistry) Do(routes []*Route) []*Route { for _, epi := range route.LBEndpoints { metrics := r.GetMetrics(epi.Host) if metrics.DetectedTime().IsZero() { - r.SetDetectedTime(epi.Host, now) + r.setDetectedTime(epi.Host, now) } - r.lastSeen[epi.Host] = now + r.setLastSeenTime(epi.Host, now) } } } - for host, ts := range r.lastSeen { - if ts.Add(r.lastSeenTimeout).Before(now) { - r.mu.Lock() - if r.data[host].inflightRequests == 0 { - delete(r.lastSeen, host) - delete(r.data, host) - } - r.mu.Unlock() + r.data.Range(func(key, value any) bool { + e := value.(*entry) + if e.lastSeen.Load().(time.Time).Add(r.lastSeenTimeout).Before(now) && atomic.LoadInt64(&e.inflightRequests) == 0 { + r.data.Delete(key) } - } + + return true + }) return routes } @@ -82,58 +88,49 @@ func NewEndpointRegistry(o RegistryOptions) *EndpointRegistry { o.LastSeenTimeout = defaultLastSeenTimeout } - return &EndpointRegistry{ - data: map[string]*entry{}, - lastSeen: map[string]time.Time{}, + result := &EndpointRegistry{ + data: sync.Map{}, lastSeenTimeout: o.LastSeenTimeout, now: time.Now, } + + for host, detected := range o.Detected { + e := &entry{} + e.detected.Store(detected) + e.lastSeen.Store(result.now()) + result.data.Store(host, e) + } + + return result } func (r *EndpointRegistry) GetMetrics(key string) Metrics { - r.mu.Lock() - defer r.mu.Unlock() + e, _ := r.data.LoadOrStore(key, newEntry()) + old := e.(*entry) - e := r.getOrInitEntryLocked(key) - copy := &entry{} - *copy = *e + copy := newEntry() + copy.detected.Store(old.detected.Load()) + atomic.StoreInt64(©.inflightRequests, atomic.LoadInt64(&old.inflightRequests)) + // no need to copy lastSeen, because it is needed only for endpointregistry itself return copy } -func (r *EndpointRegistry) SetDetectedTime(key string, detected time.Time) { - r.mu.Lock() - defer r.mu.Unlock() +func (r *EndpointRegistry) setDetectedTime(key string, detected time.Time) { + e, _ := r.data.LoadOrStore(key, newEntry()) + e.(*entry).detected.Store(detected) +} - e := r.getOrInitEntryLocked(key) - e.detected = detected +func (r *EndpointRegistry) setLastSeenTime(key string, ts time.Time) { + e, _ := r.data.LoadOrStore(key, newEntry()) + e.(*entry).lastSeen.Store(ts) } func (r *EndpointRegistry) IncInflightRequest(key string) { - r.mu.Lock() - defer r.mu.Unlock() - - e := r.getOrInitEntryLocked(key) - e.inflightRequests++ + e, _ := r.data.LoadOrStore(key, newEntry()) + atomic.AddInt64(&e.(*entry).inflightRequests, 1) } func (r *EndpointRegistry) DecInflightRequest(key string) { - r.mu.Lock() - defer r.mu.Unlock() - - e := r.getOrInitEntryLocked(key) - e.inflightRequests-- -} - -// getOrInitEntryLocked returns pointer to endpoint registry entry -// which contains the information about endpoint representing the -// following key. r.mu must be held while calling this function and -// using of the entry returned. In general, key represents the "host:port" -// string -func (r *EndpointRegistry) getOrInitEntryLocked(key string) *entry { - e, ok := r.data[key] - if !ok { - e = &entry{} - r.data[key] = e - } - return e + e, _ := r.data.LoadOrStore(key, newEntry()) + atomic.AddInt64(&e.(*entry).inflightRequests, -1) } diff --git a/routing/endpointregistry_test.go b/routing/endpointregistry_test.go index 1256da4d17..d11f3f0cb4 100644 --- a/routing/endpointregistry_test.go +++ b/routing/endpointregistry_test.go @@ -20,6 +20,15 @@ func TestEmptyRegistry(t *testing.T) { assert.Equal(t, int64(0), m.InflightRequests()) } +func TestRegistryWithInitData(t *testing.T) { + now := time.Now() + r := routing.NewEndpointRegistry(routing.RegistryOptions{Detected: map[string]time.Time{"some key": now}}) + m := r.GetMetrics("some key") + + assert.Equal(t, now, m.DetectedTime()) + assert.Equal(t, int64(0), m.InflightRequests()) +} + func TestSetAndGet(t *testing.T) { r := routing.NewEndpointRegistry(routing.RegistryOptions{}) @@ -29,14 +38,6 @@ func TestSetAndGet(t *testing.T) { assert.Equal(t, int64(0), mBefore.InflightRequests()) assert.Equal(t, int64(1), mAfter.InflightRequests()) - - ts, _ := time.Parse(time.DateOnly, "2023-08-29") - mBefore = r.GetMetrics("some key") - r.SetDetectedTime("some key", ts) - mAfter = r.GetMetrics("some key") - - assert.Equal(t, time.Time{}, mBefore.DetectedTime()) - assert.Equal(t, ts, mAfter.DetectedTime()) } func TestSetAndGetAnotherKey(t *testing.T) { @@ -101,6 +102,38 @@ func TestDoRemovesOldEntries(t *testing.T) { assert.Equal(t, int64(0), mRemoved.InflightRequests()) } +func TestRaceReadWrite(t *testing.T) { + r := routing.NewEndpointRegistry(routing.RegistryOptions{}) + + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + r.GetMetrics("some key") + }() + go func() { + defer wg.Done() + r.IncInflightRequest("some key") + }() + wg.Wait() +} + +func TestRaceTwoWriters(t *testing.T) { + r := routing.NewEndpointRegistry(routing.RegistryOptions{}) + + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + r.DecInflightRequest("some key") + }() + go func() { + defer wg.Done() + r.IncInflightRequest("some key") + }() + wg.Wait() +} + func printTotalMutexWaitTime(b *testing.B) { // Name of the metric we want to read. const myMetric = "/sync/mutex/wait/total:seconds"