From c55ededba893101a36b07dd5fb5228205e64712f Mon Sep 17 00:00:00 2001 From: Vernon Miller Date: Tue, 14 May 2024 08:43:24 -0600 Subject: [PATCH 1/5] make WriteTo non-blocking --- kv/memberlist/tcp_transport.go | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/kv/memberlist/tcp_transport.go b/kv/memberlist/tcp_transport.go index 751ad1163..abd7d1c90 100644 --- a/kv/memberlist/tcp_transport.go +++ b/kv/memberlist/tcp_transport.go @@ -52,7 +52,10 @@ type TCPTransportConfig struct { // Timeout for writing packet data. Zero = no timeout. PacketWriteTimeout time.Duration `yaml:"packet_write_timeout" category:"advanced"` - // Transport logs lot of messages at debug level, so it deserves an extra flag for turning it on + // Maximum number of concurrent writes to other nodes. + MaxConcurrentWrites int `yaml:"max_concurrent_writes" category:"advanced"` + + // Transport logs lots of messages at debug level, so it deserves an extra flag for turning it on TransportDebug bool `yaml:"-" category:"advanced"` // Where to put custom metrics. nil = don't register. @@ -73,6 +76,7 @@ func (cfg *TCPTransportConfig) RegisterFlagsWithPrefix(f *flag.FlagSet, prefix s f.IntVar(&cfg.BindPort, prefix+"memberlist.bind-port", 7946, "Port to listen on for gossip messages.") f.DurationVar(&cfg.PacketDialTimeout, prefix+"memberlist.packet-dial-timeout", 2*time.Second, "Timeout used when connecting to other nodes to send packet.") f.DurationVar(&cfg.PacketWriteTimeout, prefix+"memberlist.packet-write-timeout", 5*time.Second, "Timeout for writing 'packet' data.") + f.IntVar(&cfg.MaxConcurrentWrites, prefix+"memberlist.max-concurrent-writes", 1, "Maximum number of concurrent writes to other nodes.") f.BoolVar(&cfg.TransportDebug, prefix+"memberlist.transport-debug", false, "Log debug transport messages. Note: global log.level must be at debug level as well.") f.BoolVar(&cfg.TLSEnabled, prefix+"memberlist.tls-enabled", false, "Enable TLS on the memberlist transport layer.") @@ -88,6 +92,7 @@ type TCPTransport struct { packetCh chan *memberlist.Packet connCh chan net.Conn wg sync.WaitGroup + writeCh chan struct{} tcpListeners []net.Listener tlsConfig *tls.Config @@ -124,6 +129,7 @@ func NewTCPTransport(config TCPTransportConfig, logger log.Logger, registerer pr logger: log.With(logger, "component", "memberlist TCPTransport"), packetCh: make(chan *memberlist.Packet), connCh: make(chan net.Conn), + writeCh: make(chan struct{}, config.MaxConcurrentWrites), } var err error @@ -426,7 +432,15 @@ func (t *TCPTransport) getAdvertisedAddr() string { func (t *TCPTransport) WriteTo(b []byte, addr string) (time.Time, error) { t.sentPackets.Inc() t.sentPacketsBytes.Add(float64(len(b))) + t.writeCh <- struct{}{} + go func() { + defer func() { <-t.writeCh }() + t.writeToAsync(b, addr) + }() + return time.Now(), nil +} +func (t *TCPTransport) writeToAsync(b []byte, addr string) { err := t.writeTo(b, addr) if err != nil { t.sentPacketsErrors.Inc() @@ -441,10 +455,7 @@ func (t *TCPTransport) WriteTo(b []byte, addr string) (time.Time, error) { // WriteTo is used to send "UDP" packets. Since we use TCP, we can detect more errors, // but memberlist library doesn't seem to cope with that very well. That is why we return nil instead. - return time.Now(), nil } - - return time.Now(), nil } func (t *TCPTransport) writeTo(b []byte, addr string) error { From 516a2ce358dfa8f02505d8e391a3617b35f1b07f Mon Sep 17 00:00:00 2001 From: Julien Duchesne Date: Wed, 2 Oct 2024 18:58:03 -0400 Subject: [PATCH 2/5] Try to make this PR ready to go - Create goroutines and keep them while the TCPTransport is alive. End them on the `Shutdown` function - Add `TestTCPTransportWriteToUnreachableAddr` test to check that writing is not blocking anymore (without this PR, it takes `writeCt * timeout` to run and it fails) --- kv/memberlist/tcp_transport.go | 76 ++++++++++++++++++++--------- kv/memberlist/tcp_transport_test.go | 58 ++++++++++++++++++++++ 2 files changed, 110 insertions(+), 24 deletions(-) diff --git a/kv/memberlist/tcp_transport.go b/kv/memberlist/tcp_transport.go index abd7d1c90..c74e50d17 100644 --- a/kv/memberlist/tcp_transport.go +++ b/kv/memberlist/tcp_transport.go @@ -76,13 +76,18 @@ func (cfg *TCPTransportConfig) RegisterFlagsWithPrefix(f *flag.FlagSet, prefix s f.IntVar(&cfg.BindPort, prefix+"memberlist.bind-port", 7946, "Port to listen on for gossip messages.") f.DurationVar(&cfg.PacketDialTimeout, prefix+"memberlist.packet-dial-timeout", 2*time.Second, "Timeout used when connecting to other nodes to send packet.") f.DurationVar(&cfg.PacketWriteTimeout, prefix+"memberlist.packet-write-timeout", 5*time.Second, "Timeout for writing 'packet' data.") - f.IntVar(&cfg.MaxConcurrentWrites, prefix+"memberlist.max-concurrent-writes", 1, "Maximum number of concurrent writes to other nodes.") + f.IntVar(&cfg.MaxConcurrentWrites, prefix+"memberlist.max-concurrent-writes", 3, "Maximum number of concurrent writes to other nodes.") f.BoolVar(&cfg.TransportDebug, prefix+"memberlist.transport-debug", false, "Log debug transport messages. Note: global log.level must be at debug level as well.") f.BoolVar(&cfg.TLSEnabled, prefix+"memberlist.tls-enabled", false, "Enable TLS on the memberlist transport layer.") cfg.TLS.RegisterFlagsWithPrefix(prefix+"memberlist", f) } +type writeRequest struct { + b []byte + addr string +} + // TCPTransport is a memberlist.Transport implementation that uses TCP for both packet and stream // operations ("packet" and "stream" are terms used by memberlist). // It uses a new TCP connections for each operation. There is no connection reuse. @@ -92,10 +97,13 @@ type TCPTransport struct { packetCh chan *memberlist.Packet connCh chan net.Conn wg sync.WaitGroup - writeCh chan struct{} tcpListeners []net.Listener tlsConfig *tls.Config + writeMu sync.RWMutex + writeCh chan writeRequest + writeWG sync.WaitGroup + shutdown atomic.Int32 advertiseMu sync.RWMutex @@ -124,12 +132,20 @@ func NewTCPTransport(config TCPTransportConfig, logger log.Logger, registerer pr // Build out the new transport. var ok bool + concurrentWrites := config.MaxConcurrentWrites + if concurrentWrites <= 0 { + concurrentWrites = 1 + } t := TCPTransport{ cfg: config, logger: log.With(logger, "component", "memberlist TCPTransport"), packetCh: make(chan *memberlist.Packet), connCh: make(chan net.Conn), - writeCh: make(chan struct{}, config.MaxConcurrentWrites), + writeCh: make(chan writeRequest), + } + + for i := 0; i < concurrentWrites; i++ { + go t.writeWorker() } var err error @@ -430,31 +446,34 @@ func (t *TCPTransport) getAdvertisedAddr() string { // WriteTo is a packet-oriented interface that fires off the given // payload to the given address. func (t *TCPTransport) WriteTo(b []byte, addr string) (time.Time, error) { - t.sentPackets.Inc() - t.sentPacketsBytes.Add(float64(len(b))) - t.writeCh <- struct{}{} - go func() { - defer func() { <-t.writeCh }() - t.writeToAsync(b, addr) - }() + if t.shutdown.Load() == 1 { + return time.Time{}, errors.New("transport is shutting down") + } + t.writeMu.RLock() + defer t.writeMu.RUnlock() + t.writeWG.Add(1) + t.writeCh <- writeRequest{b: b, addr: addr} return time.Now(), nil } -func (t *TCPTransport) writeToAsync(b []byte, addr string) { - err := t.writeTo(b, addr) - if err != nil { - t.sentPacketsErrors.Inc() +func (t *TCPTransport) writeWorker() { + for req := range t.writeCh { + b, addr := req.b, req.addr + t.sentPackets.Inc() + t.sentPacketsBytes.Add(float64(len(b))) + err := t.writeTo(b, addr) + if err != nil { + t.sentPacketsErrors.Inc() - logLevel := level.Warn(t.logger) - if strings.Contains(err.Error(), "connection refused") { - // The connection refused is a common error that could happen during normal operations when a node - // shutdown (or crash). It shouldn't be considered a warning condition on the sender side. - logLevel = t.debugLog() + logLevel := level.Warn(t.logger) + if strings.Contains(err.Error(), "connection refused") { + // The connection refused is a common error that could happen during normal operations when a node + // shutdown (or crash). It shouldn't be considered a warning condition on the sender side. + logLevel = t.debugLog() + } + logLevel.Log("msg", "WriteTo failed", "addr", addr, "err", err) } - logLevel.Log("msg", "WriteTo failed", "addr", addr, "err", err) - - // WriteTo is used to send "UDP" packets. Since we use TCP, we can detect more errors, - // but memberlist library doesn't seem to cope with that very well. That is why we return nil instead. + t.writeWG.Done() } } @@ -570,9 +589,12 @@ func (t *TCPTransport) StreamCh() <-chan net.Conn { // Shutdown is called when memberlist is shutting down; this gives the // transport a chance to clean up any listeners. +// This will avoid log spam about errors when we shut down. func (t *TCPTransport) Shutdown() error { // This will avoid log spam about errors when we shut down. - t.shutdown.Store(1) + if old := t.shutdown.Swap(1); old == 1 { + return nil // already shut down + } // Rip through all the connections and shut them down. for _, conn := range t.tcpListeners { @@ -581,6 +603,12 @@ func (t *TCPTransport) Shutdown() error { // Block until all the listener threads have died. t.wg.Wait() + + // Wait until the write channel is empty and close it (to end the writeWorker goroutines). + t.writeMu.Lock() + defer t.writeMu.Unlock() + t.writeWG.Wait() + close(t.writeCh) return nil } diff --git a/kv/memberlist/tcp_transport_test.go b/kv/memberlist/tcp_transport_test.go index 310e11ecb..282bbc693 100644 --- a/kv/memberlist/tcp_transport_test.go +++ b/kv/memberlist/tcp_transport_test.go @@ -1,7 +1,10 @@ package memberlist import ( + "net" + "strings" "testing" + "time" "github.com/go-kit/log" "github.com/prometheus/client_golang/prometheus" @@ -9,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/grafana/dskit/concurrency" + "github.com/grafana/dskit/crypto/tls" "github.com/grafana/dskit/flagext" ) @@ -51,6 +55,8 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T _, err = transport.WriteTo([]byte("test"), testData.remoteAddr) require.NoError(t, err) + require.NoError(t, transport.Shutdown()) + if testData.expectedLogs != "" { assert.Contains(t, logs.String(), testData.expectedLogs) } @@ -61,6 +67,58 @@ func TestTCPTransport_WriteTo_ShouldNotLogAsWarningExpectedFailures(t *testing.T } } +type timeoutReader struct{} + +func (f *timeoutReader) ReadSecret(_ string) ([]byte, error) { + time.Sleep(1 * time.Second) + return nil, nil +} + +func TestTCPTransportWriteToUnreachableAddr(t *testing.T) { + writeCt := 50 + + // Listen for TCP connections on a random port + freePorts, err := getFreePorts(1) + require.NoError(t, err) + addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: freePorts[0]} + listener, err := net.ListenTCP("tcp", addr) + require.NoError(t, err) + defer listener.Close() + + logs := &concurrency.SyncBuffer{} + logger := log.NewLogfmtLogger(logs) + + cfg := TCPTransportConfig{} + flagext.DefaultValues(&cfg) + cfg.MaxConcurrentWrites = writeCt + cfg.PacketDialTimeout = 500 * time.Millisecond + transport, err := NewTCPTransport(cfg, logger, nil) + require.NoError(t, err) + + // Configure TLS only for writes. The dialing should timeout (because of the timeoutReader) + transport.cfg.TLSEnabled = true + transport.cfg.TLS = tls.ClientConfig{ + Reader: &timeoutReader{}, + CertPath: "fake", + KeyPath: "fake", + CAPath: "fake", + } + + timeStart := time.Now() + + for i := 0; i < writeCt; i++ { + _, err = transport.WriteTo([]byte("test"), addr.String()) + require.NoError(t, err) + } + + require.NoError(t, transport.Shutdown()) + + gotErrorCt := strings.Count(logs.String(), "context deadline exceeded") + assert.Equal(t, writeCt, gotErrorCt, "expected %d errors, got %d", writeCt, gotErrorCt) + assert.GreaterOrEqual(t, time.Since(timeStart), 500*time.Millisecond, "expected to take at least 500ms (timeout duration)") + assert.LessOrEqual(t, time.Since(timeStart), 2*time.Second, "expected to take less than 2s (timeout + a good margin), writing to unreachable addresses should not block") +} + func TestFinalAdvertiseAddr(t *testing.T) { tests := map[string]struct { advertiseAddr string From d88e163b9c4e630829f4ce63121b338e9fceeeaa Mon Sep 17 00:00:00 2001 From: Julien Duchesne Date: Thu, 3 Oct 2024 10:12:14 -0400 Subject: [PATCH 3/5] Add CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3288acdeb..bb5cb0fbd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -233,6 +233,7 @@ * [ENHANCEMENT] Adapt `metrics.SendSumOfGaugesPerTenant` to use `metrics.MetricOption`. #584 * [ENHANCEMENT] Cache: Add `.Add()` and `.Set()` methods to cache clients. #591 * [ENHANCEMENT] Cache: Add `.Advance()` methods to mock cache clients for easier testing of TTLs. #601 +* [ENHANCEMENT] Memberlist: Make `WriteTo` non-blocking. #525 * [CHANGE] Backoff: added `Backoff.ErrCause()` which is like `Backoff.Err()` but returns the context cause if backoff is terminated because the context has been canceled. #538 * [BUGFIX] spanlogger: Support multiple tenant IDs. #59 * [BUGFIX] Memberlist: fixed corrupted packets when sending compound messages with more than 255 messages or messages bigger than 64KB. #85 From 95d45b353c9b38ec00f5ef1cb162169b577ed8ce Mon Sep 17 00:00:00 2001 From: Julien Duchesne Date: Fri, 4 Oct 2024 12:44:46 -0400 Subject: [PATCH 4/5] Address PR comments - Rename CHANGELOG - Mutex lock on shutdown rather than write - Wait when workers are ended rather than for each write --- CHANGELOG.md | 2 +- kv/memberlist/tcp_transport.go | 37 ++++++++++++++++++++-------------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bb5cb0fbd..7857018e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -233,7 +233,7 @@ * [ENHANCEMENT] Adapt `metrics.SendSumOfGaugesPerTenant` to use `metrics.MetricOption`. #584 * [ENHANCEMENT] Cache: Add `.Add()` and `.Set()` methods to cache clients. #591 * [ENHANCEMENT] Cache: Add `.Advance()` methods to mock cache clients for easier testing of TTLs. #601 -* [ENHANCEMENT] Memberlist: Make `WriteTo` non-blocking. #525 +* [ENHANCEMENT] Memberlist: Add concurrency to the transport's WriteTo method. #525 * [CHANGE] Backoff: added `Backoff.ErrCause()` which is like `Backoff.Err()` but returns the context cause if backoff is terminated because the context has been canceled. #538 * [BUGFIX] spanlogger: Support multiple tenant IDs. #59 * [BUGFIX] Memberlist: fixed corrupted packets when sending compound messages with more than 255 messages or messages bigger than 64KB. #85 diff --git a/kv/memberlist/tcp_transport.go b/kv/memberlist/tcp_transport.go index c74e50d17..bd13d9393 100644 --- a/kv/memberlist/tcp_transport.go +++ b/kv/memberlist/tcp_transport.go @@ -19,7 +19,6 @@ import ( "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - "go.uber.org/atomic" dstls "github.com/grafana/dskit/crypto/tls" "github.com/grafana/dskit/flagext" @@ -100,11 +99,11 @@ type TCPTransport struct { tcpListeners []net.Listener tlsConfig *tls.Config - writeMu sync.RWMutex writeCh chan writeRequest writeWG sync.WaitGroup - shutdown atomic.Int32 + shutdown bool + shutdownMu sync.RWMutex advertiseMu sync.RWMutex advertiseAddr string @@ -145,6 +144,7 @@ func NewTCPTransport(config TCPTransportConfig, logger log.Logger, registerer pr } for i := 0; i < concurrentWrites; i++ { + t.writeWG.Add(1) go t.writeWorker() } @@ -227,7 +227,10 @@ func (t *TCPTransport) tcpListen(tcpLn net.Listener) { for { conn, err := tcpLn.Accept() if err != nil { - if s := t.shutdown.Load(); s == 1 { + t.shutdownMu.RLock() + isShuttingDown := t.shutdown + t.shutdownMu.RUnlock() + if isShuttingDown { break } @@ -446,17 +449,17 @@ func (t *TCPTransport) getAdvertisedAddr() string { // WriteTo is a packet-oriented interface that fires off the given // payload to the given address. func (t *TCPTransport) WriteTo(b []byte, addr string) (time.Time, error) { - if t.shutdown.Load() == 1 { + t.shutdownMu.RLock() + defer t.shutdownMu.RUnlock() // Unlock at the end to protect the chan + if t.shutdown { return time.Time{}, errors.New("transport is shutting down") } - t.writeMu.RLock() - defer t.writeMu.RUnlock() - t.writeWG.Add(1) t.writeCh <- writeRequest{b: b, addr: addr} return time.Now(), nil } func (t *TCPTransport) writeWorker() { + defer t.writeWG.Done() for req := range t.writeCh { b, addr := req.b, req.addr t.sentPackets.Inc() @@ -473,7 +476,6 @@ func (t *TCPTransport) writeWorker() { } logLevel.Log("msg", "WriteTo failed", "addr", addr, "err", err) } - t.writeWG.Done() } } @@ -591,24 +593,29 @@ func (t *TCPTransport) StreamCh() <-chan net.Conn { // transport a chance to clean up any listeners. // This will avoid log spam about errors when we shut down. func (t *TCPTransport) Shutdown() error { + t.shutdownMu.Lock() // This will avoid log spam about errors when we shut down. - if old := t.shutdown.Swap(1); old == 1 { + if t.shutdown { + t.shutdownMu.Unlock() return nil // already shut down } + // Set the shutdown flag and close the write channel. + t.shutdown = true + close(t.writeCh) + t.shutdownMu.Unlock() + // Rip through all the connections and shut them down. for _, conn := range t.tcpListeners { _ = conn.Close() } + // Wait until all write workers have finished. + t.writeWG.Wait() + // Block until all the listener threads have died. t.wg.Wait() - // Wait until the write channel is empty and close it (to end the writeWorker goroutines). - t.writeMu.Lock() - defer t.writeMu.Unlock() - t.writeWG.Wait() - close(t.writeCh) return nil } From ecd48530a5ce15c8a04d0269312058520f80f4e5 Mon Sep 17 00:00:00 2001 From: Julien Duchesne Date: Tue, 8 Oct 2024 18:13:38 -0400 Subject: [PATCH 5/5] Address PR comments - Move variables around - Add timeout before dropping requests. This prevents blocking on the `WriteTo` function --- kv/memberlist/tcp_transport.go | 27 +++++++++++++++---- kv/memberlist/tcp_transport_test.go | 41 +++++++++++++++++++++++++---- 2 files changed, 58 insertions(+), 10 deletions(-) diff --git a/kv/memberlist/tcp_transport.go b/kv/memberlist/tcp_transport.go index bd13d9393..2010d3919 100644 --- a/kv/memberlist/tcp_transport.go +++ b/kv/memberlist/tcp_transport.go @@ -54,6 +54,9 @@ type TCPTransportConfig struct { // Maximum number of concurrent writes to other nodes. MaxConcurrentWrites int `yaml:"max_concurrent_writes" category:"advanced"` + // Timeout for acquiring one of the concurrent write slots. + AcquireWriterTimeout time.Duration `yaml:"acquire_writer_timeout" category:"advanced"` + // Transport logs lots of messages at debug level, so it deserves an extra flag for turning it on TransportDebug bool `yaml:"-" category:"advanced"` @@ -76,6 +79,7 @@ func (cfg *TCPTransportConfig) RegisterFlagsWithPrefix(f *flag.FlagSet, prefix s f.DurationVar(&cfg.PacketDialTimeout, prefix+"memberlist.packet-dial-timeout", 2*time.Second, "Timeout used when connecting to other nodes to send packet.") f.DurationVar(&cfg.PacketWriteTimeout, prefix+"memberlist.packet-write-timeout", 5*time.Second, "Timeout for writing 'packet' data.") f.IntVar(&cfg.MaxConcurrentWrites, prefix+"memberlist.max-concurrent-writes", 3, "Maximum number of concurrent writes to other nodes.") + f.DurationVar(&cfg.AcquireWriterTimeout, prefix+"memberlist.acquire-writer-timeout", 250*time.Millisecond, "Timeout for acquiring one of the concurrent write slots. After this time, the message will be dropped.") f.BoolVar(&cfg.TransportDebug, prefix+"memberlist.transport-debug", false, "Log debug transport messages. Note: global log.level must be at debug level as well.") f.BoolVar(&cfg.TLSEnabled, prefix+"memberlist.tls-enabled", false, "Enable TLS on the memberlist transport layer.") @@ -99,11 +103,11 @@ type TCPTransport struct { tcpListeners []net.Listener tlsConfig *tls.Config - writeCh chan writeRequest - writeWG sync.WaitGroup - - shutdown bool shutdownMu sync.RWMutex + shutdown bool + writeCh chan writeRequest // this channel is protected by shutdownMu + + writeWG sync.WaitGroup advertiseMu sync.RWMutex advertiseAddr string @@ -454,7 +458,20 @@ func (t *TCPTransport) WriteTo(b []byte, addr string) (time.Time, error) { if t.shutdown { return time.Time{}, errors.New("transport is shutting down") } - t.writeCh <- writeRequest{b: b, addr: addr} + + // Send the packet to the write workers + // If this blocks for too long (as configured), abort and log an error. + select { + case <-time.After(t.cfg.AcquireWriterTimeout): + level.Warn(t.logger).Log("msg", "WriteTo failed to acquire a writer. Dropping message", "timeout", t.cfg.AcquireWriterTimeout, "addr", addr) + t.sentPacketsErrors.Inc() + // WriteTo is used to send "UDP" packets. Since we use TCP, we can detect more errors, + // but memberlist library doesn't seem to cope with that very well. That is why we return nil instead. + return time.Now(), nil + case t.writeCh <- writeRequest{b: b, addr: addr}: + // OK + } + return time.Now(), nil } diff --git a/kv/memberlist/tcp_transport_test.go b/kv/memberlist/tcp_transport_test.go index 282bbc693..1803c8280 100644 --- a/kv/memberlist/tcp_transport_test.go +++ b/kv/memberlist/tcp_transport_test.go @@ -3,6 +3,7 @@ package memberlist import ( "net" "strings" + "sync" "testing" "time" @@ -78,10 +79,7 @@ func TestTCPTransportWriteToUnreachableAddr(t *testing.T) { writeCt := 50 // Listen for TCP connections on a random port - freePorts, err := getFreePorts(1) - require.NoError(t, err) - addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: freePorts[0]} - listener, err := net.ListenTCP("tcp", addr) + listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) defer listener.Close() @@ -107,7 +105,7 @@ func TestTCPTransportWriteToUnreachableAddr(t *testing.T) { timeStart := time.Now() for i := 0; i < writeCt; i++ { - _, err = transport.WriteTo([]byte("test"), addr.String()) + _, err = transport.WriteTo([]byte("test"), listener.Addr().String()) require.NoError(t, err) } @@ -119,6 +117,39 @@ func TestTCPTransportWriteToUnreachableAddr(t *testing.T) { assert.LessOrEqual(t, time.Since(timeStart), 2*time.Second, "expected to take less than 2s (timeout + a good margin), writing to unreachable addresses should not block") } +func TestTCPTransportWriterAcquireTimeout(t *testing.T) { + // Listen for TCP connections on a random port + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + logs := &concurrency.SyncBuffer{} + logger := log.NewLogfmtLogger(logs) + + cfg := TCPTransportConfig{} + flagext.DefaultValues(&cfg) + cfg.MaxConcurrentWrites = 1 + cfg.AcquireWriterTimeout = 1 * time.Millisecond // very short timeout + transport, err := NewTCPTransport(cfg, logger, nil) + require.NoError(t, err) + + writeCt := 100 + var reqWg sync.WaitGroup + for i := 0; i < writeCt; i++ { + reqWg.Add(1) + go func() { + defer reqWg.Done() + transport.WriteTo([]byte("test"), listener.Addr().String()) // nolint:errcheck + }() + } + reqWg.Wait() + + require.NoError(t, transport.Shutdown()) + gotErrorCt := strings.Count(logs.String(), "WriteTo failed to acquire a writer. Dropping message") + assert.Less(t, gotErrorCt, writeCt, "expected to have less errors (%d) than total writes (%d). Some writes should pass.", gotErrorCt, writeCt) + assert.NotZero(t, gotErrorCt, "expected errors, got none") +} + func TestFinalAdvertiseAddr(t *testing.T) { tests := map[string]struct { advertiseAddr string