Skip to content

Commit

Permalink
Use sync.Map in the endpointregistry
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Zavodskikh <[email protected]>
  • Loading branch information
Roman Zavodskikh committed Dec 6, 2023
1 parent 2fd26a2 commit 3ac5d78
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 66 deletions.
10 changes: 6 additions & 4 deletions loadbalancer/fadein_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,18 @@ func initializeEndpoints(endpointAges []time.Duration, fadeInDuration time.Durat
LBFadeInDuration: fadeInDuration,
LBFadeInExponent: 1,
},
Registry: routing.NewEndpointRegistry(routing.RegistryOptions{}),
}

detected := map[string]time.Time{}
for i := range eps {
ctx.Route.LBEndpoints = append(ctx.Route.LBEndpoints, routing.LBEndpoint{
Host: eps[i],
Detected: detectionTimes[i],
})
ctx.Registry.SetDetectedTime(eps[i], detectionTimes[i])
detected[eps[i]] = detectionTimes[i]
}
ctx.LBEndpoints = ctx.Route.LBEndpoints
ctx.Registry = routing.NewEndpointRegistry(routing.RegistryOptions{Detected: detected})

return ctx, eps
}
Expand Down Expand Up @@ -326,14 +327,15 @@ func benchmarkFadeIn(
LBFadeInDuration: fadeInDuration,
LBFadeInExponent: 1,
}
registry := routing.NewEndpointRegistry(routing.RegistryOptions{})
detected := map[string]time.Time{}
for i := range eps {
route.LBEndpoints = append(route.LBEndpoints, routing.LBEndpoint{
Host: eps[i],
Detected: detectionTimes[i],
})
registry.SetDetectedTime(eps[i], detectionTimes[i])
detected[eps[i]] = detectionTimes[i]
}
registry := routing.NewEndpointRegistry(routing.RegistryOptions{Detected: detected})

var wg sync.WaitGroup

Expand Down
105 changes: 51 additions & 54 deletions routing/endpointregistry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package routing

import (
"sync"
"sync/atomic"
"time"

"github.com/zalando/skipper/eskip"
Expand All @@ -17,34 +18,41 @@ type Metrics interface {
}

type entry struct {
detected time.Time
detected atomic.Value // time.Time
lastSeen atomic.Value // time.Time
inflightRequests int64
}

var _ Metrics = &entry{}

func newEntry() (result *entry) {
result = &entry{}
result.detected.Store(time.Time{})
result.lastSeen.Store(time.Time{})
return
}

func (e *entry) DetectedTime() time.Time {
return e.detected
return e.detected.Load().(time.Time)
}

func (e *entry) InflightRequests() int64 {
return e.inflightRequests
}

type EndpointRegistry struct {
lastSeen map[string]time.Time
lastSeenTimeout time.Duration
now func() time.Time

mu sync.Mutex

data map[string]*entry
// map[string]*entry
data sync.Map
}

var _ PostProcessor = &EndpointRegistry{}

type RegistryOptions struct {
LastSeenTimeout time.Duration
Detected map[string]time.Time
}

func (r *EndpointRegistry) Do(routes []*Route) []*Route {
Expand All @@ -55,24 +63,22 @@ func (r *EndpointRegistry) Do(routes []*Route) []*Route {
for _, epi := range route.LBEndpoints {
metrics := r.GetMetrics(epi.Host)
if metrics.DetectedTime().IsZero() {
r.SetDetectedTime(epi.Host, now)
r.setDetectedTime(epi.Host, now)
}

r.lastSeen[epi.Host] = now
r.setLastSeenTime(epi.Host, now)
}
}
}

for host, ts := range r.lastSeen {
if ts.Add(r.lastSeenTimeout).Before(now) {
r.mu.Lock()
if r.data[host].inflightRequests == 0 {
delete(r.lastSeen, host)
delete(r.data, host)
}
r.mu.Unlock()
r.data.Range(func(key, value any) bool {
e := value.(*entry)
if e.lastSeen.Load().(time.Time).Add(r.lastSeenTimeout).Before(now) && atomic.LoadInt64(&e.inflightRequests) == 0 {
r.data.Delete(key)
}
}

return true
})

return routes
}
Expand All @@ -82,58 +88,49 @@ func NewEndpointRegistry(o RegistryOptions) *EndpointRegistry {
o.LastSeenTimeout = defaultLastSeenTimeout
}

return &EndpointRegistry{
data: map[string]*entry{},
lastSeen: map[string]time.Time{},
result := &EndpointRegistry{
data: sync.Map{},
lastSeenTimeout: o.LastSeenTimeout,
now: time.Now,
}

for host, detected := range o.Detected {
e := &entry{}
e.detected.Store(detected)
e.lastSeen.Store(result.now())
result.data.Store(host, e)
}

return result
}

func (r *EndpointRegistry) GetMetrics(key string) Metrics {
r.mu.Lock()
defer r.mu.Unlock()
e, _ := r.data.LoadOrStore(key, newEntry())
old := e.(*entry)

e := r.getOrInitEntryLocked(key)
copy := &entry{}
*copy = *e
copy := newEntry()
copy.detected.Store(old.detected.Load())
atomic.StoreInt64(&copy.inflightRequests, atomic.LoadInt64(&old.inflightRequests))
// no need to copy lastSeen, because it is needed only for endpointregistry itself
return copy
}

func (r *EndpointRegistry) SetDetectedTime(key string, detected time.Time) {
r.mu.Lock()
defer r.mu.Unlock()
func (r *EndpointRegistry) setDetectedTime(key string, detected time.Time) {
e, _ := r.data.LoadOrStore(key, newEntry())
e.(*entry).detected.Store(detected)
}

e := r.getOrInitEntryLocked(key)
e.detected = detected
func (r *EndpointRegistry) setLastSeenTime(key string, ts time.Time) {
e, _ := r.data.LoadOrStore(key, newEntry())
e.(*entry).lastSeen.Store(ts)
}

func (r *EndpointRegistry) IncInflightRequest(key string) {
r.mu.Lock()
defer r.mu.Unlock()

e := r.getOrInitEntryLocked(key)
e.inflightRequests++
e, _ := r.data.LoadOrStore(key, newEntry())
atomic.AddInt64(&e.(*entry).inflightRequests, 1)
}

func (r *EndpointRegistry) DecInflightRequest(key string) {
r.mu.Lock()
defer r.mu.Unlock()

e := r.getOrInitEntryLocked(key)
e.inflightRequests--
}

// getOrInitEntryLocked returns pointer to endpoint registry entry
// which contains the information about endpoint representing the
// following key. r.mu must be held while calling this function and
// using of the entry returned. In general, key represents the "host:port"
// string
func (r *EndpointRegistry) getOrInitEntryLocked(key string) *entry {
e, ok := r.data[key]
if !ok {
e = &entry{}
r.data[key] = e
}
return e
e, _ := r.data.LoadOrStore(key, newEntry())
atomic.AddInt64(&e.(*entry).inflightRequests, -1)
}
33 changes: 25 additions & 8 deletions routing/endpointregistry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ func TestEmptyRegistry(t *testing.T) {
assert.Equal(t, int64(0), m.InflightRequests())
}

func TestRegistryWithInitData(t *testing.T) {
now := time.Now()
r := routing.NewEndpointRegistry(routing.RegistryOptions{Detected: map[string]time.Time{"some key": now}})
m := r.GetMetrics("some key")

assert.Equal(t, now, m.DetectedTime())
assert.Equal(t, int64(0), m.InflightRequests())
}

func TestSetAndGet(t *testing.T) {
r := routing.NewEndpointRegistry(routing.RegistryOptions{})

Expand All @@ -29,14 +38,6 @@ func TestSetAndGet(t *testing.T) {

assert.Equal(t, int64(0), mBefore.InflightRequests())
assert.Equal(t, int64(1), mAfter.InflightRequests())

ts, _ := time.Parse(time.DateOnly, "2023-08-29")
mBefore = r.GetMetrics("some key")
r.SetDetectedTime("some key", ts)
mAfter = r.GetMetrics("some key")

assert.Equal(t, time.Time{}, mBefore.DetectedTime())
assert.Equal(t, ts, mAfter.DetectedTime())
}

func TestSetAndGetAnotherKey(t *testing.T) {
Expand Down Expand Up @@ -101,6 +102,22 @@ func TestDoRemovesOldEntries(t *testing.T) {
assert.Equal(t, int64(0), mRemoved.InflightRequests())
}

func TestRace(t *testing.T) {
r := routing.NewEndpointRegistry(routing.RegistryOptions{})

wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
r.GetMetrics("some key")
}()
go func() {
defer wg.Done()
r.IncInflightRequest("some key")
}()
wg.Wait()
}

func printTotalMutexWaitTime(b *testing.B) {
// Name of the metric we want to read.
const myMetric = "/sync/mutex/wait/total:seconds"
Expand Down

0 comments on commit 3ac5d78

Please sign in to comment.