From 683c549a24d100e955f893f62054cab63e7e4822 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1vid=20N=C3=A9meth=20Cs?= Date: Thu, 14 Sep 2023 09:39:27 +0200 Subject: [PATCH] zmq4: fix race condition in Rep cancellation * There is a race condition in repWriter write(ctx context.Context, msg Msg): if the run() loop of repWriter has exited because r.ctx was cancelled then <- repSendPayload will block, because nobody is reading the channel anymore I'm adding a test that reproduces the issue with a fix. * We should also abort when the writer context is cancelled. --- rep.go | 10 +++++-- rep_test.go | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/rep.go b/rep.go index 8ca6f2a..b578bad 100644 --- a/rep.go +++ b/rep.go @@ -257,8 +257,14 @@ func (r *repWriter) rmConn(conn *Conn) { func (r *repWriter) write(ctx context.Context, msg Msg) error { conn, preamble := r.state.Get() - r.sendCh <- repSendPayload{conn, preamble, msg} - return nil + select { + case <-ctx.Done(): + return ctx.Err() + case <-r.ctx.Done(): // repWriter.run() terminates on this, sendCh <- will not complete + return r.ctx.Err() + case r.sendCh <- repSendPayload{conn, preamble, msg}: + return nil + } } func (r *repWriter) run() { diff --git a/rep_test.go b/rep_test.go index e4bcae4..9d5dba5 100644 --- a/rep_test.go +++ b/rep_test.go @@ -6,6 +6,7 @@ package zmq4_test import ( "context" + "errors" "sync" "testing" @@ -92,3 +93,77 @@ func TestIssue99(t *testing.T) { t.Fatalf("message length mismatch: got=%d, want=%d", got, want) } } + +func TestCancellation(t *testing.T) { + // if the context is cancelled during a rep.Send both the requester and the responder should get an error + var wg sync.WaitGroup + + ep, err := EndPoint("tcp") + if err != nil { + t.Fatalf("could not find endpoint: %+v", err) + } + + responderStarted := make(chan bool) + + requester := func() { + defer wg.Done() + <-responderStarted + + req := zmq4.NewReq(context.Background()) + defer req.Close() + + err := req.Dial(ep) + if err != nil { + t.Errorf("could not dial: %+v", err) + return + } + + err = req.Send(zmq4.NewMsgString("ping")) + if err != nil { + t.Errorf("could not send: %+v", err) + return + } + + msg, err := req.Recv() + if err == nil { + t.Errorf("requester should have gotten an error, but got: %+v", msg) + } + } + + responder := func() { + + defer wg.Done() + repCtx, cancel := context.WithCancel(context.Background()) + rep := zmq4.NewRep(repCtx) + defer rep.Close() + + err := rep.Listen(ep) + if err != nil { + t.Errorf("could not dial: %+v", err) + return + } + + responderStarted <- true + + _, err = rep.Recv() + if err != nil { + t.Errorf("could not recv: %+v", err) + return + } + + // cancel the context right before sending the response + cancel() + err = rep.Send(zmq4.NewMsgString("pong")) + + if !errors.Is(err, context.Canceled) { + t.Errorf("context should be cancelled: %+v", err) + } + } + + wg.Add(2) + + go requester() + go responder() + + wg.Wait() +}