diff --git a/internal/track/track.go b/internal/track/track.go index 2e971ce..6ad0e8d 100644 --- a/internal/track/track.go +++ b/internal/track/track.go @@ -1,47 +1,106 @@ package track import ( - "sync" + "context" + "fmt" "time" ) type TickFunc func(key string, fn func()) +type trackCmd int + +const ( + call trackCmd = iota + remove + execute +) + +func (c trackCmd) String() string { + switch c { + case call: + return "call" + case remove: + return "remove" + case execute: + return "execute" + default: + return "unknown" + } +} + type track struct { + key string + cmd trackCmd timer *time.Timer count int fn func() } -func Create(d time.Duration, n int) TickFunc { - var mu sync.Mutex - var tracks = make(map[string]*track) +func (t *track) String() string { + return fmt.Sprintf("key: %s, cmd: %s, count: %d", t.key, t.cmd, t.count) +} - return func(key string, fn func()) { - mu.Lock() - t, ok := tracks[key] - if !ok { - t = &track{ - count: n, +func Create(ctx context.Context, d time.Duration, n int) TickFunc { + ch := make(chan *track, 10) + tracks := make(map[string]*track) + + go func() { + for { + select { + case <-ctx.Done(): + return + case t := <-ch: + switch t.cmd { + case call: + curr, ok := tracks[t.key] + if ok { + curr.timer.Stop() + curr.fn = t.fn + } else { + tracks[t.key] = t + curr = t + } + + curr.count-- + if curr.count <= 0 { + curr.timer.Stop() + ch <- &track{ + key: t.key, + cmd: execute, + } + } else { + curr.timer.Reset(d) + } + case remove: + curr, ok := tracks[t.key] + if ok { + curr.timer.Stop() + delete(tracks, t.key) + } + case execute: + curr, ok := tracks[t.key] + if ok { + delete(tracks, t.key) + curr.fn() + } + } } - t.timer = time.AfterFunc(d, func() { - mu.Lock() - delete(tracks, key) - mu.Unlock() - t.fn() - }) - tracks[key] = t // No need to lock/unlock here } - mu.Unlock() - t.fn = fn - t.count-- - - if t.count == 0 { - t.timer.Stop() - mu.Lock() - delete(tracks, key) - mu.Unlock() - go fn() + }() + + return func(key string, fn func()) { + ch <- &track{ + key: key, + cmd: call, + fn: fn, + count: n, + timer: time.AfterFunc(d, func() { + ch <- &track{ + key: key, + cmd: execute, + } + }), } } } diff --git a/internal/track/track_test.go b/internal/track/track_test.go new file mode 100644 index 0000000..dc0e63b --- /dev/null +++ b/internal/track/track_test.go @@ -0,0 +1,99 @@ +package track_test + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "ella.to/bus/internal/track" + "github.com/stretchr/testify/assert" +) + +func TestTrack(t *testing.T) { + const timeout = 500 * time.Millisecond + const n = 100 + const key = "key" + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + fn := track.Create(ctx, timeout, n) + + for i := 0; i < n; i++ { + fn(key, func() { + fmt.Println("key Called on", i) + }) + } + + wg := sync.WaitGroup{} + wg.Add(1) + + start := time.Now() + fn(key, func() { + defer wg.Done() + fmt.Println("key Called on", time.Now().Sub(start).Seconds()) + }) + + wg.Wait() +} + +func TestTrackShouldBeCalledOnce(t *testing.T) { + const timeout = 500 * time.Millisecond + const n = 100 + const key = "key" + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + fn := track.Create(ctx, timeout, n) + + for i := 0; i < n; i++ { + fn(key, func() { + fmt.Println("key Called on", i) + }) + } + + var counter int64 + + start := time.Now() + fn(key, func() { + atomic.AddInt64(&counter, 1) + fmt.Println("key Called on", time.Since(start).Seconds()) + }) + + time.Sleep(2 * time.Second) + assert.Equal(t, int64(1), atomic.LoadInt64(&counter)) +} + +func TestTrackUnderLoad(t *testing.T) { + const timeout = 500 * time.Millisecond + const n = 10 + const workerCount = 3 + const workerCalls = 4 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + fn := track.Create(ctx, timeout, n) + + var wg sync.WaitGroup + + wg.Add(workerCount) + + for i := 0; i < workerCount; i++ { + go func() { + defer wg.Done() + for j := 0; j < workerCalls; j++ { + fn("key", func() { + fmt.Printf("@@@@@key from %d Called on %d\n", i, j) + }) + } + }() + } + + wg.Wait() + + time.Sleep(timeout * 2) +} diff --git a/server/server.go b/server/server.go index de09999..13589c5 100644 --- a/server/server.go +++ b/server/server.go @@ -374,7 +374,7 @@ func New(ctx context.Context, opts ...Opt) (*Server, error) { s := &Server{ consumersMap: bus.NewConsumersEventMap(conf.consumerQueueSize), incomingEvents: make(chan *incomingEvent, conf.incomingEventsBufferSize), - tick: track.Create(conf.tickTimeout, conf.tickSize), + tick: track.Create(ctx, conf.tickTimeout, conf.tickSize), closeSignal: make(chan struct{}), }