diff --git a/internal/errgroup/errgroup.go b/internal/errgroup/errgroup.go new file mode 100644 index 0000000..ac4acfa --- /dev/null +++ b/internal/errgroup/errgroup.go @@ -0,0 +1,107 @@ +// Copyright 2023 The go-zeromq Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package errgroup is bit more advanced than golang.org/x/sync/errgroup. +// Major difference is that when error group is created with WithContext +// the parent context would implicitly cancel all functions called by Go method. +package errgroup + +import ( + "context" + + "golang.org/x/sync/errgroup" +) + +// The Group is superior errgroup.Group which aborts whole group +// execution when parent context is cancelled +type Group struct { + grp *errgroup.Group + ctx context.Context +} + +// WithContext creates Group and store inside parent context +// so the Go method would respect parent context cancellation +func WithContext(ctx context.Context) (*Group, context.Context) { + grp, child_ctx := errgroup.WithContext(ctx) + return &Group{grp: grp, ctx: ctx}, child_ctx +} + +// Go runs the provided f function in a dedicated goroutine and waits for its +// completion or for the parent context cancellation. +func (g *Group) Go(f func() error) { + g.getErrGroup().Go(g.wrap(f)) +} + +// Wait blocks until all function calls from the Go method have returned, then +// returns the first non-nil error (if any) from them. +// If the error group was created via WithContext then the Wait returns error +// of cancelled parent context prior any functions calls complete. +func (g *Group) Wait() error { + return g.getErrGroup().Wait() +} + +// SetLimit limits the number of active goroutines in this group to at most n. +// A negative value indicates no limit. +// +// Any subsequent call to the Go method will block until it can add an active +// goroutine without exceeding the configured limit. +// +// The limit must not be modified while any goroutines in the group are active. +func (g *Group) SetLimit(n int) { + g.getErrGroup().SetLimit(n) +} + +// TryGo calls the given function in a new goroutine only if the number of +// active goroutines in the group is currently below the configured limit. +// +// The return value reports whether the goroutine was started. +func (g *Group) TryGo(f func() error) bool { + return g.getErrGroup().TryGo(g.wrap(f)) +} + +func (g *Group) wrap(f func() error) func() error { + if g.ctx == nil { + return f + } + + return func() error { + // If parent context is canceled, + // just return its error and do not call func f + select { + case <-g.ctx.Done(): + return g.ctx.Err() + default: + } + + // Create return channel and call func f + // Buffered channel is used as the following select + // may be exiting by context cancellation + // and in such case the write to channel can be block + // and cause the go routine leak + ch := make(chan error, 1) + go func() { + ch <- f() + }() + + // Wait func f complete or + // parent context to be cancelled, + select { + case err := <-ch: + return err + case <-g.ctx.Done(): + return g.ctx.Err() + } + } +} + +// The getErrGroup returns actual x/sync/errgroup.Group. +// If the group is not allocated it would implicitly allocate it. +// Thats allows the internal/errgroup.Group be fully +// compatible to x/sync/errgroup.Group +func (g *Group) getErrGroup() *errgroup.Group { + if g.grp == nil { + g.grp = &errgroup.Group{} + } + return g.grp +} diff --git a/internal/errgroup/errgroup_test.go b/internal/errgroup/errgroup_test.go new file mode 100644 index 0000000..7f6b810 --- /dev/null +++ b/internal/errgroup/errgroup_test.go @@ -0,0 +1,124 @@ +// Copyright 2023 The go-zeromq Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package errgroup + +import ( + "context" + "fmt" + "testing" + + "golang.org/x/sync/errgroup" +) + +// TestRegularErrGroupDoesNotRespectParentContext checks regular errgroup behavior +// where errgroup.WithContext does not respect the parent context +func TestRegularErrGroupDoesNotRespectParentContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + eg, _ := errgroup.WithContext(ctx) + + what := fmt.Errorf("func generated error") + ch := make(chan error) + eg.Go(func() error { return <-ch }) + + cancel() // abort parent context + ch <- what // signal the func in regular errgroup to fail + err := eg.Wait() + + // The error shall be one returned by the function + // as regular errgroup.WithContext does not respect parent context + if err != what { + t.Errorf("invalid error. got=%+v, want=%+v", err, what) + } +} + +// TestErrGroupWithContextCanCallFunctions checks the errgroup operations +// are fine working and errgroup called function can return error +func TestErrGroupWithContextCanCallFunctions(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + eg, _ := WithContext(ctx) + + what := fmt.Errorf("func generated error") + ch := make(chan error) + eg.Go(func() error { return <-ch }) + + ch <- what // signal the func in errgroup to fail + err := eg.Wait() // wait errgroup complete and read error + + // The error shall be one returned by the function + if err != what { + t.Errorf("invalid error. got=%+v, want=%+v", err, what) + } +} + +// TestErrGroupWithContextDoesRespectParentContext checks the errgroup operations +// are cancellable by parent context +func TestErrGroupWithContextDoesRespectParentContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + eg, _ := WithContext(ctx) + + s1 := make(chan struct{}) + s2 := make(chan struct{}) + eg.Go(func() error { + s1 <- struct{}{} + <-s2 + return fmt.Errorf("func generated error") + }) + + // We have no set limit to errgroup so + // shall be able to start function via TryGo + if ok := eg.TryGo(func() error { return nil }); !ok { + t.Errorf("Expected TryGo to be able start function!!!") + } + + <-s1 // wait for function to start + cancel() // abort parent context + + eg.Go(func() error { + t.Errorf("The parent context was already cancelled and this function shall not be called!!!") + return nil + }) + + s2 <- struct{}{} // signal the func in regular errgroup to fail + err := eg.Wait() // wait errgroup complete and read error + + // The error shall be one returned by the function + // as regular errgroup.WithContext does not respect parent context + if err != context.Canceled { + t.Errorf("expected a context.Canceled error, got=%+v", err) + } +} + +// TestErrGroupFallback tests fallback logic to be compatible with x/sync/errgroup +func TestErrGroupFallback(t *testing.T) { + eg := Group{} + eg.SetLimit(2) + + ch1 := make(chan error) + eg.Go(func() error { return <-ch1 }) + + ch2 := make(chan error) + ok := eg.TryGo(func() error { return <-ch2 }) + if !ok { + t.Errorf("Expected errgroup.TryGo to success!!!") + } + + // The limit set to 2, so 3rd function shall not be possible to call + ok = eg.TryGo(func() error { + t.Errorf("This function is unexpected to be called!!!") + return nil + }) + if ok { + t.Errorf("Expected errgroup.TryGo to fail!!!") + } + + ch1 <- nil + ch2 <- nil + err := eg.Wait() + + if err != nil { + t.Errorf("expected a nil error, got=%+v", err) + } +} diff --git a/msgio.go b/msgio.go index 3f5e492..d648674 100644 --- a/msgio.go +++ b/msgio.go @@ -9,6 +9,7 @@ import ( "io" "sync" + errgrp "github.com/go-zeromq/zmq4/internal/errgroup" "golang.org/x/sync/errgroup" ) @@ -63,11 +64,11 @@ func (q *qreader) Close() error { } func (q *qreader) addConn(r *Conn) { - go q.listen(q.ctx, r) q.mu.Lock() q.sem.enable() q.rs = append(q.rs, r) q.mu.Unlock() + go q.listen(q.ctx, r) } func (q *qreader) rmConn(r *Conn) { @@ -167,7 +168,7 @@ func (mw *mwriter) rmConn(w *Conn) { func (w *mwriter) write(ctx context.Context, msg Msg) error { w.sem.lock(ctx) - grp, _ := errgroup.WithContext(ctx) + grp, _ := errgrp.WithContext(ctx) w.mu.Lock() for i := range w.ws { ww := w.ws[i] diff --git a/options.go b/options.go index d85b7c5..416006a 100644 --- a/options.go +++ b/options.go @@ -44,6 +44,13 @@ func WithDialerTimeout(timeout time.Duration) Option { } } +// WithTimeout sets the timeout value for socket operations +func WithTimeout(timeout time.Duration) Option { + return func(s *socket) { + s.timeout = timeout + } +} + // WithLogger sets a dedicated log.Logger for the socket. func WithLogger(msg *log.Logger) Option { return func(s *socket) { diff --git a/pub.go b/pub.go index 8e6947e..a65ac4d 100644 --- a/pub.go +++ b/pub.go @@ -39,7 +39,7 @@ func (pub *pubSocket) Close() error { // Send puts the message on the outbound send queue. // Send blocks until the message can be queued or the send deadline expires. func (pub *pubSocket) Send(msg Msg) error { - ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.timeout()) + ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.Timeout()) defer cancel() return pub.sck.w.write(ctx, msg) } @@ -49,7 +49,7 @@ func (pub *pubSocket) Send(msg Msg) error { // The message will be sent as a multipart message. func (pub *pubSocket) SendMulti(msg Msg) error { msg.multipart = true - ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.timeout()) + ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.Timeout()) defer cancel() return pub.sck.w.write(ctx, msg) } @@ -149,11 +149,11 @@ func (q *pubQReader) Close() error { } func (q *pubQReader) addConn(r *Conn) { - go q.listen(q.ctx, r) q.mu.Lock() q.sem.enable() q.rs = append(q.rs, r) q.mu.Unlock() + go q.listen(q.ctx, r) } func (q *pubQReader) rmConn(r *Conn) { diff --git a/rep.go b/rep.go index b578bad..d9a1a78 100644 --- a/rep.go +++ b/rep.go @@ -34,7 +34,7 @@ func (rep *repSocket) Close() error { // Send puts the message on the outbound send queue. // Send blocks until the message can be queued or the send deadline expires. func (rep *repSocket) Send(msg Msg) error { - ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.timeout()) + ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.Timeout()) defer cancel() return rep.sck.w.write(ctx, msg) } @@ -44,7 +44,7 @@ func (rep *repSocket) Send(msg Msg) error { // The message will be sent as a multipart message. func (rep *repSocket) SendMulti(msg Msg) error { msg.multipart = true - ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.timeout()) + ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.Timeout()) defer cancel() return rep.sck.w.write(ctx, msg) } diff --git a/rep_test.go b/rep_test.go index 9d5dba5..e449f68 100644 --- a/rep_test.go +++ b/rep_test.go @@ -134,6 +134,7 @@ func TestCancellation(t *testing.T) { defer wg.Done() repCtx, cancel := context.WithCancel(context.Background()) + defer cancel() rep := zmq4.NewRep(repCtx) defer rep.Close() diff --git a/req.go b/req.go index d0de956..2b99aed 100644 --- a/req.go +++ b/req.go @@ -35,7 +35,7 @@ func (req *reqSocket) Close() error { // Send puts the message on the outbound send queue. // Send blocks until the message can be queued or the send deadline expires. func (req *reqSocket) Send(msg Msg) error { - ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.timeout()) + ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.Timeout()) defer cancel() return req.sck.w.write(ctx, msg) } @@ -45,7 +45,7 @@ func (req *reqSocket) Send(msg Msg) error { // The message will be sent as a multipart message. func (req *reqSocket) SendMulti(msg Msg) error { msg.multipart = true - ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.timeout()) + ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.Timeout()) defer cancel() return req.sck.w.write(ctx, msg) } diff --git a/router.go b/router.go index d58de9a..aebb14a 100644 --- a/router.go +++ b/router.go @@ -35,7 +35,7 @@ func (router *routerSocket) Close() error { // Send puts the message on the outbound send queue. // Send blocks until the message can be queued or the send deadline expires. func (router *routerSocket) Send(msg Msg) error { - ctx, cancel := context.WithTimeout(router.sck.ctx, router.sck.timeout()) + ctx, cancel := context.WithTimeout(router.sck.ctx, router.sck.Timeout()) defer cancel() return router.sck.w.write(ctx, msg) } @@ -119,11 +119,11 @@ func (q *routerQReader) Close() error { } func (q *routerQReader) addConn(r *Conn) { - go q.listen(q.ctx, r) q.mu.Lock() q.sem.enable() q.rs = append(q.rs, r) q.mu.Unlock() + go q.listen(q.ctx, r) } func (q *routerQReader) rmConn(r *Conn) { diff --git a/socket.go b/socket.go index d322e9f..ca198a6 100644 --- a/socket.go +++ b/socket.go @@ -40,6 +40,7 @@ type socket struct { log *log.Logger subTopics func() []string autoReconnect bool + timeout time.Duration mu sync.RWMutex conns []*Conn // ZMTP connections @@ -67,6 +68,7 @@ func newDefaultSocket(ctx context.Context, sockType SocketType) *socket { typ: sockType, retry: defaultRetry, maxRetries: defaultMaxRetries, + timeout: defaultTimeout, sec: nullSecurity{}, conns: nil, r: newQReader(ctx), @@ -147,7 +149,7 @@ func (sck *socket) Close() error { // Send puts the message on the outbound send queue. // Send blocks until the message can be queued or the send deadline expires. func (sck *socket) Send(msg Msg) error { - ctx, cancel := context.WithTimeout(sck.ctx, sck.timeout()) + ctx, cancel := context.WithTimeout(sck.ctx, sck.Timeout()) defer cancel() return sck.w.write(ctx, msg) } @@ -157,7 +159,7 @@ func (sck *socket) Send(msg Msg) error { // The message will be sent as a multipart message. func (sck *socket) SendMulti(msg Msg) error { msg.multipart = true - ctx, cancel := context.WithTimeout(sck.ctx, sck.timeout()) + ctx, cancel := context.WithTimeout(sck.ctx, sck.Timeout()) defer cancel() return sck.w.write(ctx, msg) } @@ -365,9 +367,8 @@ func (sck *socket) SetOption(name string, value interface{}) error { return nil } -func (sck *socket) timeout() time.Duration { - // FIXME(sbinet): extract from options - return defaultTimeout +func (sck *socket) Timeout() time.Duration { + return sck.timeout } func (sck *socket) connReaper() { diff --git a/zmq4_timeout_test.go b/zmq4_timeout_test.go new file mode 100644 index 0000000..725f1a8 --- /dev/null +++ b/zmq4_timeout_test.go @@ -0,0 +1,49 @@ +// Copyright 2023 The go-zeromq Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zmq4 + +import ( + "context" + "testing" + "time" +) + +func TestPushTimeout(t *testing.T) { + ep := "ipc://@push_timeout_test" + push := NewPush(context.Background(), WithTimeout(1*time.Second)) + defer push.Close() + if err := push.Listen(ep); err != nil { + t.FailNow() + } + + pull := NewPull(context.Background()) + defer pull.Close() + if err := pull.Dial(ep); err != nil { + t.FailNow() + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + for { + select { + case <-ctx.Done(): + // The ctx limits overall time of execution + // If it gets canceled, that meain tests failed + // as write to socket did not genereate timeout error + t.Fatalf("test failed before being able to generate timeout error: %+v", ctx.Err()) + default: + } + + err := push.Send(NewMsgString("test string")) + if err == nil { + continue + } + if err != context.DeadlineExceeded { + t.Fatalf("expected a context.DeadlineExceeded error, got=%+v", err) + } + break + } + +}