-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix Track data race issue by using channels
- Loading branch information
Showing
3 changed files
with
186 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,106 @@ | ||
package track | ||
|
||
import ( | ||
"sync" | ||
"context" | ||
"fmt" | ||
"time" | ||
) | ||
|
||
type TickFunc func(key string, fn func()) | ||
|
||
type trackCmd int | ||
|
||
const ( | ||
call trackCmd = iota | ||
remove | ||
execute | ||
) | ||
|
||
func (c trackCmd) String() string { | ||
switch c { | ||
case call: | ||
return "call" | ||
case remove: | ||
return "remove" | ||
case execute: | ||
return "execute" | ||
default: | ||
return "unknown" | ||
} | ||
} | ||
|
||
type track struct { | ||
key string | ||
cmd trackCmd | ||
timer *time.Timer | ||
count int | ||
fn func() | ||
} | ||
|
||
func Create(d time.Duration, n int) TickFunc { | ||
var mu sync.Mutex | ||
var tracks = make(map[string]*track) | ||
func (t *track) String() string { | ||
return fmt.Sprintf("key: %s, cmd: %s, count: %d", t.key, t.cmd, t.count) | ||
} | ||
|
||
return func(key string, fn func()) { | ||
mu.Lock() | ||
t, ok := tracks[key] | ||
if !ok { | ||
t = &track{ | ||
count: n, | ||
func Create(ctx context.Context, d time.Duration, n int) TickFunc { | ||
ch := make(chan *track, 10) | ||
tracks := make(map[string]*track) | ||
|
||
go func() { | ||
for { | ||
select { | ||
case <-ctx.Done(): | ||
return | ||
case t := <-ch: | ||
switch t.cmd { | ||
case call: | ||
curr, ok := tracks[t.key] | ||
if ok { | ||
curr.timer.Stop() | ||
curr.fn = t.fn | ||
} else { | ||
tracks[t.key] = t | ||
curr = t | ||
} | ||
|
||
curr.count-- | ||
if curr.count <= 0 { | ||
curr.timer.Stop() | ||
ch <- &track{ | ||
key: t.key, | ||
cmd: execute, | ||
} | ||
} else { | ||
curr.timer.Reset(d) | ||
} | ||
case remove: | ||
curr, ok := tracks[t.key] | ||
if ok { | ||
curr.timer.Stop() | ||
delete(tracks, t.key) | ||
} | ||
case execute: | ||
curr, ok := tracks[t.key] | ||
if ok { | ||
delete(tracks, t.key) | ||
curr.fn() | ||
} | ||
} | ||
} | ||
t.timer = time.AfterFunc(d, func() { | ||
mu.Lock() | ||
delete(tracks, key) | ||
mu.Unlock() | ||
t.fn() | ||
}) | ||
tracks[key] = t // No need to lock/unlock here | ||
} | ||
mu.Unlock() | ||
t.fn = fn | ||
t.count-- | ||
|
||
if t.count == 0 { | ||
t.timer.Stop() | ||
mu.Lock() | ||
delete(tracks, key) | ||
mu.Unlock() | ||
go fn() | ||
}() | ||
|
||
return func(key string, fn func()) { | ||
ch <- &track{ | ||
key: key, | ||
cmd: call, | ||
fn: fn, | ||
count: n, | ||
timer: time.AfterFunc(d, func() { | ||
ch <- &track{ | ||
key: key, | ||
cmd: execute, | ||
} | ||
}), | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
package track_test | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"sync" | ||
"sync/atomic" | ||
"testing" | ||
"time" | ||
|
||
"ella.to/bus/internal/track" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestTrack(t *testing.T) { | ||
const timeout = 500 * time.Millisecond | ||
const n = 100 | ||
const key = "key" | ||
|
||
ctx, cancel := context.WithCancel(context.Background()) | ||
defer cancel() | ||
|
||
fn := track.Create(ctx, timeout, n) | ||
|
||
for i := 0; i < n; i++ { | ||
fn(key, func() { | ||
fmt.Println("key Called on", i) | ||
}) | ||
} | ||
|
||
wg := sync.WaitGroup{} | ||
wg.Add(1) | ||
|
||
start := time.Now() | ||
fn(key, func() { | ||
defer wg.Done() | ||
fmt.Println("key Called on", time.Now().Sub(start).Seconds()) | ||
}) | ||
|
||
wg.Wait() | ||
} | ||
|
||
func TestTrackShouldBeCalledOnce(t *testing.T) { | ||
const timeout = 500 * time.Millisecond | ||
const n = 100 | ||
const key = "key" | ||
|
||
ctx, cancel := context.WithCancel(context.Background()) | ||
defer cancel() | ||
|
||
fn := track.Create(ctx, timeout, n) | ||
|
||
for i := 0; i < n; i++ { | ||
fn(key, func() { | ||
fmt.Println("key Called on", i) | ||
}) | ||
} | ||
|
||
var counter int64 | ||
|
||
start := time.Now() | ||
fn(key, func() { | ||
atomic.AddInt64(&counter, 1) | ||
fmt.Println("key Called on", time.Since(start).Seconds()) | ||
}) | ||
|
||
time.Sleep(2 * time.Second) | ||
assert.Equal(t, int64(1), atomic.LoadInt64(&counter)) | ||
} | ||
|
||
func TestTrackUnderLoad(t *testing.T) { | ||
const timeout = 500 * time.Millisecond | ||
const n = 10 | ||
const workerCount = 3 | ||
const workerCalls = 4 | ||
ctx, cancel := context.WithCancel(context.Background()) | ||
defer cancel() | ||
|
||
fn := track.Create(ctx, timeout, n) | ||
|
||
var wg sync.WaitGroup | ||
|
||
wg.Add(workerCount) | ||
|
||
for i := 0; i < workerCount; i++ { | ||
go func() { | ||
defer wg.Done() | ||
for j := 0; j < workerCalls; j++ { | ||
fn("key", func() { | ||
fmt.Printf("@@@@@key from %d Called on %d\n", i, j) | ||
}) | ||
} | ||
}() | ||
} | ||
|
||
wg.Wait() | ||
|
||
time.Sleep(timeout * 2) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters