Skip to content

Commit

Permalink
fix channel closer
Browse files Browse the repository at this point in the history
  • Loading branch information
hailin0 committed Jul 3, 2023
1 parent aa1c414 commit a1ab7f0
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 63 deletions.
37 changes: 15 additions & 22 deletions pkg/run/channel_closer.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,12 @@ var dummyChannelCloserChan <-chan struct{}

// ChannelCloser can close a goroutine then wait for it to stop.
type ChannelCloser struct {
ctx context.Context
cancel context.CancelFunc
sender sync.WaitGroup
receiver sync.WaitGroup
senderLock sync.RWMutex
receiverLock sync.RWMutex
senderClosed bool
receiverClosed bool
ctx context.Context
cancel context.CancelFunc
sender sync.WaitGroup
receiver sync.WaitGroup
lock sync.RWMutex
closed bool
}

// NewChannelCloser instances a new ChannelCloser.
Expand All @@ -50,9 +48,9 @@ func (c *ChannelCloser) AddRunning() bool {
if c == nil {
return false
}
c.senderLock.RLock()
defer c.senderLock.RUnlock()
if c.senderClosed {
c.lock.RLock()
defer c.lock.RUnlock()
if c.closed {
return false
}
c.sender.Add(1)
Expand Down Expand Up @@ -89,19 +87,14 @@ func (c *ChannelCloser) CloseThenWait() {
return
}

c.senderLock.Lock()
c.senderClosed = true
c.senderLock.Unlock()
c.lock.Lock()
c.closed = true
c.lock.Unlock()

c.sender.Done()
c.sender.Wait()

c.cancel()

c.receiverLock.Lock()
c.receiverClosed = true
c.receiverLock.Unlock()

c.receiver.Wait()
}

Expand All @@ -110,7 +103,7 @@ func (c *ChannelCloser) Closed() bool {
if c == nil {
return true
}
c.receiverLock.RLock()
defer c.receiverLock.RUnlock()
return c.receiverClosed
c.lock.RLock()
defer c.lock.RUnlock()
return c.closed
}
51 changes: 20 additions & 31 deletions pkg/run/channel_closer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ var _ = ginkgo.Describe("ChannelCloser", func() {
var wg sync.WaitGroup
wg.Add(groupAWorkerNum + groupBWorkerNum + 2)

chanL1 := make(chan struct{})
chanL2 := make(chan struct{})
chanA := make(chan struct{})
chanB := make(chan struct{})
chanCloser := NewChannelCloser(3)

for i := 0; i < groupAWorkerNum; i++ {
Expand All @@ -110,7 +110,7 @@ var _ = ginkgo.Describe("ChannelCloser", func() {
for {
if chanCloser.AddRunning() {
time.Sleep(5 * time.Millisecond)
chanL1 <- struct{}{}
chanA <- struct{}{}
chanCloser.RunningDone()
} else {
fmt.Printf("Stop worker - %d\n", index)
Expand All @@ -128,7 +128,7 @@ var _ = ginkgo.Describe("ChannelCloser", func() {
for {
if chanCloser.AddRunning() {
time.Sleep(5 * time.Millisecond)
chanL2 <- struct{}{}
chanB <- struct{}{}
chanCloser.RunningDone()
} else {
fmt.Printf("Stop worker - %d\n", index)
Expand All @@ -150,7 +150,7 @@ var _ = ginkgo.Describe("ChannelCloser", func() {

for {
select {
case <-chanL1:
case <-chanA:
time.Sleep(10 * time.Millisecond)
case <-chanCloser.CloseNotify():
return
Expand All @@ -170,7 +170,7 @@ var _ = ginkgo.Describe("ChannelCloser", func() {

for {
select {
case <-chanL2:
case <-chanB:
time.Sleep(10 * time.Millisecond)
case <-chanCloser.CloseNotify():
return
Expand All @@ -193,18 +193,19 @@ var _ = ginkgo.Describe("ChannelCloser", func() {

chanL1 := make(chan struct{})
chanL2 := make(chan struct{})
chanCloser := NewChannelCloser(3)
chanL1Closer := NewChannelCloser(2)
chanL2Closer := NewChannelCloser(2)

for i := 0; i < workerNum; i++ {
go func(index int) {
wg.Done()

fmt.Printf("Start worker - %d\n", index)
for {
if chanCloser.AddRunning() {
if chanL1Closer.AddRunning() {
time.Sleep(5 * time.Millisecond)
chanL1 <- struct{}{}
chanCloser.RunningDone()
chanL1Closer.RunningDone()
} else {
fmt.Printf("Stop worker - %d\n", index)
return
Expand All @@ -220,29 +221,14 @@ var _ = ginkgo.Describe("ChannelCloser", func() {

defer func() {
fmt.Printf("Stop consumer: chanL1\n")
chanCloser.Done()
chanL1Closer.Done()
}()

for {
select {
case req := <-chanL1:

ExitSendChan:
for {
select {
case chanL2 <- req:
// logical code
break ExitSendChan
default:
}
if chanCloser.Closed() {
fmt.Printf("Discard unprocessed record: %v, consumer: chanL1\n", req)
return
}
time.Sleep(10 * time.Millisecond)
}

case <-chanCloser.CloseNotify():
chanL2 <- req
case <-chanL1Closer.CloseNotify():
return
}
}
Expand All @@ -255,14 +241,14 @@ var _ = ginkgo.Describe("ChannelCloser", func() {

defer func() {
fmt.Printf("Stop consumer: chanL2\n")
chanCloser.Done()
chanL2Closer.Done()
}()

for {
select {
case <-chanL2:
time.Sleep(10 * time.Millisecond)
case <-chanCloser.CloseNotify():
case <-chanL2Closer.CloseNotify():
return
}
}
Expand All @@ -271,8 +257,11 @@ var _ = ginkgo.Describe("ChannelCloser", func() {
wg.Wait()

fmt.Printf("Start close...\n")
chanCloser.Done()
chanCloser.CloseThenWait()
chanL1Closer.Done()
chanL1Closer.CloseThenWait()

chanL2Closer.Done()
chanL2Closer.CloseThenWait()
fmt.Printf("Stop close\n")
})
})
Expand Down
25 changes: 15 additions & 10 deletions pkg/wal/wal.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ type LogEntry interface {

// log implements the WAL interface.
type log struct {
entryCloser *run.ChannelCloser
writeCloser *run.ChannelCloser
flushCloser *run.ChannelCloser
buffer buffer
logger *logger.Logger
bytesBuffer *bytes.Buffer
Expand Down Expand Up @@ -196,7 +197,8 @@ func New(path string, options *Options) (WAL, error) {
flushChannel: make(chan buffer, walOptions.FlushQueueSize),
bytesBuffer: bytes.NewBuffer([]byte{}),
timestampsBuffer: bytes.NewBuffer([]byte{}),
entryCloser: run.NewChannelCloser(3),
writeCloser: run.NewChannelCloser(2),
flushCloser: run.NewChannelCloser(2),
buffer: buffer{
timestampMap: make(map[common.SeriesIDV2][]time.Time),
valueMap: make(map[common.SeriesIDV2][]byte),
Expand All @@ -217,10 +219,10 @@ func New(path string, options *Options) (WAL, error) {
// It will return immediately when the data is written in the buffer,
// The callback function will be called when the entity is flushed on the persistent storage.
func (log *log) Write(seriesID common.SeriesIDV2, timestamp time.Time, data []byte, callback func(common.SeriesIDV2, time.Time, []byte, error)) {
if !log.entryCloser.AddRunning() {
if !log.writeCloser.AddRunning() {
return
}
defer log.entryCloser.RunningDone()
defer log.writeCloser.RunningDone()

log.writeChannel <- logRequest{
seriesID: seriesID,
Expand Down Expand Up @@ -305,8 +307,11 @@ func (log *log) Close() error {
log.closerOnce.Do(func() {
log.logger.Info().Msg("Closing WAL...")

log.entryCloser.Done()
log.entryCloser.CloseThenWait()
log.writeCloser.Done()
log.writeCloser.CloseThenWait()

log.flushCloser.Done()
log.flushCloser.CloseThenWait()

if err := log.flushBuffer(log.buffer); err != nil {
globalErr = multierr.Append(globalErr, err)
Expand All @@ -323,7 +328,7 @@ func (log *log) start() {
go func() {
log.logger.Info().Msg("Start batch task...")

defer log.entryCloser.Done()
defer log.writeCloser.Done()

bufferVolume := 0
for {
Expand Down Expand Up @@ -352,7 +357,7 @@ func (log *log) start() {
}
log.triggerFlushing()
bufferVolume = 0
case <-log.entryCloser.CloseNotify():
case <-log.writeCloser.CloseNotify():
timer.Stop()
log.logger.Info().Msg("Stop batch task when close notify")
return
Expand All @@ -363,7 +368,7 @@ func (log *log) start() {
go func() {
log.logger.Info().Msg("Start flush task...")

defer log.entryCloser.Done()
defer log.flushCloser.Done()

for {
select {
Expand All @@ -389,7 +394,7 @@ func (log *log) start() {
}

batch.notifyRequests(err)
case <-log.entryCloser.CloseNotify():
case <-log.flushCloser.CloseNotify():
log.logger.Info().Msg("Stop flush task when close notify")
return
}
Expand Down

0 comments on commit a1ab7f0

Please sign in to comment.