Skip to content

Commit

Permalink
Add timeout support on send
Browse files Browse the repository at this point in the history
Add internal/errorgrp package to support cancellable error groups
Add tests for push/pull timeout
  • Loading branch information
Sergey Egorov committed Dec 11, 2023
1 parent e16dc3e commit 147519a
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 5 deletions.
61 changes: 61 additions & 0 deletions internal/errorgrp/errorgrp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Package errorgrp is bit more advanced than errgroup
// Major difference is that when error group is created with WithContext2
// the parent context would implicitly cancel all functions called by Go method.
//
// The name is selected so you can mix regular errgroup and errorgrp in same file.
package errorgrp

import (
"context"

"golang.org/x/sync/errgroup"
)

// The Group2 is superior errgroup.Group which aborts whole group
// execution when parent context is cancelled
type Group2 struct {
grp *errgroup.Group
ctx context.Context
}

// WithContext2 creates Group2 and store inside parent context
// so the Go method would respect parent context cancellation
func WithContext2(ctx context.Context) (*Group2, context.Context) {
grp, child_ctx := errgroup.WithContext(ctx)
return &Group2{grp: grp, ctx: ctx}, child_ctx
}

// Go function would wait for parent context to be cancelled,
// or func f to be complete complete
func (g *Group2) Go(f func() error) {
g.grp.Go(func() error {
// If parent context is canceled,
// just return its error and do not call func f
select {
case <-g.ctx.Done():
return g.ctx.Err()
default:
}

// Create return channel
// and call func f
ch := make(chan error, 1)
go func() {
ch <- f()
}()

// Wait func f complete or
// parent context to be cancelled,
select {
case err := <-ch:
return err
case <-g.ctx.Done():
return g.ctx.Err()
}
})
}

// Wait is direct call to errgroup.Wait
func (g *Group2) Wait() error {
return g.grp.Wait()
}
61 changes: 61 additions & 0 deletions internal/errorgrp/errorgrp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package errorgrp

import (
"context"
"fmt"
"testing"

"golang.org/x/sync/errgroup"
)

// TestErrGroupDoesNotRespectParentContext check regulare errgroup behavior
// where errgroup.WithContext does not respects the parent context
func TestErrGroupDoesNotRespectParentContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
eg, _ := errgroup.WithContext(ctx)

er := fmt.Errorf("func generated error")
s := make(chan struct{}, 1)
eg.Go(func() error {
<-s
return er
})

// Abort context
cancel()
// Signal the func in regular errgroup to fail
s <- struct{}{}
// Wait regular errgroup complete and read error
err := eg.Wait()

// The error shall be one returned by the function
// as regular errgroup.WithContext does not respect parent context
if err != er {
t.Fail()
}
}

func TestErrorGrpWithContext2DoesRespectsParentContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
eg, _ := WithContext2(ctx)

er := fmt.Errorf("func generated error")
s := make(chan struct{}, 1)
eg.Go(func() error {
<-s
return er
})

// Abort context
cancel()
// Signal the func in regular errgroup to fail
s <- struct{}{}
// Wait regular errgroup complete and read error
err := eg.Wait()

// The error shall be one returned by the function
// as regular errgroup.WithContext does not respect parent context
if err != context.Canceled {
t.Fail()
}
}
3 changes: 2 additions & 1 deletion msgio.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"sync"

"github.com/go-zeromq/zmq4/internal/errorgrp"
"golang.org/x/sync/errgroup"
)

Expand Down Expand Up @@ -167,7 +168,7 @@ func (mw *mwriter) rmConn(w *Conn) {

func (w *mwriter) write(ctx context.Context, msg Msg) error {
w.sem.lock(ctx)
grp, _ := errgroup.WithContext(ctx)
grp, _ := errorgrp.WithContext2(ctx)
w.mu.Lock()
for i := range w.ws {
ww := w.ws[i]
Expand Down
7 changes: 7 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ func WithDialerTimeout(timeout time.Duration) Option {
}
}

// WithTimeout sets the timeout value for socket operations
func WithTimeout(timeout time.Duration) Option {
return func(s *socket) {
s.Timeout = timeout
}
}

// WithLogger sets a dedicated log.Logger for the socket.
func WithLogger(msg *log.Logger) Option {
return func(s *socket) {
Expand Down
4 changes: 2 additions & 2 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"net"
"sync"

"golang.org/x/sync/errgroup"
"github.com/go-zeromq/zmq4/internal/errorgrp"
)

// NewRouter returns a new ROUTER ZeroMQ socket.
Expand Down Expand Up @@ -225,7 +225,7 @@ func (mw *routerMWriter) rmConn(w *Conn) {

func (w *routerMWriter) write(ctx context.Context, msg Msg) error {
w.sem.lock(ctx)
grp, _ := errgroup.WithContext(ctx)
grp, _ := errorgrp.WithContext2(ctx)
w.mu.Lock()
id := msg.Frames[0]
dmsg := NewMsgFrom(msg.Frames[1:]...)
Expand Down
5 changes: 3 additions & 2 deletions socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type socket struct {
log *log.Logger
subTopics func() []string
autoReconnect bool
Timeout time.Duration

mu sync.RWMutex
conns []*Conn // ZMTP connections
Expand Down Expand Up @@ -67,6 +68,7 @@ func newDefaultSocket(ctx context.Context, sockType SocketType) *socket {
typ: sockType,
retry: defaultRetry,
maxRetries: defaultMaxRetries,
Timeout: defaultTimeout,
sec: nullSecurity{},
conns: nil,
r: newQReader(ctx),
Expand Down Expand Up @@ -366,8 +368,7 @@ func (sck *socket) SetOption(name string, value interface{}) error {
}

func (sck *socket) timeout() time.Duration {
// FIXME(sbinet): extract from options
return defaultTimeout
return sck.Timeout
}

func (sck *socket) connReaper() {
Expand Down
45 changes: 45 additions & 0 deletions zmq4_timeout_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package zmq4

import (
"context"
"testing"
"time"
)

func TestPushTimeout(t *testing.T) {
ep := "ipc://@push_timeout_test"
push := NewPush(context.Background(), WithTimeout(1*time.Second))
defer push.Close()
if err := push.Listen(ep); err != nil {
t.FailNow()
}

pull := NewPull(context.Background())
defer pull.Close()
if err := pull.Dial(ep); err != nil {
t.FailNow()
}

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
for {
select {
case <-ctx.Done():
// The ctx limits overall time of execution
// If it gets canceled, that meains tests failed
// as write to socket did not genereate timeout error
t.FailNow()
default:
}

err := push.Send(NewMsgString("test string"))
if err == nil {
continue
}
if err != context.DeadlineExceeded {
t.FailNow()
}
break
}

}

0 comments on commit 147519a

Please sign in to comment.