diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 93531ce225..92e5673b8d 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -60,6 +60,32 @@ update those files. There are `make` targets to help with generation: Don't forget to account for changes to generated files in tests. +You can build mocks/protos etc. for a single component with: + +``` +make mock-gen- + +e.g. +make mock-gen-aggregator +``` + +### Adding new mocks + +`mockgen` statements are centralized in a single generate.go file per component. The convention is: + +`src//generated/mocks/generate.go` + +e.g. for the aggregator +`src/aggregator/generated/mocks/generate.go` + +### Adding new proto definitions + +Proto definitions should be placed in: + +``` +src//generated/proto` +``` + ## Scoping Pull Requests Inspired by Phabricator's article about diff --git a/src/aggregator/client/conn.go b/src/aggregator/client/conn.go index d5fc78dbb4..705ff1d48f 100644 --- a/src/aggregator/client/conn.go +++ b/src/aggregator/client/conn.go @@ -21,6 +21,7 @@ package client import ( + "context" "errors" "math/rand" "net" @@ -29,6 +30,7 @@ import ( "github.com/m3db/m3/src/x/clock" xio "github.com/m3db/m3/src/x/io" + xnet "github.com/m3db/m3/src/x/net" "github.com/m3db/m3/src/x/retry" "github.com/uber-go/tally" @@ -59,7 +61,7 @@ type connection struct { connectWithLockFn connectWithLockFn sleepFn sleepFn nowFn clock.NowFn - conn *net.TCPConn + conn net.Conn rngFn retry.RngFn writeWithLockFn writeWithLockFn addr string @@ -74,6 +76,7 @@ type connection struct { numFailures int mtx sync.Mutex keepAlive bool + dialer xnet.ContextDialerFn } // newConnection creates a new connection. @@ -88,6 +91,7 @@ func newConnection(addr string, opts ConnectionOptions) *connection { maxThreshold: opts.MaxReconnectThreshold(), maxDuration: opts.MaxReconnectDuration(), writeRetryOpts: opts.WriteRetryOptions(), + dialer: opts.ContextDialer(), rngFn: rand.New(rand.NewSource(time.Now().UnixNano())).Int63n, nowFn: opts.ClockOptions().NowFn(), sleepFn: time.Sleep, @@ -166,27 +170,52 @@ func (c *connection) writeAttemptWithLock(data []byte) error { } func (c *connection) connectWithLock() error { + // TODO: propagate this all the way up the callstack. + ctx := context.TODO() + c.lastConnectAttemptNanos = c.nowFn().UnixNano() - conn, err := net.DialTimeout(tcpProtocol, c.addr, c.connTimeout) + + ctx, cancel := context.WithTimeout(ctx, c.connTimeout) + defer cancel() + + conn, err := c.dialContext(ctx, c.addr) if err != nil { c.metrics.connectError.Inc(1) return err } - tcpConn := conn.(*net.TCPConn) - if err := tcpConn.SetKeepAlive(c.keepAlive); err != nil { - c.metrics.setKeepAliveError.Inc(1) + // N.B.: If using a custom dialer which doesn't return *net.TCPConn, users are responsible for TCP keep alive options + // themselves. + if tcpConn, ok := conn.(keepAlivable); ok { + if err := tcpConn.SetKeepAlive(c.keepAlive); err != nil { + c.metrics.setKeepAliveError.Inc(1) + } } if c.conn != nil { c.conn.Close() // nolint: errcheck } - c.conn = tcpConn - c.writer.Reset(tcpConn) + c.conn = conn + c.writer.Reset(conn) return nil } +// Make sure net.TCPConn implements this; otherwise bad things will happen. +var _ keepAlivable = (*net.TCPConn)(nil) + +type keepAlivable interface { + SetKeepAlive(shouldKeepAlive bool) error +} + +func (c *connection) dialContext(ctx context.Context, addr string) (net.Conn, error) { + if dialer := c.dialer; dialer != nil { + return dialer(ctx, tcpProtocol, addr) + } + var dialer net.Dialer + return dialer.DialContext(ctx, tcpProtocol, addr) +} + func (c *connection) checkReconnectWithLock() error { // If we haven't accumulated enough failures to warrant another reconnect // and we haven't past the maximum duration since the last time we attempted diff --git a/src/aggregator/client/conn_mock_test.go b/src/aggregator/client/conn_mock_test.go new file mode 100644 index 0000000000..0a05b15507 --- /dev/null +++ b/src/aggregator/client/conn_mock_test.go @@ -0,0 +1,170 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: net (interfaces: Conn) + +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// Package client is a generated GoMock package. +package client + +import ( + "net" + "reflect" + "time" + + "github.com/golang/mock/gomock" +) + +// MockConn is a mock of Conn interface. +type MockConn struct { + ctrl *gomock.Controller + recorder *MockConnMockRecorder +} + +// MockConnMockRecorder is the mock recorder for MockConn. +type MockConnMockRecorder struct { + mock *MockConn +} + +// NewMockConn creates a new mock instance. +func NewMockConn(ctrl *gomock.Controller) *MockConn { + mock := &MockConn{ctrl: ctrl} + mock.recorder = &MockConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConn) EXPECT() *MockConnMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConn)(nil).Close)) +} + +// LocalAddr mocks base method. +func (m *MockConn) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr. +func (mr *MockConnMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockConn)(nil).LocalAddr)) +} + +// Read mocks base method. +func (m *MockConn) Read(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read. +func (mr *MockConnMockRecorder) Read(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockConn)(nil).Read), arg0) +} + +// RemoteAddr mocks base method. +func (m *MockConn) RemoteAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoteAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// RemoteAddr indicates an expected call of RemoteAddr. +func (mr *MockConnMockRecorder) RemoteAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockConn)(nil).RemoteAddr)) +} + +// SetDeadline mocks base method. +func (m *MockConn) SetDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline. +func (mr *MockConnMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockConn)(nil).SetDeadline), arg0) +} + +// SetReadDeadline mocks base method. +func (m *MockConn) SetReadDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline. +func (mr *MockConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockConn)(nil).SetReadDeadline), arg0) +} + +// SetWriteDeadline mocks base method. +func (m *MockConn) SetWriteDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWriteDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWriteDeadline indicates an expected call of SetWriteDeadline. +func (mr *MockConnMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockConn)(nil).SetWriteDeadline), arg0) +} + +// Write mocks base method. +func (m *MockConn) Write(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write. +func (mr *MockConnMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockConn)(nil).Write), arg0) +} diff --git a/src/aggregator/client/conn_options.go b/src/aggregator/client/conn_options.go index a249295340..98e7073f00 100644 --- a/src/aggregator/client/conn_options.go +++ b/src/aggregator/client/conn_options.go @@ -26,6 +26,7 @@ import ( "github.com/m3db/m3/src/x/clock" "github.com/m3db/m3/src/x/instrument" xio "github.com/m3db/m3/src/x/io" + xnet "github.com/m3db/m3/src/x/net" "github.com/m3db/m3/src/x/retry" ) @@ -111,6 +112,17 @@ type ConnectionOptions interface { // RWOptions returns the RW options. RWOptions() xio.Options + + // ContextDialer allows customizing the way an aggregator client the aggregator, at the TCP layer. + // By default, this is: + // (&net.ContextDialer{}).DialContext. This can be used to do a variety of things, such as forwarding a connection + // over a proxy. + // NOTE: if your xnet.ContextDialerFn returns anything other a *net.TCPConn, TCP options such as KeepAlivePeriod + // will *not* be applied automatically. It is your responsibility to make sure these get applied as needed in + // your custom xnet.ContextDialerFn. + ContextDialer() xnet.ContextDialerFn + // SetContextDialer sets ContextDialer() -- see that method. + SetContextDialer(dialer xnet.ContextDialerFn) ConnectionOptions } type connectionOptions struct { @@ -125,6 +137,7 @@ type connectionOptions struct { maxThreshold int multiplier int connKeepAlive bool + dialer xnet.ContextDialerFn } // NewConnectionOptions create a new set of connection options. @@ -147,6 +160,7 @@ func NewConnectionOptions() ConnectionOptions { maxDuration: defaultMaxReconnectDuration, writeRetryOpts: defaultWriteRetryOpts, rwOpts: xio.NewOptions(), + dialer: nil, // Will default to net.Dialer{}.DialContext } } @@ -259,3 +273,14 @@ func (o *connectionOptions) SetRWOptions(value xio.Options) ConnectionOptions { func (o *connectionOptions) RWOptions() xio.Options { return o.rwOpts } + +func (o *connectionOptions) ContextDialer() xnet.ContextDialerFn { + return o.dialer +} + +// SetContextDialer see ContextDialer. +func (o *connectionOptions) SetContextDialer(dialer xnet.ContextDialerFn) ConnectionOptions { + opts := *o + opts.dialer = dialer + return &opts +} diff --git a/src/aggregator/client/conn_test.go b/src/aggregator/client/conn_test.go index e7bbed48b0..6c1e6b330d 100644 --- a/src/aggregator/client/conn_test.go +++ b/src/aggregator/client/conn_test.go @@ -21,6 +21,7 @@ package client import ( + "context" "errors" "fmt" "math" @@ -31,9 +32,11 @@ import ( "github.com/m3db/m3/src/x/clock" + "github.com/golang/mock/gomock" "github.com/leanovate/gopter" "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -283,6 +286,75 @@ func TestConnectionWriteFailsOnSecondAttempt(t *testing.T) { require.Equal(t, 2, conn.threshold) } +type keepAlivableConn struct { + net.Conn + keepAlivable +} + +func TestConnectWithCustomDialer(t *testing.T) { + testData := []byte("foobar") + testConnectionTimeout := 5 * time.Second + + testWithConn := func(t *testing.T, netConn net.Conn) { + type args struct { + Ctx context.Context + Network string + Address string + } + var capturedArgs args + dialer := func(ctx context.Context, network string, address string) (net.Conn, error) { + capturedArgs = args{ + Ctx: ctx, + Network: network, + Address: address, + } + return netConn, nil + } + opts := testConnectionOptions(). + SetContextDialer(dialer). + SetConnectionTimeout(testConnectionTimeout) + addr := "127.0.0.1:5555" + + conn := newConnection(addr, opts) + start := time.Now() + require.NoError(t, conn.Write(testData)) + + assert.Equal(t, addr, capturedArgs.Address) + assert.Equal(t, tcpProtocol, capturedArgs.Network) + + deadline, ok := capturedArgs.Ctx.Deadline() + require.True(t, ok) + // Start is taken *before* we try to connect, so the deadline must = start + + testDialTimeout. + // Therefore deadline - start >= testDialTimeout. + assert.True(t, deadline.Sub(start) >= testConnectionTimeout) + } + + t.Run("non keep alivable conn", func(t *testing.T) { + ctrl := gomock.NewController(t) + mockConn := NewMockConn(ctrl) + + mockConn.EXPECT().Write(testData) + mockConn.EXPECT().SetWriteDeadline(gomock.Any()) + testWithConn(t, mockConn) + }) + + t.Run("keep alivable conn", func(t *testing.T) { + ctrl := gomock.NewController(t) + mockConn := NewMockConn(ctrl) + + mockConn.EXPECT().Write(testData) + mockConn.EXPECT().SetWriteDeadline(gomock.Any()) + + mockKeepAlivable := NewMockkeepAlivable(ctrl) + mockKeepAlivable.EXPECT().SetKeepAlive(true) + + testWithConn(t, keepAlivableConn{ + Conn: mockConn, + keepAlivable: mockKeepAlivable, + }) + }) +} + func TestConnectWriteToServer(t *testing.T) { data := []byte("foobar") diff --git a/src/aggregator/client/keep_alivable_mock_test.go b/src/aggregator/client/keep_alivable_mock_test.go new file mode 100644 index 0000000000..2883fdb8cb --- /dev/null +++ b/src/aggregator/client/keep_alivable_mock_test.go @@ -0,0 +1,68 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/m3db/m3/src/aggregator/client/conn.go + +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// Package client is a generated GoMock package. +package client + +import ( + "reflect" + + "github.com/golang/mock/gomock" +) + +// MockkeepAlivable is a mock of keepAlivable interface. +type MockkeepAlivable struct { + ctrl *gomock.Controller + recorder *MockkeepAlivableMockRecorder +} + +// MockkeepAlivableMockRecorder is the mock recorder for MockkeepAlivable. +type MockkeepAlivableMockRecorder struct { + mock *MockkeepAlivable +} + +// NewMockkeepAlivable creates a new mock instance. +func NewMockkeepAlivable(ctrl *gomock.Controller) *MockkeepAlivable { + mock := &MockkeepAlivable{ctrl: ctrl} + mock.recorder = &MockkeepAlivableMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockkeepAlivable) EXPECT() *MockkeepAlivableMockRecorder { + return m.recorder +} + +// SetKeepAlive mocks base method. +func (m *MockkeepAlivable) SetKeepAlive(shouldKeepAlive bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetKeepAlive", shouldKeepAlive) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetKeepAlive indicates an expected call of SetKeepAlive. +func (mr *MockkeepAlivableMockRecorder) SetKeepAlive(shouldKeepAlive interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetKeepAlive", reflect.TypeOf((*MockkeepAlivable)(nil).SetKeepAlive), shouldKeepAlive) +} diff --git a/src/aggregator/generated/mocks/generate.go b/src/aggregator/generated/mocks/generate.go index 1b16756579..5c29467186 100644 --- a/src/aggregator/generated/mocks/generate.go +++ b/src/aggregator/generated/mocks/generate.go @@ -23,6 +23,7 @@ //go:generate sh -c "mockgen -package=client github.com/m3db/m3/src/aggregator/client Client,AdminClient | genclean -pkg github.com/m3db/m3/src/aggregator/client -out $GOPATH/src/github.com/m3db/m3/src/aggregator/client/client_mock.go" //go:generate sh -c "mockgen -package=handler github.com/m3db/m3/src/aggregator/aggregator/handler Handler | genclean -pkg github.com/m3db/m3/src/aggregator/aggregator/handler -out $GOPATH/src/github.com/m3db/m3/src/aggregator/aggregator/handler/handler_mock.go" //go:generate sh -c "mockgen -package=runtime github.com/m3db/m3/src/aggregator/runtime OptionsWatcher | genclean -pkg github.com/m3db/m3/src/aggregator/runtime -out $GOPATH/src/github.com/m3db/m3/src/aggregator/runtime/runtime_mock.go" +//go:generate sh -c "mockgen -package=client net Conn | genclean -pkg github.com/m3db/m3/src/aggregator/client -out $GOPATH/src/github.com/m3db/m3/src/aggregator/client/conn_mock_test.go" // mockgen rules for generating mocks for unexported interfaces (file mode). //go:generate sh -c "mockgen -package=aggregator -destination=$GOPATH/src/github.com/m3db/m3/src/aggregator/aggregator/flush_mgr_mock.go -source=$GOPATH/src/github.com/m3db/m3/src/aggregator/aggregator/flush_mgr.go" @@ -34,5 +35,6 @@ //go:generate sh -c "mockgen -package=deploy -destination=$GOPATH/src/github.com/m3db/m3/src/aggregator/tools/deploy/manager_mock.go -source=$GOPATH/src/github.com/m3db/m3/src/aggregator/tools/deploy/manager.go" //go:generate sh -c "mockgen -package=deploy -destination=$GOPATH/src/github.com/m3db/m3/src/aggregator/tools/deploy/planner_mock.go -source=$GOPATH/src/github.com/m3db/m3/src/aggregator/tools/deploy/planner.go" //go:generate sh -c "mockgen -package=deploy -destination=$GOPATH/src/github.com/m3db/m3/src/aggregator/tools/deploy/validator_mock.go -source=$GOPATH/src/github.com/m3db/m3/src/aggregator/tools/deploy/validator.go" +//go:generate sh -c "mockgen -package=client -destination=$GOPATH/src/github.com/m3db/m3/src/aggregator/client/keep_alivable_mock_test.go -source=$GOPATH/src/github.com/m3db/m3/src/aggregator/client/conn.go client keepAlivableConn" package mocks diff --git a/src/msg/producer/config/writer.go b/src/msg/producer/config/writer.go index 782f1d45f1..820e555ba2 100644 --- a/src/msg/producer/config/writer.go +++ b/src/msg/producer/config/writer.go @@ -33,6 +33,7 @@ import ( "github.com/m3db/m3/src/msg/topic" "github.com/m3db/m3/src/x/instrument" xio "github.com/m3db/m3/src/x/io" + xnet "github.com/m3db/m3/src/x/net" "github.com/m3db/m3/src/x/pool" "github.com/m3db/m3/src/x/retry" @@ -52,7 +53,7 @@ type ConnectionConfiguration struct { ReadBufferSize *int `yaml:"readBufferSize"` // ContextDialer specifies a custom dialer to use when creating TCP connections to the consumer. // See writer.ConnectionOptions.ContextDialer for details. - ContextDialer writer.ContextDialerFn `yaml:"-"` // not serializable + ContextDialer xnet.ContextDialerFn `yaml:"-"` // not serializable } // NewOptions creates connection options. diff --git a/src/msg/producer/writer/options.go b/src/msg/producer/writer/options.go index 1a7e97eaa1..7b488adc79 100644 --- a/src/msg/producer/writer/options.go +++ b/src/msg/producer/writer/options.go @@ -21,8 +21,6 @@ package writer import ( - "context" - "net" "time" "github.com/m3db/m3/src/cluster/placement" @@ -30,6 +28,7 @@ import ( "github.com/m3db/m3/src/msg/protocol/proto" "github.com/m3db/m3/src/msg/topic" "github.com/m3db/m3/src/x/instrument" + xnet "github.com/m3db/m3/src/x/net" "github.com/m3db/m3/src/x/retry" ) @@ -55,10 +54,6 @@ const ( defaultWriterRetryInitialBackoff = time.Second * 5 ) -// ContextDialerFn allows customization of how a m3msg Writer connects to producer endpoints. -// See ConnectionOptions#ContextDialer -type ContextDialerFn func(ctx context.Context, network string, address string) (net.Conn, error) - // ConnectionOptions configs the connections. type ConnectionOptions interface { // NumConnections returns the number of connections. @@ -70,13 +65,13 @@ type ConnectionOptions interface { // ContextDialer allows customizing the way a m3msg Writer connects to producer endpoints. By default, this is: // (&net.ContextDialer{}).DialContext. This can be used to do a variety of things, such as forwarding a connection // over a proxy. - // NOTE: if your ContextDialerFn returns anything other a *net.TCPConn, TCP options such as KeepAlivePeriod + // NOTE: if your xnet.ContextDialerFn returns anything other a *net.TCPConn, TCP options such as KeepAlivePeriod // will *not* be applied automatically. It is your responsibility to make sure these get applied as needed in - // your custom ContextDialerFn. - ContextDialer() ContextDialerFn + // your custom xnet.ContextDialerFn. + ContextDialer() xnet.ContextDialerFn // SetContextDialer see ContextDialer. - SetContextDialer(fn ContextDialerFn) ConnectionOptions + SetContextDialer(fn xnet.ContextDialerFn) ConnectionOptions // DialTimeout returns the dial timeout. DialTimeout() time.Duration @@ -144,7 +139,7 @@ type connectionOptions struct { writeBufferSize int readBufferSize int iOpts instrument.Options - dialer ContextDialerFn + dialer xnet.ContextDialerFn } // NewConnectionOptions creates ConnectionOptions. @@ -184,11 +179,11 @@ func (opts *connectionOptions) SetDialTimeout(value time.Duration) ConnectionOpt return &o } -func (opts *connectionOptions) ContextDialer() ContextDialerFn { +func (opts *connectionOptions) ContextDialer() xnet.ContextDialerFn { return opts.dialer } -func (opts *connectionOptions) SetContextDialer(fn ContextDialerFn) ConnectionOptions { +func (opts *connectionOptions) SetContextDialer(fn xnet.ContextDialerFn) ConnectionOptions { o := *opts o.dialer = fn return &o diff --git a/src/x/net/context_dialerfn.go b/src/x/net/context_dialerfn.go new file mode 100644 index 0000000000..f77a8a6daa --- /dev/null +++ b/src/x/net/context_dialerfn.go @@ -0,0 +1,35 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package net + +import ( + "context" + "net" +) + +// ContextDialerFn allows customization of how a process makes its TCP connections. This is the same pattern/function +// signature used by grpc.WithContextDialer -- we just define it here for convenience (it's used in multiple places +// across the M3 codebase). +// It is implemented by at least net.Dialer +type ContextDialerFn func(ctx context.Context, network string, address string) (net.Conn, error) + +// Assert that net.Dialer.DialContext implements our interface. +var _ ContextDialerFn = (*net.Dialer)(nil).DialContext