Skip to content

Expose NetRPCConfig #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"time"

hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/yamux"
)

// If this is 1, then we've called CleanupClients. This can be used
Expand Down Expand Up @@ -202,6 +203,27 @@ type ClientConfig struct {
//
// You cannot Reattach to a server with this option enabled.
AutoMTLS bool

// NetRpcConfig allows configuring some properties of the connection
// if the protocol is ProtocolNetRPC.
NetRPCConfig *NetRPCConfig
}

// NetRPCConfig allows providing properties of an underlying connection
// via the ClientConfig.NetRPCConfig field.
type NetRPCConfig struct {
// EnableKeepalive is used to do a period keep alive
// messages using a ping.
EnableKeepAlive bool

// KeepAliveInterval is how often to perform the keep alive
KeepAliveInterval time.Duration

// ConnectionWriteTimeout is meant to be a "safety valve" timeout after
// we which will suspect a problem with the underlying connection and
// close it. This is only applied to writes, where's there's generally
// an expectation that things will move along quickly.
ConnectionWriteTimeout time.Duration
}

// ReattachConfig is used to configure a client to reattach to an
Expand Down Expand Up @@ -281,6 +303,15 @@ func CleanupClients() {
wg.Wait()
}

func DefaultNetRPCConfig() (netRPCConfig *NetRPCConfig) {
defaultYamuxConfig := yamux.DefaultConfig()
return &NetRPCConfig{
ConnectionWriteTimeout: defaultYamuxConfig.ConnectionWriteTimeout,
EnableKeepAlive: defaultYamuxConfig.EnableKeepAlive,
KeepAliveInterval: defaultYamuxConfig.KeepAliveInterval,
}
}

// Creates a new plugin client which manages the lifecycle of an external
// plugin and gets the address for the RPC connection.
//
Expand Down
91 changes: 91 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"path/filepath"
"strings"
"sync"
"syscall"
"testing"
"time"

Expand Down Expand Up @@ -184,6 +185,96 @@ func TestClient_testInterface(t *testing.T) {
}
}

func TestClient_keepAliveEnabled(t *testing.T) {
process := helperProcess("test-interface")
c := NewClient(&ClientConfig{
Cmd: process,
HandshakeConfig: testHandshake,
Plugins: testPluginMap,
AllowedProtocols: []Protocol{ProtocolNetRPC},
NetRPCConfig: &NetRPCConfig{
EnableKeepAlive: true,
KeepAliveInterval: 1 * time.Millisecond,
ConnectionWriteTimeout: 100 * time.Millisecond,
},
})
defer c.Kill()

// Grab the RPC client
client, err := c.Client()
if err != nil {
t.Fatalf("err should be nil, got %s", err)
}

// Grab the impl
raw, err := client.Dispense("test")
if err != nil {
t.Fatalf("err should be nil, got %s", err)
}

_, ok := raw.(testInterface)
if !ok {
t.Fatalf("bad: %#v", raw)
}

defer c.process.Signal(syscall.SIGCONT)
c.process.Signal(syscall.SIGSTOP)

select {
case <-c.doneCtx.Done():
case <-time.After(time.Second * 2):
t.Fatal("Context was not closed")
}

c.process.Signal(syscall.SIGCONT)
c.Kill()
}

func TestClient_keepAliveDisabled(t *testing.T) {
process := helperProcess("test-interface")
c := NewClient(&ClientConfig{
Cmd: process,
HandshakeConfig: testHandshake,
Plugins: testPluginMap,
AllowedProtocols: []Protocol{ProtocolNetRPC},
NetRPCConfig: &NetRPCConfig{
EnableKeepAlive: false,
KeepAliveInterval: 1 * time.Millisecond,
ConnectionWriteTimeout: 100 * time.Millisecond,
},
})
defer c.Kill()

// Grab the RPC client
client, err := c.Client()
if err != nil {
t.Fatalf("err should be nil, got %s", err)
}

// Grab the impl
raw, err := client.Dispense("test")
if err != nil {
t.Fatalf("err should be nil, got %s", err)
}

_, ok := raw.(testInterface)
if !ok {
t.Fatalf("bad: %#v", raw)
}

defer c.process.Signal(syscall.SIGCONT)
c.process.Signal(syscall.SIGSTOP)

select {
case <-c.doneCtx.Done():
t.Fatal("Context was closed")
case <-time.After(time.Second * 2):
}

c.process.Signal(syscall.SIGCONT)
c.Kill()
}

func TestClient_grpc_servercrash(t *testing.T) {
process := helperProcess("test-grpc")
c := NewClient(&ClientConfig{
Expand Down
60 changes: 60 additions & 0 deletions net.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package plugin

import (
"context"
"net"
"time"
)

type connWithCancel struct {
conn net.Conn
context.CancelFunc
}

func (wrapper *connWithCancel) Read(p []byte) (int, error) {
return wrapper.conn.Read(p)
}

func (wrapper *connWithCancel) Write(p []byte) (int, error) {
return wrapper.conn.Write(p)
}

func (wrapper *connWithCancel) Close() error {
err := wrapper.conn.Close()
wrapper.CancelFunc()
return err
}

func (wrapper *connWithCancel) LocalAddr() net.Addr {
return wrapper.conn.LocalAddr()
}

func (wrapper *connWithCancel) RemoteAddr() net.Addr {
return wrapper.conn.RemoteAddr()
}
func (wrapper *connWithCancel) SetDeadline(t time.Time) error {
return wrapper.conn.SetDeadline(t)
}

func (wrapper *connWithCancel) SetReadDeadline(t time.Time) error {
return wrapper.conn.SetReadDeadline(t)
}

func (wrapper *connWithCancel) SetWriteDeadline(t time.Time) error {
return wrapper.conn.SetWriteDeadline(t)
}

// NewConnWithCancel returns a net.Conn that signals the specified
// context.CancelFunc when the net.Conn object is closed.
func NewConnWithCancel(
conn net.Conn,
cancelFunc context.CancelFunc) (
wrapper net.Conn) {

wrapper = &connWithCancel{
conn: conn,
CancelFunc: cancelFunc,
}

return
}
24 changes: 19 additions & 5 deletions rpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ func newRPCClient(c *Client) (*RPCClient, error) {
if err != nil {
return nil, err
}
conn = NewConnWithCancel(conn, c.ctxCancel)

if tcpConn, ok := conn.(*net.TCPConn); ok {
// Make sure to set keep alive so that the connection doesn't die
tcpConn.SetKeepAlive(true)
Expand All @@ -37,8 +39,16 @@ func newRPCClient(c *Client) (*RPCClient, error) {
conn = tls.Client(conn, c.config.TLSConfig)
}

yamuxConfig := yamux.DefaultConfig()
netRPCConfig := c.config.NetRPCConfig
if netRPCConfig != nil {
yamuxConfig.EnableKeepAlive = netRPCConfig.EnableKeepAlive
yamuxConfig.KeepAliveInterval = netRPCConfig.KeepAliveInterval
yamuxConfig.ConnectionWriteTimeout = netRPCConfig.ConnectionWriteTimeout
}

// Create the actual RPC client
result, err := NewRPCClient(conn, c.config.Plugins)
result, err := newRpcClientWithConfig(yamuxConfig, conn, c.config.Plugins)
if err != nil {
conn.Close()
return nil, err
Expand All @@ -56,11 +66,9 @@ func newRPCClient(c *Client) (*RPCClient, error) {
return result, nil
}

// NewRPCClient creates a client from an already-open connection-like value.
// Dial is typically used instead.
func NewRPCClient(conn io.ReadWriteCloser, plugins map[string]Plugin) (*RPCClient, error) {
func newRpcClientWithConfig(yamuxConfig *yamux.Config, conn io.ReadWriteCloser, plugins map[string]Plugin) (*RPCClient, error) {
// Create the yamux client so we can multiplex
mux, err := yamux.Client(conn, nil)
mux, err := yamux.Client(conn, yamuxConfig)
if err != nil {
conn.Close()
return nil, err
Expand Down Expand Up @@ -97,6 +105,12 @@ func NewRPCClient(conn io.ReadWriteCloser, plugins map[string]Plugin) (*RPCClien
}, nil
}

// NewRPCClient creates a client from an already-open connection-like value.
// Dial is typically used instead.
func NewRPCClient(conn io.ReadWriteCloser, plugins map[string]Plugin) (*RPCClient, error) {
return newRpcClientWithConfig(nil, conn, plugins)
}

// SyncStreams should be called to enable syncing of stdout,
// stderr with the plugin.
//
Expand Down