Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timeout support on send #148

Merged
merged 8 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions internal/errgroup/errgroup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright 2023 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package errgroup is bit more advanced than golang.org/x/sync/errgroup.
// Major difference is that when error group is created with WithContext
// the parent context would implicitly cancel all functions called by Go method.
package errgroup

import (
"context"

"golang.org/x/sync/errgroup"
)

// The Group is superior errgroup.Group which aborts whole group
// execution when parent context is cancelled
type Group struct {
grp *errgroup.Group
ctx context.Context
}

// WithContext creates Group and store inside parent context
// so the Go method would respect parent context cancellation
func WithContext(ctx context.Context) (*Group, context.Context) {
grp, child_ctx := errgroup.WithContext(ctx)
return &Group{grp: grp, ctx: ctx}, child_ctx
sbinet marked this conversation as resolved.
Show resolved Hide resolved
}

// Go runs the provided f function in a dedicated goroutine and waits for its
// completion or for the parent context cancellation.
func (g *Group) Go(f func() error) {
g.getErrGroup().Go(g.wrap(f))
}

// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
// If the error group was created via WithContext then the Wait returns error
// of cancelled parent context prior any functions calls complete.
func (g *Group) Wait() error {
return g.getErrGroup().Wait()
}

// SetLimit limits the number of active goroutines in this group to at most n.
// A negative value indicates no limit.
//
// Any subsequent call to the Go method will block until it can add an active
// goroutine without exceeding the configured limit.
//
// The limit must not be modified while any goroutines in the group are active.
func (g *Group) SetLimit(n int) {
g.getErrGroup().SetLimit(n)
}

// TryGo calls the given function in a new goroutine only if the number of
// active goroutines in the group is currently below the configured limit.
//
// The return value reports whether the goroutine was started.
func (g *Group) TryGo(f func() error) bool {
return g.getErrGroup().TryGo(g.wrap(f))
}

func (g *Group) wrap(f func() error) func() error {
if g.ctx == nil {
return f
}

return func() error {
// If parent context is canceled,
// just return its error and do not call func f
select {
case <-g.ctx.Done():
return g.ctx.Err()
default:
}

// Create return channel and call func f
// Buffered channel is used as the following select
// may be exiting by context cancellation
// and in such case the write to channel can be block
// and cause the go routine leak
ch := make(chan error, 1)
go func() {
ch <- f()
}()

// Wait func f complete or
// parent context to be cancelled,
select {
case err := <-ch:
return err
case <-g.ctx.Done():
return g.ctx.Err()
}
}
}

// The getErrGroup returns actual x/sync/errgroup.Group.
// If the group is not allocated it would implicitly allocate it.
// Thats allows the internal/errgroup.Group be fully
// compatible to x/sync/errgroup.Group
func (g *Group) getErrGroup() *errgroup.Group {
if g.grp == nil {
g.grp = &errgroup.Group{}
}
return g.grp
}
124 changes: 124 additions & 0 deletions internal/errgroup/errgroup_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright 2023 The go-zeromq Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package errgroup

import (
"context"
"fmt"
"testing"

"golang.org/x/sync/errgroup"
)

// TestRegularErrGroupDoesNotRespectParentContext checks regular errgroup behavior
// where errgroup.WithContext does not respect the parent context
func TestRegularErrGroupDoesNotRespectParentContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
eg, _ := errgroup.WithContext(ctx)

what := fmt.Errorf("func generated error")
ch := make(chan error)
eg.Go(func() error { return <-ch })

cancel() // abort parent context
ch <- what // signal the func in regular errgroup to fail
err := eg.Wait()

// The error shall be one returned by the function
// as regular errgroup.WithContext does not respect parent context
if err != what {
t.Errorf("invalid error. got=%+v, want=%+v", err, what)
}
}

// TestErrGroupWithContextCanCallFunctions checks the errgroup operations
// are fine working and errgroup called function can return error
func TestErrGroupWithContextCanCallFunctions(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
eg, _ := WithContext(ctx)

what := fmt.Errorf("func generated error")
ch := make(chan error)
eg.Go(func() error { return <-ch })

ch <- what // signal the func in errgroup to fail
err := eg.Wait() // wait errgroup complete and read error

// The error shall be one returned by the function
if err != what {
t.Errorf("invalid error. got=%+v, want=%+v", err, what)
}
}

// TestErrGroupWithContextDoesRespectParentContext checks the errgroup operations
// are cancellable by parent context
func TestErrGroupWithContextDoesRespectParentContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
eg, _ := WithContext(ctx)

s1 := make(chan struct{})
s2 := make(chan struct{})
eg.Go(func() error {
s1 <- struct{}{}
<-s2
return fmt.Errorf("func generated error")
})

// We have no set limit to errgroup so
// shall be able to start function via TryGo
if ok := eg.TryGo(func() error { return nil }); !ok {
t.Errorf("Expected TryGo to be able start function!!!")
}

<-s1 // wait for function to start
cancel() // abort parent context

eg.Go(func() error {
t.Errorf("The parent context was already cancelled and this function shall not be called!!!")
return nil
})

s2 <- struct{}{} // signal the func in regular errgroup to fail
err := eg.Wait() // wait errgroup complete and read error

// The error shall be one returned by the function
// as regular errgroup.WithContext does not respect parent context
if err != context.Canceled {
t.Errorf("expected a context.Canceled error, got=%+v", err)
}
}

// TestErrGroupFallback tests fallback logic to be compatible with x/sync/errgroup
func TestErrGroupFallback(t *testing.T) {
eg := Group{}
eg.SetLimit(2)

ch1 := make(chan error)
eg.Go(func() error { return <-ch1 })

ch2 := make(chan error)
ok := eg.TryGo(func() error { return <-ch2 })
if !ok {
t.Errorf("Expected errgroup.TryGo to success!!!")
}

// The limit set to 2, so 3rd function shall not be possible to call
ok = eg.TryGo(func() error {
t.Errorf("This function is unexpected to be called!!!")
return nil
})
if ok {
t.Errorf("Expected errgroup.TryGo to fail!!!")
}

ch1 <- nil
ch2 <- nil
err := eg.Wait()

if err != nil {
t.Errorf("expected a nil error, got=%+v", err)
}
}
5 changes: 3 additions & 2 deletions msgio.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"sync"

errgrp "github.com/go-zeromq/zmq4/internal/errgroup"
"golang.org/x/sync/errgroup"
)

Expand Down Expand Up @@ -63,11 +64,11 @@ func (q *qreader) Close() error {
}

func (q *qreader) addConn(r *Conn) {
go q.listen(q.ctx, r)
q.mu.Lock()
q.sem.enable()
q.rs = append(q.rs, r)
q.mu.Unlock()
go q.listen(q.ctx, r)
}

func (q *qreader) rmConn(r *Conn) {
Expand Down Expand Up @@ -167,7 +168,7 @@ func (mw *mwriter) rmConn(w *Conn) {

func (w *mwriter) write(ctx context.Context, msg Msg) error {
w.sem.lock(ctx)
grp, _ := errgroup.WithContext(ctx)
grp, _ := errgrp.WithContext(ctx)
w.mu.Lock()
for i := range w.ws {
ww := w.ws[i]
Expand Down
7 changes: 7 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ func WithDialerTimeout(timeout time.Duration) Option {
}
}

// WithTimeout sets the timeout value for socket operations
func WithTimeout(timeout time.Duration) Option {
return func(s *socket) {
s.timeout = timeout
}
}

// WithLogger sets a dedicated log.Logger for the socket.
func WithLogger(msg *log.Logger) Option {
return func(s *socket) {
Expand Down
6 changes: 3 additions & 3 deletions pub.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (pub *pubSocket) Close() error {
// Send puts the message on the outbound send queue.
// Send blocks until the message can be queued or the send deadline expires.
func (pub *pubSocket) Send(msg Msg) error {
ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.timeout())
ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.Timeout())
defer cancel()
return pub.sck.w.write(ctx, msg)
}
Expand All @@ -49,7 +49,7 @@ func (pub *pubSocket) Send(msg Msg) error {
// The message will be sent as a multipart message.
func (pub *pubSocket) SendMulti(msg Msg) error {
msg.multipart = true
ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.timeout())
ctx, cancel := context.WithTimeout(pub.sck.ctx, pub.sck.Timeout())
defer cancel()
return pub.sck.w.write(ctx, msg)
}
Expand Down Expand Up @@ -149,11 +149,11 @@ func (q *pubQReader) Close() error {
}

func (q *pubQReader) addConn(r *Conn) {
go q.listen(q.ctx, r)
q.mu.Lock()
q.sem.enable()
q.rs = append(q.rs, r)
q.mu.Unlock()
go q.listen(q.ctx, r)
}

func (q *pubQReader) rmConn(r *Conn) {
Expand Down
4 changes: 2 additions & 2 deletions rep.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (rep *repSocket) Close() error {
// Send puts the message on the outbound send queue.
// Send blocks until the message can be queued or the send deadline expires.
func (rep *repSocket) Send(msg Msg) error {
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.timeout())
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.Timeout())
defer cancel()
return rep.sck.w.write(ctx, msg)
}
Expand All @@ -44,7 +44,7 @@ func (rep *repSocket) Send(msg Msg) error {
// The message will be sent as a multipart message.
func (rep *repSocket) SendMulti(msg Msg) error {
msg.multipart = true
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.timeout())
ctx, cancel := context.WithTimeout(rep.sck.ctx, rep.sck.Timeout())
defer cancel()
return rep.sck.w.write(ctx, msg)
}
Expand Down
1 change: 1 addition & 0 deletions rep_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func TestCancellation(t *testing.T) {

defer wg.Done()
repCtx, cancel := context.WithCancel(context.Background())
defer cancel()
rep := zmq4.NewRep(repCtx)
defer rep.Close()

Expand Down
4 changes: 2 additions & 2 deletions req.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (req *reqSocket) Close() error {
// Send puts the message on the outbound send queue.
// Send blocks until the message can be queued or the send deadline expires.
func (req *reqSocket) Send(msg Msg) error {
ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.timeout())
ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.Timeout())
defer cancel()
return req.sck.w.write(ctx, msg)
}
Expand All @@ -45,7 +45,7 @@ func (req *reqSocket) Send(msg Msg) error {
// The message will be sent as a multipart message.
func (req *reqSocket) SendMulti(msg Msg) error {
msg.multipart = true
ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.timeout())
ctx, cancel := context.WithTimeout(req.sck.ctx, req.sck.Timeout())
defer cancel()
return req.sck.w.write(ctx, msg)
}
Expand Down
4 changes: 2 additions & 2 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (router *routerSocket) Close() error {
// Send puts the message on the outbound send queue.
// Send blocks until the message can be queued or the send deadline expires.
func (router *routerSocket) Send(msg Msg) error {
ctx, cancel := context.WithTimeout(router.sck.ctx, router.sck.timeout())
ctx, cancel := context.WithTimeout(router.sck.ctx, router.sck.Timeout())
defer cancel()
return router.sck.w.write(ctx, msg)
}
Expand Down Expand Up @@ -119,11 +119,11 @@ func (q *routerQReader) Close() error {
}

func (q *routerQReader) addConn(r *Conn) {
go q.listen(q.ctx, r)
q.mu.Lock()
q.sem.enable()
q.rs = append(q.rs, r)
q.mu.Unlock()
go q.listen(q.ctx, r)
}

func (q *routerQReader) rmConn(r *Conn) {
Expand Down
Loading
Loading