Skip to content

Commit

Permalink
zmq4: fix race condition in Rep cancellation
Browse files Browse the repository at this point in the history
* There is a race condition in repWriter write(ctx context.Context, msg Msg):

if the run() loop of repWriter has exited because r.ctx was cancelled then <- repSendPayload will block, because nobody is reading the channel anymore

I'm adding a test that reproduces the issue with a fix.

* We should also abort when the writer context is cancelled.
  • Loading branch information
encse committed Sep 14, 2023
1 parent bd7e871 commit 683c549
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 2 deletions.
10 changes: 8 additions & 2 deletions rep.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,14 @@ func (r *repWriter) rmConn(conn *Conn) {

func (r *repWriter) write(ctx context.Context, msg Msg) error {
conn, preamble := r.state.Get()
r.sendCh <- repSendPayload{conn, preamble, msg}
return nil
select {
case <-ctx.Done():
return ctx.Err()
case <-r.ctx.Done(): // repWriter.run() terminates on this, sendCh <- will not complete
return r.ctx.Err()
case r.sendCh <- repSendPayload{conn, preamble, msg}:
return nil
}
}

func (r *repWriter) run() {
Expand Down
75 changes: 75 additions & 0 deletions rep_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package zmq4_test

import (
"context"
"errors"
"sync"
"testing"

Expand Down Expand Up @@ -92,3 +93,77 @@ func TestIssue99(t *testing.T) {
t.Fatalf("message length mismatch: got=%d, want=%d", got, want)
}
}

func TestCancellation(t *testing.T) {
// if the context is cancelled during a rep.Send both the requester and the responder should get an error
var wg sync.WaitGroup

ep, err := EndPoint("tcp")
if err != nil {
t.Fatalf("could not find endpoint: %+v", err)
}

responderStarted := make(chan bool)

requester := func() {
defer wg.Done()
<-responderStarted

req := zmq4.NewReq(context.Background())
defer req.Close()

err := req.Dial(ep)
if err != nil {
t.Errorf("could not dial: %+v", err)
return
}

err = req.Send(zmq4.NewMsgString("ping"))
if err != nil {
t.Errorf("could not send: %+v", err)
return
}

msg, err := req.Recv()
if err == nil {
t.Errorf("requester should have gotten an error, but got: %+v", msg)
}
}

responder := func() {

defer wg.Done()
repCtx, cancel := context.WithCancel(context.Background())
rep := zmq4.NewRep(repCtx)
defer rep.Close()

err := rep.Listen(ep)
if err != nil {
t.Errorf("could not dial: %+v", err)
return
}

responderStarted <- true

_, err = rep.Recv()
if err != nil {
t.Errorf("could not recv: %+v", err)
return
}

// cancel the context right before sending the response
cancel()
err = rep.Send(zmq4.NewMsgString("pong"))

if !errors.Is(err, context.Canceled) {
t.Errorf("context should be cancelled: %+v", err)
}
}

wg.Add(2)

go requester()
go responder()

wg.Wait()
}

0 comments on commit 683c549

Please sign in to comment.