Skip to content

Commit

Permalink
enables grpc.Invoke timeouts in the unary client
Browse files Browse the repository at this point in the history
  • Loading branch information
drewwells committed Apr 3, 2020
1 parent 747bac5 commit 10b64ec
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 63 deletions.
26 changes: 13 additions & 13 deletions pep/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,27 +218,27 @@ const (
)

type options struct {
addresses []string
balancer int
tracer ot.Tracer
maxStreams int
ctx context.Context
addresses []string
balancer int
tracer ot.Tracer
maxStreams int
ctx context.Context
connStateCb ConnectionStateNotificationCallback
autoRequestSize bool
maxRequestSize uint32
noPool bool
cache bool
cacheTTL time.Duration
cacheMaxSize int
// ignored by Unary client
connTimeout time.Duration
connStateCb ConnectionStateNotificationCallback
autoRequestSize bool
maxRequestSize uint32
noPool bool
cache bool
cacheTTL time.Duration
cacheMaxSize int
onCacheHitHandler OnCacheHitHandler
clientUnaryInterceptors []grpc.UnaryClientInterceptor
}

// NewClient creates client instance using given options.
func NewClient(opts ...Option) Client {
o := options{
connTimeout: -1,
maxRequestSize: 10240,
}
for _, opt := range opts {
Expand Down
48 changes: 24 additions & 24 deletions pep/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ import (

func TestNewClient(t *testing.T) {
c := NewClient()
if _, ok := c.(*unaryClient); !ok {
t.Errorf("Expected *unaryClient from NewClient got %#v", c)
if _, ok := c.(*UnaryClient); !ok {
t.Errorf("Expected *UnaryClient from NewClient got %#v", c)
}
}

func TestNewBalancedClient(t *testing.T) {
c := NewClient(WithRoundRobinBalancer("127.0.0.1:1000", "127.0.0.1:1001"))
if uc, ok := c.(*unaryClient); ok {
if uc, ok := c.(*UnaryClient); ok {
if len(uc.opts.addresses) <= 0 {
t.Errorf("Expected balancer to be set but got nothing")
}
} else {
t.Errorf("Expected *unaryClient from NewClient got %#v", c)
t.Errorf("Expected *UnaryClient from NewClient got %#v", c)
}

c = NewClient(WithHotSpotBalancer("127.0.0.1:1000", "127.0.0.1:1001"), WithStreams(5))
Expand All @@ -50,9 +50,9 @@ func TestNewStreamingClient(t *testing.T) {
func TestNewClientWithTracer(t *testing.T) {
tr := &ot.NoopTracer{}
c := NewClient(WithTracer(tr))
uc, ok := c.(*unaryClient)
uc, ok := c.(*UnaryClient)
if !ok {
t.Fatalf("Expected *unaryClient from NewClient got %#v", c)
t.Fatalf("Expected *UnaryClient from NewClient got %#v", c)
}

if uc.opts.tracer != tr {
Expand All @@ -67,9 +67,9 @@ var noOpClientInterceptor grpc.UnaryClientInterceptor = func(ctx context.Context
func TestNewClientWithInterceptor(t *testing.T) {
ci := noOpClientInterceptor
c := NewClient(WithClientUnaryInterceptors(ci))
uc, ok := c.(*unaryClient)
uc, ok := c.(*UnaryClient)
if !ok {
t.Fatalf("Expected *unaryClient from NewClient got %#v", c)
t.Fatalf("Expected *UnaryClient from NewClient got %#v", c)
}

if len(uc.opts.clientUnaryInterceptors) == 0 {
Expand All @@ -79,9 +79,9 @@ func TestNewClientWithInterceptor(t *testing.T) {

func TestNewClientWithAutoRequestSize(t *testing.T) {
c := NewClient(WithAutoRequestSize(true))
uc, ok := c.(*unaryClient)
uc, ok := c.(*UnaryClient)
if !ok {
t.Fatalf("Expected *unaryClient from NewClient got %#v", c)
t.Fatalf("Expected *UnaryClient from NewClient got %#v", c)
}

if !uc.opts.autoRequestSize {
Expand All @@ -91,9 +91,9 @@ func TestNewClientWithAutoRequestSize(t *testing.T) {

func TestNewClientWithMaxRequestSize(t *testing.T) {
c := NewClient(WithMaxRequestSize(1024))
uc, ok := c.(*unaryClient)
uc, ok := c.(*UnaryClient)
if !ok {
t.Fatalf("Expected *unaryClient from NewClient got %#v", c)
t.Fatalf("Expected *UnaryClient from NewClient got %#v", c)
}

if uc.opts.maxRequestSize != 1024 {
Expand All @@ -103,9 +103,9 @@ func TestNewClientWithMaxRequestSize(t *testing.T) {

func TestNewClientWithNoRequestBufferPool(t *testing.T) {
c := NewClient(WithNoRequestBufferPool())
uc, ok := c.(*unaryClient)
uc, ok := c.(*UnaryClient)
if !ok {
t.Fatalf("Expected *unaryClient from NewClient got %#v", c)
t.Fatalf("Expected *UnaryClient from NewClient got %#v", c)
}

if uc.pool.b != nil {
Expand All @@ -115,9 +115,9 @@ func TestNewClientWithNoRequestBufferPool(t *testing.T) {

func TestNewClientWithCacheTTL(t *testing.T) {
c := NewClient(WithCacheTTL(5 * time.Second))
uc, ok := c.(*unaryClient)
uc, ok := c.(*UnaryClient)
if !ok {
t.Fatalf("Expected *unaryClient from NewClient got %#v", c)
t.Fatalf("Expected *UnaryClient from NewClient got %#v", c)
}

if !uc.opts.cache || uc.opts.cacheTTL != 5*time.Second {
Expand All @@ -127,9 +127,9 @@ func TestNewClientWithCacheTTL(t *testing.T) {

func TestNewClientWithCacheTTLAndMaxSize(t *testing.T) {
c := NewClient(WithCacheTTLAndMaxSize(5*time.Second, 1024))
uc, ok := c.(*unaryClient)
uc, ok := c.(*UnaryClient)
if !ok {
t.Fatalf("Expected *unaryClient from NewClient got %#v", c)
t.Fatalf("Expected *UnaryClient from NewClient got %#v", c)
}

if !uc.opts.cache || uc.opts.cacheTTL != 5*time.Second || uc.opts.cacheMaxSize != 1024 {
Expand All @@ -140,18 +140,18 @@ func TestNewClientWithCacheTTLAndMaxSize(t *testing.T) {

func TestNewClientWithContext(t *testing.T) {
c := NewClient()
uc, ok := c.(*unaryClient)
uc, ok := c.(*UnaryClient)
if !ok {
t.Fatalf("Expected *unaryClient from NewClient got %#v", c)
t.Fatalf("Expected *UnaryClient from NewClient got %#v", c)
}
if uc.opts.ctx != nil {
t.Errorf("Expected default client to have nil context")
}

c = NewClient(WithContext(nil))
uc, ok = c.(*unaryClient)
uc, ok = c.(*UnaryClient)
if !ok {
t.Fatalf("Expected *unaryClient from NewClient got %#v", c)
t.Fatalf("Expected *UnaryClient from NewClient got %#v", c)
}
if uc.opts.ctx != nil {
t.Errorf("Expected nil context to default to nil context")
Expand All @@ -160,9 +160,9 @@ func TestNewClientWithContext(t *testing.T) {
toCtx, toCancelFn := context.WithTimeout(context.Background(), 1*time.Second)
defer toCancelFn()
c = NewClient(WithContext(toCtx))
uc, ok = c.(*unaryClient)
uc, ok = c.(*UnaryClient)
if !ok {
t.Fatalf("Expected *unaryClient from NewClient got %#v", c)
t.Fatalf("Expected *UnaryClient from NewClient got %#v", c)
}
if uc.opts.ctx != toCtx {
t.Errorf("Expected timeout context")
Expand Down
35 changes: 16 additions & 19 deletions pep/unary_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ import (
"sync"

"github.com/allegro/bigcache"
"github.com/grpc-ecosystem/go-grpc-middleware"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
ot "github.com/opentracing/opentracing-go"
"google.golang.org/grpc"

pb "github.com/infobloxopen/themis/pdp-service"
)

type unaryClient struct {
type UnaryClient struct {
lock *sync.RWMutex
conn *grpc.ClientConn
client *pb.PDPClient
Expand All @@ -26,8 +26,8 @@ type unaryClient struct {
opts options
}

func newUnaryClient(opts options) *unaryClient {
c := &unaryClient{
func newUnaryClient(opts options) *UnaryClient {
c := &UnaryClient{
lock: &sync.RWMutex{},
opts: opts,
}
Expand All @@ -39,7 +39,7 @@ func newUnaryClient(opts options) *unaryClient {
return c
}

func (c *unaryClient) Connect(addr string) error {
func (c *UnaryClient) Connect(addr string) error {
c.lock.Lock()
defer c.lock.Unlock()

Expand Down Expand Up @@ -96,13 +96,6 @@ func (c *unaryClient) Connect(addr string) error {
ctx = context.Background()
}

if c.opts.connTimeout > 0 {
var cancelFn context.CancelFunc
ctx, cancelFn = context.WithTimeout(ctx, c.opts.connTimeout)
defer cancelFn()
}

opts = append(opts, grpc.WithBlock())
conn, err := grpc.DialContext(ctx, addr, opts...)
if err != nil {
return err
Expand All @@ -117,7 +110,7 @@ func (c *unaryClient) Connect(addr string) error {
return nil
}

func (c *unaryClient) Close() {
func (c *UnaryClient) Close() {
c.lock.Lock()
defer c.lock.Unlock()

Expand All @@ -134,7 +127,16 @@ func (c *unaryClient) Close() {
c.client = nil
}

func (c *unaryClient) Validate(in, out interface{}) error {
func (c *UnaryClient) ValidateContext(ctx context.Context, in, out interface{}) error {
return c.validate(ctx, in, out)
}

// Validate is deprecated, use ValidateContext
func (c *UnaryClient) Validate(in, out interface{}) error {
return c.validate(context.Background(), in, out)
}

func (c *UnaryClient) validate(ctx context.Context, in, out interface{}) error {
c.lock.RLock()
uc := c.client
c.lock.RUnlock()
Expand Down Expand Up @@ -181,11 +183,6 @@ func (c *unaryClient) Validate(in, out interface{}) error {
}
}

ctx := c.opts.ctx
if ctx == nil {
ctx = context.Background()
}

if c.opts.connTimeout > 0 {
var cancelFn context.CancelFunc
ctx, cancelFn = context.WithTimeout(ctx, c.opts.connTimeout)
Expand Down
23 changes: 16 additions & 7 deletions pep/unary_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/status"

"github.com/infobloxopen/themis/pdp"
Expand Down Expand Up @@ -88,9 +89,9 @@ func TestUnaryClientValidationWithCache(t *testing.T) {
}
defer c.Close()

uc, ok := c.(*unaryClient)
uc, ok := c.(*UnaryClient)
if !ok {
t.Fatalf("expected *unaryClient but got %#v", c)
t.Fatalf("expected *UnaryClient but got %#v", c)
}
bc := uc.cache
if bc == nil {
Expand Down Expand Up @@ -173,12 +174,20 @@ func startTestPDPServer(p string, s uint16, t *testing.T) *loggedServer {
}

func TestUnaryClientConnectTimeout(t *testing.T) {
c := NewClient(WithConnectionTimeout(1 * time.Second))
c := NewClient().(*UnaryClient)
err := c.Connect("127.0.0.1:5555")
if err == nil {
t.Fatalf("expected DeadlineExceeded error")
} else if err != context.DeadlineExceeded {
t.Fatalf("expected DeadlineExceeded error but got %s", err)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Second)
defer cancel()
c.conn.WaitForStateChange(ctx, connectivity.Idle)
c.conn.WaitForStateChange(ctx, connectivity.Connecting)
if e := connectivity.TransientFailure; e != c.conn.GetState() {
t.Errorf("wanted: %s got: %s", e, c.conn.GetState())
}
if ctx.Err() != nil {
t.Fatal(ctx.Err())
}
}

Expand Down

0 comments on commit 10b64ec

Please sign in to comment.