Skip to content

Commit

Permalink
zmq4: add timeout support on send
Browse files Browse the repository at this point in the history
Add internal/errgroup package to support cancellable error groups.

Fixes #147.

Authored-by: Sergey Egorov <[email protected]>
  • Loading branch information
egorse committed Dec 15, 2023
1 parent e16dc3e commit 16ca7c0
Show file tree
Hide file tree
Showing 11 changed files with 306 additions and 16 deletions.
107 changes: 107 additions & 0 deletions internal/errgroup/errgroup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright 2023 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package errgroup is bit more advanced than golang.org/x/sync/errgroup.
// Major difference is that when error group is created with WithContext
// the parent context would implicitly cancel all functions called by Go method.
package errgroup

import (
"context"

"golang.org/x/sync/errgroup"
)

// The Group is superior errgroup.Group which aborts whole group
// execution when parent context is cancelled
type Group struct {
grp *errgroup.Group
ctx context.Context
}

// WithContext creates Group and store inside parent context
// so the Go method would respect parent context cancellation
func WithContext(ctx context.Context) (*Group, context.Context) {
grp, child_ctx := errgroup.WithContext(ctx)
return &Group{grp: grp, ctx: ctx}, child_ctx
}

// Go runs the provided f function in a dedicated goroutine and waits for its
// completion or for the parent context cancellation.
func (g *Group) Go(f func() error) {
g.getErrGroup().Go(g.wrap(f))
}

// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
// If the error group was created via WithContext then the Wait returns error
// of cancelled parent context prior any functions calls complete.
func (g *Group) Wait() error {
return g.getErrGroup().Wait()
}

// SetLimit limits the number of active goroutines in this group to at most n.
// A negative value indicates no limit.
//
// Any subsequent call to the Go method will block until it can add an active
// goroutine without exceeding the configured limit.
//
// The limit must not be modified while any goroutines in the group are active.
func (g *Group) SetLimit(n int) {
g.getErrGroup().SetLimit(n)
}

// TryGo calls the given function in a new goroutine only if the number of
// active goroutines in the group is currently below the configured limit.
//
// The return value reports whether the goroutine was started.
func (g *Group) TryGo(f func() error) bool {
return g.getErrGroup().TryGo(g.wrap(f))
}

func (g *Group) wrap(f func() error) func() error {
if g.ctx == nil {
return f
}

return func() error {
// If parent context is canceled,
// just return its error and do not call func f
select {
case <-g.ctx.Done():
return g.ctx.Err()
default:
}

// Create return channel and call func f
// Buffered channel is used as the following select
// may be exiting by context cancellation
// and in such case the write to channel can be block
// and cause the go routine leak
ch := make(chan error, 1)
go func() {
ch <- f()
}()

// Wait func f complete or
// parent context to be cancelled,
select {
case err := <-ch:
return err
case <-g.ctx.Done():
return g.ctx.Err()
}
}
}

// The getErrGroup returns actual x/sync/errgroup.Group.
// If the group is not allocated it would implicitly allocate it.
// Thats allows the internal/errgroup.Group be fully
// compatible to x/sync/errgroup.Group
func (g *Group) getErrGroup() *errgroup.Group {
if g.grp == nil {
g.grp = &errgroup.Group{}
}
return g.grp
}
124 changes: 124 additions & 0 deletions internal/errgroup/errgroup_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright 2023 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package errgroup

import (
"context"
"fmt"
"testing"

"golang.org/x/sync/errgroup"
)

// TestRegularErrGroupDoesNotRespectParentContext checks regular errgroup behavior
// where errgroup.WithContext does not respect the parent context
func TestRegularErrGroupDoesNotRespectParentContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
eg, _ := errgroup.WithContext(ctx)

what := fmt.Errorf("func generated error")
ch := make(chan error)
eg.Go(func() error { return <-ch })

cancel() // abort parent context
ch <- what // signal the func in regular errgroup to fail
err := eg.Wait()

// The error shall be one returned by the function
// as regular errgroup.WithContext does not respect parent context
if err != what {
t.Errorf("invalid error. got=%+v, want=%+v", err, what)
}
}

// TestErrGroupWithContextCanCallFunctions checks the errgroup operations
// are fine working and errgroup called function can return error
func TestErrGroupWithContextCanCallFunctions(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
eg, _ := WithContext(ctx)

what := fmt.Errorf("func generated error")
ch := make(chan error)
eg.Go(func() error { return <-ch })

ch <- what // signal the func in errgroup to fail
err := eg.Wait() // wait errgroup complete and read error

// The error shall be one returned by the function
if err != what {
t.Errorf("invalid error. got=%+v, want=%+v", err, what)
}
}

// TestErrGroupWithContextDoesRespectParentContext checks the errgroup operations
// are cancellable by parent context
func TestErrGroupWithContextDoesRespectParentContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
eg, _ := WithContext(ctx)

s1 := make(chan struct{})
s2 := make(chan struct{})
eg.Go(func() error {
s1 <- struct{}{}
<-s2
return fmt.Errorf("func generated error")
})

// We have no set limit to errgroup so
// shall be able to start function via TryGo
if ok := eg.TryGo(func() error { return nil }); !ok {
t.Errorf("Expected TryGo to be able start function!!!")
}

<-s1 // wait for function to start
cancel() // abort parent context

eg.Go(func() error {
t.Errorf("The parent context was already cancelled and this function shall not be called!!!")
return nil
})

s2 <- struct{}{} // signal the func in regular errgroup to fail
err := eg.Wait() // wait errgroup complete and read error

// The error shall be one returned by the function
// as regular errgroup.WithContext does not respect parent context
if err != context.Canceled {
t.Errorf("expected a context.Canceled error, got=%+v", err)
}
}

// TestErrGroupFallback tests fallback logic to be compatible with x/sync/errgroup
func TestErrGroupFallback(t *testing.T) {
eg := Group{}
eg.SetLimit(2)

ch1 := make(chan error)
eg.Go(func() error { return <-ch1 })

ch2 := make(chan error)
ok := eg.TryGo(func() error { return <-ch2 })
if !ok {
t.Errorf("Expected errgroup.TryGo to success!!!")
}

// The limit set to 2, so 3rd function shall not be possible to call
ok = eg.TryGo(func() error {
t.Errorf("This function is unexpected to be called!!!")
return nil
})
if ok {
t.Errorf("Expected errgroup.TryGo to fail!!!")
}

ch1 <- nil
ch2 <- nil
err := eg.Wait()

if err != nil {
t.Errorf("expected a nil error, got=%+v", err)
}
}
5 changes: 3 additions & 2 deletions msgio.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"sync"

errgrp "github.com/go-zeromq/zmq4/internal/errgroup"
"golang.org/x/sync/errgroup"
)

Expand Down Expand Up @@ -63,11 +64,11 @@ func (q *qreader) Close() error {
}

func (q *qreader) addConn(r *Conn) {
go q.listen(q.ctx, r)
q.mu.Lock()
q.sem.enable()
q.rs = append(q.rs, r)
q.mu.Unlock()
go q.listen(q.ctx, r)
}

func (q *qreader) rmConn(r *Conn) {
Expand Down Expand Up @@ -167,7 +168,7 @@ func (mw *mwriter) rmConn(w *Conn) {

func (w *mwriter) write(ctx context.Context, msg Msg) error {
w.sem.lock(ctx)
grp, _ := errgroup.WithContext(ctx)
grp, _ := errgrp.WithContext(ctx)
w.mu.Lock()
for i := range w.ws {
ww := w.ws[i]
Expand Down
7 changes: 7 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ func WithDialerTimeout(timeout time.Duration) Option {
}
}

// WithTimeout sets the timeout value for socket operations
func WithTimeout(timeout time.Duration) Option {
return func(s *socket) {
s.timeout = timeout
}
}

// WithLogger sets a dedicated log.Logger for the socket.
func WithLogger(msg *log.Logger) Option {
return func(s *socket) {
Expand Down
6 changes: 3 additions & 3 deletions pub.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (pub *pubSocket) Close() error {
// Send puts the message on the outbound send queue.
// Send blocks until the message can be queued or the send deadline expires.
func (pub *pubSocket) Send(msg Msg) error {
ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.timeout())
ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.Timeout())
defer cancel()
return pub.sck.w.write(ctx, msg)
}
Expand All @@ -49,7 +49,7 @@ func (pub *pubSocket) Send(msg Msg) error {
// The message will be sent as a multipart message.
func (pub *pubSocket) SendMulti(msg Msg) error {
msg.multipart = true
ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.timeout())
ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.Timeout())
defer cancel()
return pub.sck.w.write(ctx, msg)
}
Expand Down Expand Up @@ -149,11 +149,11 @@ func (q *pubQReader) Close() error {
}

func (q *pubQReader) addConn(r *Conn) {
go q.listen(q.ctx, r)
q.mu.Lock()
q.sem.enable()
q.rs = append(q.rs, r)
q.mu.Unlock()
go q.listen(q.ctx, r)
}

func (q *pubQReader) rmConn(r *Conn) {
Expand Down
4 changes: 2 additions & 2 deletions rep.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (rep *repSocket) Close() error {
// Send puts the message on the outbound send queue.
// Send blocks until the message can be queued or the send deadline expires.
func (rep *repSocket) Send(msg Msg) error {
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.timeout())
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.Timeout())
defer cancel()
return rep.sck.w.write(ctx, msg)
}
Expand All @@ -44,7 +44,7 @@ func (rep *repSocket) Send(msg Msg) error {
// The message will be sent as a multipart message.
func (rep *repSocket) SendMulti(msg Msg) error {
msg.multipart = true
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.timeout())
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.Timeout())
defer cancel()
return rep.sck.w.write(ctx, msg)
}
Expand Down
1 change: 1 addition & 0 deletions rep_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func TestCancellation(t *testing.T) {

defer wg.Done()
repCtx, cancel := context.WithCancel(context.Background())
defer cancel()
rep := zmq4.NewRep(repCtx)
defer rep.Close()

Expand Down
4 changes: 2 additions & 2 deletions req.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (req *reqSocket) Close() error {
// Send puts the message on the outbound send queue.
// Send blocks until the message can be queued or the send deadline expires.
func (req *reqSocket) Send(msg Msg) error {
ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.timeout())
ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.Timeout())
defer cancel()
return req.sck.w.write(ctx, msg)
}
Expand All @@ -45,7 +45,7 @@ func (req *reqSocket) Send(msg Msg) error {
// The message will be sent as a multipart message.
func (req *reqSocket) SendMulti(msg Msg) error {
msg.multipart = true
ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.timeout())
ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.Timeout())
defer cancel()
return req.sck.w.write(ctx, msg)
}
Expand Down
4 changes: 2 additions & 2 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (router *routerSocket) Close() error {
// Send puts the message on the outbound send queue.
// Send blocks until the message can be queued or the send deadline expires.
func (router *routerSocket) Send(msg Msg) error {
ctx, cancel := context.WithTimeout(router.sck.ctx, router.sck.timeout())
ctx, cancel := context.WithTimeout(router.sck.ctx, router.sck.Timeout())
defer cancel()
return router.sck.w.write(ctx, msg)
}
Expand Down Expand Up @@ -119,11 +119,11 @@ func (q *routerQReader) Close() error {
}

func (q *routerQReader) addConn(r *Conn) {
go q.listen(q.ctx, r)
q.mu.Lock()
q.sem.enable()
q.rs = append(q.rs, r)
q.mu.Unlock()
go q.listen(q.ctx, r)
}

func (q *routerQReader) rmConn(r *Conn) {
Expand Down
Loading

0 comments on commit 16ca7c0

Please sign in to comment.