diff --git a/rep.go b/rep.go index 8ca6f2a..b578bad 100644 --- a/rep.go +++ b/rep.go @@ -257,8 +257,14 @@ func (r *repWriter) rmConn(conn *Conn) { func (r *repWriter) write(ctx context.Context, msg Msg) error { conn, preamble := r.state.Get() - r.sendCh <- repSendPayload{conn, preamble, msg} - return nil + select { + case <-ctx.Done(): + return ctx.Err() + case <-r.ctx.Done(): // repWriter.run() terminates on this, sendCh <- will not complete + return r.ctx.Err() + case r.sendCh <- repSendPayload{conn, preamble, msg}: + return nil + } } func (r *repWriter) run() { diff --git a/rep_test.go b/rep_test.go index e4bcae4..9d5dba5 100644 --- a/rep_test.go +++ b/rep_test.go @@ -6,6 +6,7 @@ package zmq4_test import ( "context" + "errors" "sync" "testing" @@ -92,3 +93,77 @@ func TestIssue99(t *testing.T) { t.Fatalf("message length mismatch: got=%d, want=%d", got, want) } } + +func TestCancellation(t *testing.T) { + // if the context is cancelled during a rep.Send both the requester and the responder should get an error + var wg sync.WaitGroup + + ep, err := EndPoint("tcp") + if err != nil { + t.Fatalf("could not find endpoint: %+v", err) + } + + responderStarted := make(chan bool) + + requester := func() { + defer wg.Done() + <-responderStarted + + req := zmq4.NewReq(context.Background()) + defer req.Close() + + err := req.Dial(ep) + if err != nil { + t.Errorf("could not dial: %+v", err) + return + } + + err = req.Send(zmq4.NewMsgString("ping")) + if err != nil { + t.Errorf("could not send: %+v", err) + return + } + + msg, err := req.Recv() + if err == nil { + t.Errorf("requester should have gotten an error, but got: %+v", msg) + } + } + + responder := func() { + + defer wg.Done() + repCtx, cancel := context.WithCancel(context.Background()) + rep := zmq4.NewRep(repCtx) + defer rep.Close() + + err := rep.Listen(ep) + if err != nil { + t.Errorf("could not dial: %+v", err) + return + } + + responderStarted <- true + + _, err = rep.Recv() + if err != nil { + t.Errorf("could not recv: %+v", err) + return + } + + // cancel the context right before sending the response + cancel() + err = rep.Send(zmq4.NewMsgString("pong")) + + if !errors.Is(err, context.Canceled) { + t.Errorf("context should be cancelled: %+v", err) + } + } + + wg.Add(2) + + go requester() + go responder() + + wg.Wait() +}