Skip to content

Commit

Permalink
fix(reader): Separate initial connect timeout context so that context…
Browse files Browse the repository at this point in the history
… does not expire too early
  • Loading branch information
bow committed Jan 26, 2024
1 parent 3724a24 commit 524dec7
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
15 changes: 7 additions & 8 deletions cmd/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package cmd

import (
"context"
"fmt"
"net"
"time"
Expand Down Expand Up @@ -39,10 +38,11 @@ func newReaderCommand() *cobra.Command {
RunE: func(cmd *cobra.Command, args []string) error {

var (
err error
connectAddr net.Addr
ctx = cmd.Context()
dialOpts = []grpc.DialOption{
err error
connectAddr net.Addr
connectTimeout time.Duration
ctx = cmd.Context()
dialOpts = []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
}
addr = resolveAddr(v, addrKey, connectKey, defaultConnectAddr, defaultStartAddr)
Expand All @@ -54,9 +54,7 @@ func newReaderCommand() *cobra.Command {
return err
}
dialOpts = append(dialOpts, grpc.WithBlock())
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, v.GetDuration(connectTimeoutKey))
defer cancel()
connectTimeout = v.GetDuration(connectTimeoutKey)

} else {
server, ierr := makeServer(cmd, v, addr)
Expand All @@ -74,6 +72,7 @@ func newReaderCommand() *cobra.Command {

rdr, err := reader.NewBuilder(cmd.Context()).
Context(ctx).
ConnectTimeout(connectTimeout).
Address(connectAddr.String()).
DialOpts(dialOpts...).
Build()
Expand Down
24 changes: 18 additions & 6 deletions internal/reader/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,10 @@ type Builder struct {
scr tcell.Screen

// rpcBackend args.
addr string
dopts []grpc.DialOption
callTimeout time.Duration
addr string
dopts []grpc.DialOption
callTimeout time.Duration
connectTimeout time.Duration

// For testing.
be bknd.Backend
Expand All @@ -200,6 +201,11 @@ func (b *Builder) Address(addr string) *Builder {
return b
}

func (b *Builder) ConnectTimeout(timeout time.Duration) *Builder {
b.connectTimeout = timeout
return b
}

func (b *Builder) DialOpts(dialOpts ...grpc.DialOption) *Builder {
b.dopts = dialOpts
return b
Expand Down Expand Up @@ -247,13 +253,19 @@ func (b *Builder) Build() (*Reader, error) {
}

var (
be bknd.Backend
err error
be bknd.Backend
err error
connectCtx = b.ctx
cancel context.CancelFunc
)
if b.be != nil {
be = b.be
} else {
be, err = bknd.NewRPC(b.ctx, b.addr, b.dopts...)
if b.connectTimeout > 0 {
connectCtx, cancel = context.WithTimeout(b.ctx, b.connectTimeout)
defer cancel()
}
be, err = bknd.NewRPC(connectCtx, b.addr, b.dopts...)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 524dec7

Please sign in to comment.