diff --git a/client/clients/tso/client.go b/client/clients/tso/client.go index 095f0e8dcb3..c2664789822 100644 --- a/client/clients/tso/client.go +++ b/client/clients/tso/client.go @@ -485,7 +485,7 @@ func (c *Cli) tryConnectToTSOWithProxy(ctx context.Context) error { addrTrim := tlsutil.TrimHTTPPrefix(addr) metrics.RequestForwarded.WithLabelValues(forwardedHostTrim, addrTrim).Set(1) } - c.conCtxMgr.StoreIfNotExist(ctx, addr, stream) + c.conCtxMgr.Store(ctx, addr, stream) continue } log.Error("[tso] create the tso stream failed", diff --git a/client/clients/tso/dispatcher_test.go b/client/clients/tso/dispatcher_test.go index 9280acfc257..7e5554c7c7b 100644 --- a/client/clients/tso/dispatcher_test.go +++ b/client/clients/tso/dispatcher_test.go @@ -70,7 +70,7 @@ func (m *mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context) bool } else { stream = m.createStream(ctx) } - m.conCtxMgr.StoreIfNotExist(ctx, mockStreamURL, stream) + m.conCtxMgr.Store(ctx, mockStreamURL, stream) return true } diff --git a/client/pkg/connectionctx/manager.go b/client/pkg/connectionctx/manager.go index 78afd85c7fc..bae95e963ab 100644 --- a/client/pkg/connectionctx/manager.go +++ b/client/pkg/connectionctx/manager.go @@ -26,7 +26,7 @@ type connectionCtx[T any] struct { Cancel context.CancelFunc // Current URL of the stream connection. StreamURL string - // Current stream to send gRPC requests. + // Current stream to send the gRPC requests. Stream T } @@ -51,44 +51,52 @@ func (c *Manager[T]) Exist(url string) bool { return ok } -// StoreIfNotExist is used to store the connection context if it does not exist before. -func (c *Manager[T]) StoreIfNotExist(ctx context.Context, url string, stream T) { - c.RWMutex.Lock() - defer c.RWMutex.Unlock() +// Store is used to store the connection context, `overwrite` is used to force the store operation +// no matter whether the connection context exists before, which is false by default. +func (c *Manager[T]) Store(ctx context.Context, url string, stream T, overwrite ...bool) { + c.Lock() + defer c.Unlock() + overwriteFlag := false + if len(overwrite) > 0 { + overwriteFlag = overwrite[0] + } _, ok := c.connectionCtxs[url] - if ok { + if !overwriteFlag && ok { return } + c.storeLocked(ctx, url, stream) +} + +func (c *Manager[T]) storeLocked(ctx context.Context, url string, stream T) { + c.releaseLocked(url) cctx, cancel := context.WithCancel(ctx) c.connectionCtxs[url] = &connectionCtx[T]{cctx, cancel, url, stream} } // ExclusivelyStore is used to store the connection context exclusively. It will release // all other connection contexts. `stream` is optional, if it is not provided, all -// connection contexts other than the given `url` will be cleared. +// connection contexts other than the given `url` will be released. func (c *Manager[T]) ExclusivelyStore(ctx context.Context, url string, stream ...T) { - c.RWMutex.Lock() - defer c.RWMutex.Unlock() + c.Lock() + defer c.Unlock() // Remove all other `connectionCtx`s. - for curURL := range c.connectionCtxs { - if curURL == url { - continue - } - c.releaseLocked(curURL) - } + c.gcLocked(func(curURL string) bool { + return curURL != url + }) if len(stream) == 0 { return } - // Release the old connection context if it exists. - c.releaseLocked(url) - cctx, cancel := context.WithCancel(ctx) - c.connectionCtxs[url] = &connectionCtx[T]{cctx, cancel, url, stream[0]} + c.storeLocked(ctx, url, stream[0]) } // GC is used to release all connection contexts that match the given condition. func (c *Manager[T]) GC(condition func(url string) bool) { - c.RWMutex.Lock() - defer c.RWMutex.Unlock() + c.Lock() + defer c.Unlock() + c.gcLocked(condition) +} + +func (c *Manager[T]) gcLocked(condition func(url string) bool) { for url := range c.connectionCtxs { if condition(url) { c.releaseLocked(url)