diff --git a/client/clients/tso/client.go b/client/clients/tso/client.go index c26dd25f2ad..d24dba52394 100644 --- a/client/clients/tso/client.go +++ b/client/clients/tso/client.go @@ -36,6 +36,7 @@ import ( "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/metrics" "github.com/tikv/pd/client/opt" + cctx "github.com/tikv/pd/client/pkg/connectionctx" "github.com/tikv/pd/client/pkg/utils/grpcutil" "github.com/tikv/pd/client/pkg/utils/tlsutil" sd "github.com/tikv/pd/client/servicediscovery" @@ -80,7 +81,9 @@ type Cli struct { svcDiscovery sd.ServiceDiscovery tsoStreamBuilderFactory // leaderURL is the URL of the TSO leader. - leaderURL atomic.Value + leaderURL atomic.Value + conCtxMgr *cctx.Manager[*tsoStream] + updateConCtxsCh chan struct{} // tsoReqPool is the pool to recycle `*tsoRequest`. tsoReqPool *sync.Pool @@ -100,6 +103,8 @@ func NewClient( option: option, svcDiscovery: svcDiscovery, tsoStreamBuilderFactory: factory, + conCtxMgr: cctx.NewManager[*tsoStream](), + updateConCtxsCh: make(chan struct{}, 1), tsoReqPool: &sync.Pool{ New: func() any { return &Request{ @@ -122,6 +127,8 @@ func (c *Cli) getOption() *opt.Option { return c.option } func (c *Cli) getServiceDiscovery() sd.ServiceDiscovery { return c.svcDiscovery } +func (c *Cli) getConnectionCtxMgr() *cctx.Manager[*tsoStream] { return c.conCtxMgr } + func (c *Cli) getDispatcher() *tsoDispatcher { return c.dispatcher.Load() } @@ -133,6 +140,8 @@ func (c *Cli) GetRequestPool() *sync.Pool { // Setup initializes the TSO client. func (c *Cli) Setup() { + // Daemon goroutine to update the connectionCtxs periodically and handle the `connectionCtxs` update event. + go c.connectionCtxsUpdater() if err := c.svcDiscovery.CheckMemberChanged(); err != nil { log.Warn("[tso] failed to check member changed", errs.ZapError(err)) } @@ -154,9 +163,12 @@ func (c *Cli) Close() { log.Info("[tso] tso client is closed") } -// scheduleUpdateTSOConnectionCtxs update the TSO connection contexts. +// scheduleUpdateTSOConnectionCtxs schedules the update of the TSO connection contexts. func (c *Cli) scheduleUpdateTSOConnectionCtxs() { - c.getDispatcher().scheduleUpdateConnectionCtxs() + select { + case c.updateConCtxsCh <- struct{}{}: + default: + } } // GetTSORequest gets a TSO request from the pool. @@ -231,25 +243,69 @@ func (c *Cli) backupClientConn() (*grpc.ClientConn, string) { return nil, "" } -// tsoConnectionContext is used to store the context of a TSO stream connection. -type tsoConnectionContext struct { - ctx context.Context - cancel context.CancelFunc - // Current URL of the stream connection. - streamURL string - // Current stream to send gRPC requests. - stream *tsoStream +// connectionCtxsUpdater updates the `connectionCtxs` regularly. +func (c *Cli) connectionCtxsUpdater() { + log.Info("[tso] start tso connection contexts updater") + + var updateTicker = &time.Ticker{} + setNewUpdateTicker := func(interval time.Duration) { + if updateTicker.C != nil { + updateTicker.Stop() + } + if interval == 0 { + updateTicker = &time.Ticker{} + } else { + updateTicker = time.NewTicker(interval) + } + } + // If the TSO Follower Proxy is enabled, set the update interval to the member update interval. + if c.option.GetEnableTSOFollowerProxy() { + setNewUpdateTicker(sd.MemberUpdateInterval) + } + // Set to nil before returning to ensure that the existing ticker can be GC. + defer setNewUpdateTicker(0) + + ctx, cancel := context.WithCancel(c.ctx) + defer cancel() + for { + c.updateConnectionCtxs(ctx) + select { + case <-ctx.Done(): + log.Info("[tso] exit tso connection contexts updater") + return + case <-c.option.EnableTSOFollowerProxyCh: + enableTSOFollowerProxy := c.option.GetEnableTSOFollowerProxy() + log.Info("[tso] tso follower proxy status changed", + zap.Bool("enable", enableTSOFollowerProxy)) + if enableTSOFollowerProxy && updateTicker.C == nil { + // Because the TSO Follower Proxy is enabled, + // the periodic check needs to be performed. + setNewUpdateTicker(sd.MemberUpdateInterval) + failpoint.Inject("speedUpTsoDispatcherUpdateInterval", func() { + setNewUpdateTicker(10 * time.Millisecond) + }) + } else if !enableTSOFollowerProxy && updateTicker.C != nil { + // Because the TSO Follower Proxy is disabled, + // the periodic check needs to be turned off. + setNewUpdateTicker(0) + } + case <-updateTicker.C: + // Triggered periodically when the TSO Follower Proxy is enabled. + case <-c.updateConCtxsCh: + // Triggered by the leader/follower change. + } + } } // updateConnectionCtxs will choose the proper way to update the connections. // It will return a bool to indicate whether the update is successful. -func (c *Cli) updateConnectionCtxs(ctx context.Context, connectionCtxs *sync.Map) bool { +func (c *Cli) updateConnectionCtxs(ctx context.Context) bool { // Normal connection creating, it will be affected by the `enableForwarding`. createTSOConnection := c.tryConnectToTSO if c.option.GetEnableTSOFollowerProxy() { createTSOConnection = c.tryConnectToTSOWithProxy } - if err := createTSOConnection(ctx, connectionCtxs); err != nil { + if err := createTSOConnection(ctx); err != nil { log.Error("[tso] update connection contexts failed", errs.ZapError(err)) return false } @@ -260,30 +316,13 @@ func (c *Cli) updateConnectionCtxs(ctx context.Context, connectionCtxs *sync.Map // and enableForwarding is true, it will create a new connection to a follower to do the forwarding, // while a new daemon will be created also to switch back to a normal leader connection ASAP the // connection comes back to normal. -func (c *Cli) tryConnectToTSO( - ctx context.Context, - connectionCtxs *sync.Map, -) error { +func (c *Cli) tryConnectToTSO(ctx context.Context) error { var ( - networkErrNum uint64 - err error - stream *tsoStream - url string - cc *grpc.ClientConn - updateAndClear = func(newURL string, connectionCtx *tsoConnectionContext) { - // Only store the `connectionCtx` if it does not exist before. - if connectionCtx != nil { - connectionCtxs.LoadOrStore(newURL, connectionCtx) - } - // Remove all other `connectionCtx`s. - connectionCtxs.Range(func(url, cc any) bool { - if url.(string) != newURL { - cc.(*tsoConnectionContext).cancel() - connectionCtxs.Delete(url) - } - return true - }) - } + networkErrNum uint64 + err error + stream *tsoStream + url string + cc *grpc.ClientConn ) ticker := time.NewTicker(constants.RetryInterval) @@ -292,9 +331,9 @@ func (c *Cli) tryConnectToTSO( for range constants.MaxRetryTimes { c.svcDiscovery.ScheduleCheckMemberChanged() cc, url = c.getTSOLeaderClientConn() - if _, ok := connectionCtxs.Load(url); ok { + if c.conCtxMgr.Exist(url) { // Just trigger the clean up of the stale connection contexts. - updateAndClear(url, nil) + c.conCtxMgr.CleanAllAndStore(ctx, url) return nil } if cc != nil { @@ -305,7 +344,7 @@ func (c *Cli) tryConnectToTSO( err = status.New(codes.Unavailable, "unavailable").Err() }) if stream != nil && err == nil { - updateAndClear(url, &tsoConnectionContext{cctx, cancel, url, stream}) + c.conCtxMgr.CleanAllAndStore(ctx, url, stream) return nil } @@ -348,9 +387,9 @@ func (c *Cli) tryConnectToTSO( forwardedHostTrim := tlsutil.TrimHTTPPrefix(forwardedHost) addr := tlsutil.TrimHTTPPrefix(backupURL) // the goroutine is used to check the network and change back to the original stream - go c.checkLeader(ctx, cancel, forwardedHostTrim, addr, url, updateAndClear) + go c.checkLeader(ctx, cancel, forwardedHostTrim, addr, url) metrics.RequestForwarded.WithLabelValues(forwardedHostTrim, addr).Set(1) - updateAndClear(backupURL, &tsoConnectionContext{cctx, cancel, backupURL, stream}) + c.conCtxMgr.CleanAllAndStore(ctx, backupURL, stream) return nil } cancel() @@ -363,7 +402,6 @@ func (c *Cli) checkLeader( ctx context.Context, forwardCancel context.CancelFunc, forwardedHostTrim, addr, url string, - updateAndClear func(newAddr string, connectionCtx *tsoConnectionContext), ) { defer func() { // cancel the forward stream @@ -396,7 +434,7 @@ func (c *Cli) checkLeader( stream, err := c.tsoStreamBuilderFactory.makeBuilder(cc).build(cctx, cancel, c.option.Timeout) if err == nil && stream != nil { log.Info("[tso] recover the original tso stream since the network has become normal", zap.String("url", url)) - updateAndClear(url, &tsoConnectionContext{cctx, cancel, url, stream}) + c.conCtxMgr.CleanAllAndStore(ctx, url, stream) return } } @@ -413,10 +451,7 @@ func (c *Cli) checkLeader( // tryConnectToTSOWithProxy will create multiple streams to all the service endpoints to work as // a TSO proxy to reduce the pressure of the main serving service endpoint. -func (c *Cli) tryConnectToTSOWithProxy( - ctx context.Context, - connectionCtxs *sync.Map, -) error { +func (c *Cli) tryConnectToTSOWithProxy(ctx context.Context) error { tsoStreamBuilders := c.getAllTSOStreamBuilders() leaderAddr := c.svcDiscovery.GetServingURL() forwardedHost := c.getLeaderURL() @@ -424,20 +459,17 @@ func (c *Cli) tryConnectToTSOWithProxy( return errors.Errorf("cannot find the tso leader") } // GC the stale one. - connectionCtxs.Range(func(addr, cc any) bool { - addrStr := addr.(string) - if _, ok := tsoStreamBuilders[addrStr]; !ok { + c.conCtxMgr.GC(func(addr string) bool { + _, ok := tsoStreamBuilders[addr] + if !ok { log.Info("[tso] remove the stale tso stream", - zap.String("addr", addrStr)) - cc.(*tsoConnectionContext).cancel() - connectionCtxs.Delete(addr) + zap.String("addr", addr)) } - return true + return !ok }) // Update the missing one. for addr, tsoStreamBuilder := range tsoStreamBuilders { - _, ok := connectionCtxs.Load(addr) - if ok { + if c.conCtxMgr.Exist(addr) { continue } log.Info("[tso] try to create tso stream", zap.String("addr", addr)) @@ -456,7 +488,7 @@ func (c *Cli) tryConnectToTSOWithProxy( addrTrim := tlsutil.TrimHTTPPrefix(addr) metrics.RequestForwarded.WithLabelValues(forwardedHostTrim, addrTrim).Set(1) } - connectionCtxs.Store(addr, &tsoConnectionContext{cctx, cancel, addr, stream}) + c.conCtxMgr.Store(ctx, addr, stream) continue } log.Error("[tso] create the tso stream failed", diff --git a/client/clients/tso/dispatcher.go b/client/clients/tso/dispatcher.go index 9ee44db27e4..1cc2b2aa940 100644 --- a/client/clients/tso/dispatcher.go +++ b/client/clients/tso/dispatcher.go @@ -18,7 +18,6 @@ import ( "context" "fmt" "math" - "math/rand" "runtime/trace" "sync" "sync/atomic" @@ -36,33 +35,13 @@ import ( "github.com/tikv/pd/client/metrics" "github.com/tikv/pd/client/opt" "github.com/tikv/pd/client/pkg/batch" + cctx "github.com/tikv/pd/client/pkg/connectionctx" + "github.com/tikv/pd/client/pkg/deadline" "github.com/tikv/pd/client/pkg/retry" - "github.com/tikv/pd/client/pkg/utils/timerutil" "github.com/tikv/pd/client/pkg/utils/tsoutil" sd "github.com/tikv/pd/client/servicediscovery" ) -// deadline is used to control the TS request timeout manually, -// it will be sent to the `tsDeadlineCh` to be handled by the `watchTSDeadline` goroutine. -type deadline struct { - timer *time.Timer - done chan struct{} - cancel context.CancelFunc -} - -func newTSDeadline( - timeout time.Duration, - done chan struct{}, - cancel context.CancelFunc, -) *deadline { - timer := timerutil.GlobalTimerPool.Get(timeout) - return &deadline{ - timer: timer, - done: done, - cancel: cancel, - } -} - type tsoInfo struct { tsoServer string reqKeyspaceGroupID uint32 @@ -76,7 +55,8 @@ type tsoInfo struct { type tsoServiceProvider interface { getOption() *opt.Option getServiceDiscovery() sd.ServiceDiscovery - updateConnectionCtxs(ctx context.Context, connectionCtxs *sync.Map) bool + getConnectionCtxMgr() *cctx.Manager[*tsoStream] + updateConnectionCtxs(ctx context.Context) bool } const dispatcherCheckRPCConcurrencyInterval = time.Second * 5 @@ -85,12 +65,10 @@ type tsoDispatcher struct { ctx context.Context cancel context.CancelFunc - provider tsoServiceProvider - // URL -> *connectionContext - connectionCtxs *sync.Map - tsoRequestCh chan *Request - tsDeadlineCh chan *deadline - latestTSOInfo atomic.Pointer[tsoInfo] + provider tsoServiceProvider + tsoRequestCh chan *Request + deadlineWatcher *deadline.Watcher + latestTSOInfo atomic.Pointer[tsoInfo] // For reusing `*batchController` objects batchBufferPool *sync.Pool @@ -102,8 +80,6 @@ type tsoDispatcher struct { lastCheckConcurrencyTime time.Time tokenCount int rpcConcurrency int - - updateConnectionCtxsCh chan struct{} } func newTSODispatcher( @@ -122,12 +98,11 @@ func newTSODispatcher( tokenCh := make(chan struct{}, tokenChCapacity) td := &tsoDispatcher{ - ctx: dispatcherCtx, - cancel: dispatcherCancel, - provider: provider, - connectionCtxs: &sync.Map{}, - tsoRequestCh: tsoRequestCh, - tsDeadlineCh: make(chan *deadline, tokenChCapacity), + ctx: dispatcherCtx, + cancel: dispatcherCancel, + provider: provider, + tsoRequestCh: tsoRequestCh, + deadlineWatcher: deadline.NewWatcher(dispatcherCtx, tokenChCapacity, "tso"), batchBufferPool: &sync.Pool{ New: func() any { return batch.NewController[*Request]( @@ -137,44 +112,11 @@ func newTSODispatcher( ) }, }, - tokenCh: tokenCh, - updateConnectionCtxsCh: make(chan struct{}, 1), + tokenCh: tokenCh, } - go td.watchTSDeadline() return td } -func (td *tsoDispatcher) watchTSDeadline() { - log.Info("[tso] start tso deadline watcher") - defer log.Info("[tso] exit tso deadline watcher") - for { - select { - case d := <-td.tsDeadlineCh: - select { - case <-d.timer.C: - log.Error("[tso] tso request is canceled due to timeout", - errs.ZapError(errs.ErrClientGetTSOTimeout)) - d.cancel() - timerutil.GlobalTimerPool.Put(d.timer) - case <-d.done: - timerutil.GlobalTimerPool.Put(d.timer) - case <-td.ctx.Done(): - timerutil.GlobalTimerPool.Put(d.timer) - return - } - case <-td.ctx.Done(): - return - } - } -} - -func (td *tsoDispatcher) scheduleUpdateConnectionCtxs() { - select { - case td.updateConnectionCtxsCh <- struct{}{}: - default: - } -} - func (td *tsoDispatcher) revokePendingRequests(err error) { for range len(td.tsoRequestCh) { req := <-td.tsoRequestCh @@ -196,9 +138,9 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { var ( ctx = td.ctx provider = td.provider - svcDiscovery = provider.getServiceDiscovery() option = provider.getOption() - connectionCtxs = td.connectionCtxs + svcDiscovery = provider.getServiceDiscovery() + conCtxMgr = provider.getConnectionCtxMgr() tsoBatchController *batch.Controller[*Request] ) @@ -207,10 +149,7 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { defer func() { log.Info("[tso] exit tso dispatcher") // Cancel all connections. - connectionCtxs.Range(func(_, cc any) bool { - cc.(*tsoConnectionContext).cancel() - return true - }) + conCtxMgr.ReleaseAll() if tsoBatchController != nil && tsoBatchController.GetCollectedRequestCount() != 0 { // If you encounter this failure, please check the stack in the logs to see if it's a panic. log.Fatal("batched tso requests not cleared when exiting the tso dispatcher loop", zap.Any("panic", recover())) @@ -219,8 +158,6 @@ func (td *tsoDispatcher) handleDispatcher(wg *sync.WaitGroup) { td.revokePendingRequests(tsoErr) wg.Done() }() - // Daemon goroutine to update the connectionCtxs periodically and handle the `connectionCtxs` update event. - go td.connectionCtxsUpdater() var ( err error @@ -291,14 +228,14 @@ tsoBatchLoop: // Choose a stream to send the TSO gRPC request. streamChoosingLoop: for { - connectionCtx := chooseStream(connectionCtxs) + connectionCtx := conCtxMgr.GetConnectionCtx() if connectionCtx != nil { - streamCtx, cancel, streamURL, stream = connectionCtx.ctx, connectionCtx.cancel, connectionCtx.streamURL, connectionCtx.stream + streamCtx, cancel, streamURL, stream = connectionCtx.Ctx, connectionCtx.Cancel, connectionCtx.StreamURL, connectionCtx.Stream } // Check stream and retry if necessary. if stream == nil { log.Info("[tso] tso stream is not ready") - if provider.updateConnectionCtxs(ctx, connectionCtxs) { + if provider.updateConnectionCtxs(ctx) { continue streamChoosingLoop } timer := time.NewTimer(constants.RetryInterval) @@ -325,8 +262,7 @@ tsoBatchLoop: case <-streamCtx.Done(): log.Info("[tso] tso stream is canceled", zap.String("stream-url", streamURL)) // Set `stream` to nil and remove this stream from the `connectionCtxs` due to being canceled. - connectionCtxs.Delete(streamURL) - cancel() + conCtxMgr.Release(streamURL) stream = nil continue default: @@ -334,7 +270,7 @@ tsoBatchLoop: // Check if any error has occurred on this stream when receiving asynchronously. if err = stream.GetRecvError(); err != nil { - exit := !td.handleProcessRequestError(ctx, bo, streamURL, cancel, err) + exit := !td.handleProcessRequestError(ctx, bo, conCtxMgr, streamURL, err) stream = nil if exit { td.cancelCollectedRequests(tsoBatchController, invalidStreamID, errors.WithStack(ctx.Err())) @@ -396,14 +332,11 @@ tsoBatchLoop: } } - done := make(chan struct{}) - dl := newTSDeadline(option.Timeout, done, cancel) - select { - case <-ctx.Done(): + done := td.deadlineWatcher.Start(ctx, option.Timeout, cancel) + if done == nil { // Finish the collected requests if the context is canceled. td.cancelCollectedRequests(tsoBatchController, invalidStreamID, errors.WithStack(ctx.Err())) return - case td.tsDeadlineCh <- dl: } // processRequests guarantees that the collected requests could be finished properly. err = td.processRequests(stream, tsoBatchController, done) @@ -419,7 +352,7 @@ tsoBatchLoop: // reused in the next loop safely. tsoBatchController = nil } else { - exit := !td.handleProcessRequestError(ctx, bo, streamURL, cancel, err) + exit := !td.handleProcessRequestError(ctx, bo, conCtxMgr, streamURL, err) stream = nil if exit { return @@ -430,113 +363,44 @@ tsoBatchLoop: // handleProcessRequestError handles errors occurs when trying to process a TSO RPC request for the dispatcher loop. // Returns true if the dispatcher loop is ok to continue. Otherwise, the dispatcher loop should be exited. -func (td *tsoDispatcher) handleProcessRequestError(ctx context.Context, bo *retry.Backoffer, streamURL string, streamCancelFunc context.CancelFunc, err error) bool { +func (td *tsoDispatcher) handleProcessRequestError( + ctx context.Context, + bo *retry.Backoffer, + conCtxMgr *cctx.Manager[*tsoStream], + streamURL string, + err error, +) bool { + log.Error("[tso] getTS error after processing requests", + zap.String("stream-url", streamURL), + zap.Error(errs.ErrClientGetTSO.FastGenByArgs(err.Error()))) + select { case <-ctx.Done(): return false default: } + // Release this stream from the manager due to error. + conCtxMgr.Release(streamURL) + // Update the member list to ensure the latest topology is used before the next batch. svcDiscovery := td.provider.getServiceDiscovery() - - svcDiscovery.ScheduleCheckMemberChanged() - log.Error("[tso] getTS error after processing requests", - zap.String("stream-url", streamURL), - zap.Error(errs.ErrClientGetTSO.FastGenByArgs(err.Error()))) - // Set `stream` to nil and remove this stream from the `connectionCtxs` due to error. - td.connectionCtxs.Delete(streamURL) - streamCancelFunc() - // Because ScheduleCheckMemberChanged is asynchronous, if the leader changes, we better call `updateMember` ASAP. if errs.IsLeaderChange(err) { + // If the leader changed, we better call `CheckMemberChanged` blockingly to + // ensure the next round of TSO requests can be sent to the new leader. if err := bo.Exec(ctx, svcDiscovery.CheckMemberChanged); err != nil { - select { - case <-ctx.Done(): - return false - default: - } + log.Error("[tso] check member changed error after the leader changed", zap.Error(err)) } - // Because the TSO Follower Proxy could be configured online, - // If we change it from on -> off, background updateConnectionCtxs - // will cancel the current stream, then the EOF error caused by cancel() - // should not trigger the updateConnectionCtxs here. - // So we should only call it when the leader changes. - td.provider.updateConnectionCtxs(ctx, td.connectionCtxs) - } - - return true -} - -// updateConnectionCtxs updates the `connectionCtxs` regularly. -func (td *tsoDispatcher) connectionCtxsUpdater() { - var ( - ctx = td.ctx - connectionCtxs = td.connectionCtxs - provider = td.provider - option = td.provider.getOption() - updateTicker = &time.Ticker{} - ) - - log.Info("[tso] start tso connection contexts updater") - setNewUpdateTicker := func(interval time.Duration) { - if updateTicker.C != nil { - updateTicker.Stop() - } - if interval == 0 { - updateTicker = &time.Ticker{} - } else { - updateTicker = time.NewTicker(interval) - } - } - // If the TSO Follower Proxy is enabled, set the update interval to the member update interval. - if option.GetEnableTSOFollowerProxy() { - setNewUpdateTicker(sd.MemberUpdateInterval) + } else { + // For other errors, we can just schedule a member change check asynchronously. + svcDiscovery.ScheduleCheckMemberChanged() } - // Set to nil before returning to ensure that the existing ticker can be GC. - defer setNewUpdateTicker(0) - for { - provider.updateConnectionCtxs(ctx, connectionCtxs) - select { - case <-ctx.Done(): - log.Info("[tso] exit tso connection contexts updater") - return - case <-option.EnableTSOFollowerProxyCh: - enableTSOFollowerProxy := option.GetEnableTSOFollowerProxy() - log.Info("[tso] tso follower proxy status changed", - zap.Bool("enable", enableTSOFollowerProxy)) - if enableTSOFollowerProxy && updateTicker.C == nil { - // Because the TSO Follower Proxy is enabled, - // the periodic check needs to be performed. - setNewUpdateTicker(sd.MemberUpdateInterval) - failpoint.Inject("speedUpTsoDispatcherUpdateInterval", func() { - setNewUpdateTicker(10 * time.Millisecond) - }) - } else if !enableTSOFollowerProxy && updateTicker.C != nil { - // Because the TSO Follower Proxy is disabled, - // the periodic check needs to be turned off. - setNewUpdateTicker(0) - } - case <-updateTicker.C: - // Triggered periodically when the TSO Follower Proxy is enabled. - case <-td.updateConnectionCtxsCh: - // Triggered by the leader/follower change. - } - } -} - -// chooseStream uses the reservoir sampling algorithm to randomly choose a connection. -// connectionCtxs will only have only one stream to choose when the TSO Follower Proxy is off. -func chooseStream(connectionCtxs *sync.Map) (connectionCtx *tsoConnectionContext) { - idx := 0 - connectionCtxs.Range(func(_, cc any) bool { - j := rand.Intn(idx + 1) - if j < 1 { - connectionCtx = cc.(*tsoConnectionContext) - } - idx++ + select { + case <-ctx.Done(): + return false + default: return true - }) - return connectionCtx + } } // processRequests sends the RPC request for the batch. It's guaranteed that after calling this function, requests diff --git a/client/clients/tso/dispatcher_test.go b/client/clients/tso/dispatcher_test.go index cefc53f3944..7e5554c7c7b 100644 --- a/client/clients/tso/dispatcher_test.go +++ b/client/clients/tso/dispatcher_test.go @@ -30,19 +30,21 @@ import ( "github.com/pingcap/log" "github.com/tikv/pd/client/opt" + cctx "github.com/tikv/pd/client/pkg/connectionctx" sd "github.com/tikv/pd/client/servicediscovery" ) type mockTSOServiceProvider struct { option *opt.Option createStream func(ctx context.Context) *tsoStream - updateConnMu sync.Mutex + conCtxMgr *cctx.Manager[*tsoStream] } func newMockTSOServiceProvider(option *opt.Option, createStream func(ctx context.Context) *tsoStream) *mockTSOServiceProvider { return &mockTSOServiceProvider{ option: option, createStream: createStream, + conCtxMgr: cctx.NewManager[*tsoStream](), } } @@ -54,24 +56,21 @@ func (*mockTSOServiceProvider) getServiceDiscovery() sd.ServiceDiscovery { return sd.NewMockPDServiceDiscovery([]string{mockStreamURL}, nil) } -func (m *mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context, connectionCtxs *sync.Map) bool { - // Avoid concurrent updating in the background updating goroutine and active updating in the dispatcher loop when - // stream is missing. - m.updateConnMu.Lock() - defer m.updateConnMu.Unlock() +func (m *mockTSOServiceProvider) getConnectionCtxMgr() *cctx.Manager[*tsoStream] { + return m.conCtxMgr +} - _, ok := connectionCtxs.Load(mockStreamURL) - if ok { +func (m *mockTSOServiceProvider) updateConnectionCtxs(ctx context.Context) bool { + if m.conCtxMgr.Exist(mockStreamURL) { return true } - ctx, cancel := context.WithCancel(ctx) var stream *tsoStream if m.createStream == nil { stream = newTSOStream(ctx, mockStreamURL, newMockTSOStreamImpl(ctx, resultModeGenerated)) } else { stream = m.createStream(ctx) } - connectionCtxs.LoadOrStore(mockStreamURL, &tsoConnectionContext{ctx, cancel, mockStreamURL, stream}) + m.conCtxMgr.Store(ctx, mockStreamURL, stream) return true } diff --git a/client/errs/errno.go b/client/errs/errno.go index 25665f01017..99a426d0776 100644 --- a/client/errs/errno.go +++ b/client/errs/errno.go @@ -56,7 +56,6 @@ var ( ErrClientGetMetaStorageClient = errors.Normalize("failed to get meta storage client", errors.RFCCodeText("PD:client:ErrClientGetMetaStorageClient")) ErrClientCreateTSOStream = errors.Normalize("create TSO stream failed, %s", errors.RFCCodeText("PD:client:ErrClientCreateTSOStream")) ErrClientTSOStreamClosed = errors.Normalize("encountered TSO stream being closed unexpectedly", errors.RFCCodeText("PD:client:ErrClientTSOStreamClosed")) - ErrClientGetTSOTimeout = errors.Normalize("get TSO timeout", errors.RFCCodeText("PD:client:ErrClientGetTSOTimeout")) ErrClientGetTSO = errors.Normalize("get TSO failed, %v", errors.RFCCodeText("PD:client:ErrClientGetTSO")) ErrClientGetMinTSO = errors.Normalize("get min TSO failed, %v", errors.RFCCodeText("PD:client:ErrClientGetMinTSO")) ErrClientGetLeader = errors.Normalize("get leader failed, %v", errors.RFCCodeText("PD:client:ErrClientGetLeader")) diff --git a/client/pkg/connectionctx/manager.go b/client/pkg/connectionctx/manager.go new file mode 100644 index 00000000000..04c1eb13d3a --- /dev/null +++ b/client/pkg/connectionctx/manager.go @@ -0,0 +1,143 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connectionctx + +import ( + "context" + "sync" + + "golang.org/x/exp/rand" +) + +type connectionCtx[T any] struct { + Ctx context.Context + Cancel context.CancelFunc + // Current URL of the stream connection. + StreamURL string + // Current stream to send the gRPC requests. + Stream T +} + +// Manager is used to manage the connection contexts. +type Manager[T any] struct { + sync.RWMutex + connectionCtxs map[string]*connectionCtx[T] +} + +// NewManager is used to create a new connection context manager. +func NewManager[T any]() *Manager[T] { + return &Manager[T]{ + connectionCtxs: make(map[string]*connectionCtx[T], 3), + } +} + +// Exist is used to check if the connection context exists by the given URL. +func (c *Manager[T]) Exist(url string) bool { + c.RLock() + defer c.RUnlock() + _, ok := c.connectionCtxs[url] + return ok +} + +// Store is used to store the connection context, `overwrite` is used to force the store operation +// no matter whether the connection context exists before, which is false by default. +func (c *Manager[T]) Store(ctx context.Context, url string, stream T, overwrite ...bool) { + c.Lock() + defer c.Unlock() + overwriteFlag := false + if len(overwrite) > 0 { + overwriteFlag = overwrite[0] + } + _, ok := c.connectionCtxs[url] + if !overwriteFlag && ok { + return + } + c.storeLocked(ctx, url, stream) +} + +func (c *Manager[T]) storeLocked(ctx context.Context, url string, stream T) { + c.releaseLocked(url) + cctx, cancel := context.WithCancel(ctx) + c.connectionCtxs[url] = &connectionCtx[T]{cctx, cancel, url, stream} +} + +// CleanAllAndStore is used to store the connection context exclusively. It will release +// all other connection contexts. `stream` is optional, if it is not provided, all +// connection contexts other than the given `url` will be released. +func (c *Manager[T]) CleanAllAndStore(ctx context.Context, url string, stream ...T) { + c.Lock() + defer c.Unlock() + // Remove all other `connectionCtx`s. + c.gcLocked(func(curURL string) bool { + return curURL != url + }) + if len(stream) == 0 { + return + } + c.storeLocked(ctx, url, stream[0]) +} + +// GC is used to release all connection contexts that match the given condition. +func (c *Manager[T]) GC(condition func(url string) bool) { + c.Lock() + defer c.Unlock() + c.gcLocked(condition) +} + +func (c *Manager[T]) gcLocked(condition func(url string) bool) { + for url := range c.connectionCtxs { + if condition(url) { + c.releaseLocked(url) + } + } +} + +// ReleaseAll is used to release all connection contexts. +func (c *Manager[T]) ReleaseAll() { + c.GC(func(string) bool { return true }) +} + +// Release is used to delete a connection context from the connection context map and release the resources. +func (c *Manager[T]) Release(url string) { + c.Lock() + defer c.Unlock() + c.releaseLocked(url) +} + +func (c *Manager[T]) releaseLocked(url string) { + cc, ok := c.connectionCtxs[url] + if !ok { + return + } + cc.Cancel() + delete(c.connectionCtxs, url) +} + +// GetConnectionCtx is used to get a connection context from the connection context map. +// It uses the reservoir sampling algorithm to randomly pick one connection context. +func (c *Manager[T]) GetConnectionCtx() *connectionCtx[T] { + c.RLock() + defer c.RUnlock() + idx := 0 + var connectionCtx *connectionCtx[T] + for _, cc := range c.connectionCtxs { + j := rand.Intn(idx + 1) + if j < 1 { + connectionCtx = cc + } + idx++ + } + return connectionCtx +} diff --git a/client/pkg/connectionctx/manager_test.go b/client/pkg/connectionctx/manager_test.go new file mode 100644 index 00000000000..42504673b95 --- /dev/null +++ b/client/pkg/connectionctx/manager_test.go @@ -0,0 +1,83 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connectionctx + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestManager(t *testing.T) { + re := require.New(t) + ctx := context.Background() + manager := NewManager[int]() + + re.False(manager.Exist("test-url")) + manager.Store(ctx, "test-url", 1) + re.True(manager.Exist("test-url")) + + cctx := manager.GetConnectionCtx() + re.Equal("test-url", cctx.StreamURL) + re.Equal(1, cctx.Stream) + + manager.Store(ctx, "test-url", 2) + cctx = manager.GetConnectionCtx() + re.Equal("test-url", cctx.StreamURL) + re.Equal(1, cctx.Stream) + + manager.Store(ctx, "test-url", 2, true) + cctx = manager.GetConnectionCtx() + re.Equal("test-url", cctx.StreamURL) + re.Equal(2, cctx.Stream) + + manager.Store(ctx, "test-another-url", 3) + pickedCount := make(map[string]int) + for range 1000 { + cctx = manager.GetConnectionCtx() + pickedCount[cctx.StreamURL]++ + } + re.NotEmpty(pickedCount["test-url"]) + re.NotEmpty(pickedCount["test-another-url"]) + re.Equal(1000, pickedCount["test-url"]+pickedCount["test-another-url"]) + + manager.GC(func(url string) bool { + return url == "test-url" + }) + re.False(manager.Exist("test-url")) + re.True(manager.Exist("test-another-url")) + + manager.CleanAllAndStore(ctx, "test-url", 1) + re.True(manager.Exist("test-url")) + re.False(manager.Exist("test-another-url")) + + manager.Store(ctx, "test-another-url", 3) + manager.CleanAllAndStore(ctx, "test-unique-url", 4) + re.True(manager.Exist("test-unique-url")) + re.False(manager.Exist("test-url")) + re.False(manager.Exist("test-another-url")) + + manager.Release("test-unique-url") + re.False(manager.Exist("test-unique-url")) + + for i := range 1000 { + manager.Store(ctx, fmt.Sprintf("test-url-%d", i), i) + } + re.Len(manager.connectionCtxs, 1000) + manager.ReleaseAll() + re.Empty(manager.connectionCtxs) +} diff --git a/client/pkg/deadline/watcher.go b/client/pkg/deadline/watcher.go new file mode 100644 index 00000000000..b40857edbfd --- /dev/null +++ b/client/pkg/deadline/watcher.go @@ -0,0 +1,111 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deadline + +import ( + "context" + "time" + + "go.uber.org/zap" + + "github.com/pingcap/log" + + "github.com/tikv/pd/client/pkg/utils/timerutil" +) + +// The `cancel` function will be invoked once the specified `timeout` elapses without receiving a `done` signal. +type deadline struct { + timer *time.Timer + done chan struct{} + cancel context.CancelFunc +} + +// Watcher is used to watch and manage the deadlines. +type Watcher struct { + ctx context.Context + source string + Ch chan *deadline +} + +// NewWatcher is used to create a new deadline watcher. +func NewWatcher(ctx context.Context, capacity int, source string) *Watcher { + watcher := &Watcher{ + ctx: ctx, + source: source, + Ch: make(chan *deadline, capacity), + } + go watcher.Watch() + return watcher +} + +// Watch is used to watch the deadlines and invoke the `cancel` function when the deadline is reached. +// The `err` will be returned if the deadline is reached. +func (w *Watcher) Watch() { + log.Info("[pd] start the deadline watcher", zap.String("source", w.source)) + defer log.Info("[pd] exit the deadline watcher", zap.String("source", w.source)) + for { + select { + case d := <-w.Ch: + select { + case <-d.timer.C: + log.Error("[pd] the deadline is reached", zap.String("source", w.source)) + d.cancel() + timerutil.GlobalTimerPool.Put(d.timer) + case <-d.done: + timerutil.GlobalTimerPool.Put(d.timer) + case <-w.ctx.Done(): + timerutil.GlobalTimerPool.Put(d.timer) + return + } + case <-w.ctx.Done(): + return + } + } +} + +// Start is used to start a deadline. It returns a channel which will be closed when the deadline is reached. +// Returns nil if the deadline is not started. +func (w *Watcher) Start( + ctx context.Context, + timeout time.Duration, + cancel context.CancelFunc, +) chan struct{} { + // Check if the watcher is already canceled. + select { + case <-w.ctx.Done(): + return nil + case <-ctx.Done(): + return nil + default: + } + // Initialize the deadline. + timer := timerutil.GlobalTimerPool.Get(timeout) + d := &deadline{ + timer: timer, + done: make(chan struct{}), + cancel: cancel, + } + // Send the deadline to the watcher. + select { + case <-w.ctx.Done(): + timerutil.GlobalTimerPool.Put(timer) + return nil + case <-ctx.Done(): + timerutil.GlobalTimerPool.Put(timer) + return nil + case w.Ch <- d: + return d.done + } +} diff --git a/client/pkg/deadline/watcher_test.go b/client/pkg/deadline/watcher_test.go new file mode 100644 index 00000000000..b93987b8874 --- /dev/null +++ b/client/pkg/deadline/watcher_test.go @@ -0,0 +1,58 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deadline + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestWatcher(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + watcher := NewWatcher(ctx, 10, "test") + var deadlineReached atomic.Bool + done := watcher.Start(ctx, time.Millisecond, func() { + deadlineReached.Store(true) + }) + re.NotNil(done) + time.Sleep(5 * time.Millisecond) + re.True(deadlineReached.Load()) + + deadlineReached.Store(false) + done = watcher.Start(ctx, 500*time.Millisecond, func() { + deadlineReached.Store(true) + }) + re.NotNil(done) + done <- struct{}{} + time.Sleep(time.Second) + re.False(deadlineReached.Load()) + + deadCtx, deadCancel := context.WithCancel(ctx) + deadCancel() + deadlineReached.Store(false) + done = watcher.Start(deadCtx, time.Millisecond, func() { + deadlineReached.Store(true) + }) + re.Nil(done) + time.Sleep(5 * time.Millisecond) + re.False(deadlineReached.Load()) +} diff --git a/client/servicediscovery/pd_service_discovery.go b/client/servicediscovery/pd_service_discovery.go index 619d4196408..5530f3cfa9b 100644 --- a/client/servicediscovery/pd_service_discovery.go +++ b/client/servicediscovery/pd_service_discovery.go @@ -966,12 +966,9 @@ func (c *pdServiceDiscovery) updateURLs(members []*pdpb.Member) { return } c.urls.Store(urls) - // Update the connection contexts when member changes if TSO Follower Proxy is enabled. - if c.option.GetEnableTSOFollowerProxy() { - // Run callbacks to reflect the membership changes in the leader and followers. - for _, cb := range c.membersChangedCbs { - cb() - } + // Run callbacks to reflect the membership changes in the leader and followers. + for _, cb := range c.membersChangedCbs { + cb() } log.Info("[pd] update member urls", zap.Strings("old-urls", oldURLs), zap.Strings("new-urls", urls)) } diff --git a/errors.toml b/errors.toml index 2ab3b014f5a..9980a98ab14 100644 --- a/errors.toml +++ b/errors.toml @@ -131,11 +131,6 @@ error = ''' get TSO failed ''' -["PD:client:ErrClientGetTSOTimeout"] -error = ''' -get TSO timeout -''' - ["PD:cluster:ErrInvalidStoreID"] error = ''' invalid store id %d, not found diff --git a/metrics/grafana/pd.json b/metrics/grafana/pd.json index 0f4e91afd50..62b2e7234ef 100644 --- a/metrics/grafana/pd.json +++ b/metrics/grafana/pd.json @@ -8937,6 +8937,226 @@ "align": false, "alignLevel": null } + }, + { + "aliasColors": {}, + "dashLength": 10, + "datasource": "${DS_TEST-CLUSTER}", + "description": "The rate of received each kind of gRPC commands", + "editable": true, + "fill": 1, + "grid": {}, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 127 + }, + "id": 904, + "legend": { + "alignAsTable": true, + "avg": false, + "current": true, + "hideEmpty": true, + "hideZero": true, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "sideWidth": 300, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null as zero", + "paceLength": 10, + "pointradius": 5, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "targets": [ + { + "expr": "sum(rate(grpc_server_msg_received_total{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\", instance=~\"$instance\"}[1m])) by (instance, grpc_method)", + "legendFormat": "{{instance}}-{{grpc_method}}", + "interval": "", + "exemplar": true, + "intervalFactor": 2, + "refId": "A", + "step": 4 + } + ], + "thresholds": [], + "timeRegions": [], + "title": "gRPC Received commands rate", + "tooltip": { + "msResolution": false, + "shared": true, + "sort": 0, + "value_type": "cumulative" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "ops", + "label": null, + "logBase": 10, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + }, + "options": { + "alertThreshold": true + }, + "fieldConfig": { + "defaults": {}, + "overrides": [] + }, + "pluginVersion": "7.5.17", + "bars": false, + "dashes": false, + "decimals": null, + "error": false, + "percentage": false, + "points": false, + "stack": false, + "steppedLine": false, + "timeFrom": null, + "timeShift": null, + "fillGradient": 0, + "hiddenSeries": false + }, + { + "aliasColors": {}, + "dashLength": 10, + "datasource": "${DS_TEST-CLUSTER}", + "description": "The error rate of handled gRPC commands.Note: It can't catch the error hide in the header, like this https://github.com/tikv/pd/blob/2d970a619a8917c35d306f401326141481c133e0/server/grpc_service.go#L2071", + "editable": true, + "fieldConfig": { + "defaults": {}, + "overrides": [] + }, + "fill": 1, + "grid": {}, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 135 + }, + "id": 905, + "legend": { + "alignAsTable": true, + "avg": false, + "current": true, + "hideEmpty": true, + "hideZero": true, + "max": true, + "min": false, + "rightSide": true, + "show": true, + "sideWidth": 300, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null as zero", + "options": { + "alertThreshold": true + }, + "paceLength": 10, + "pluginVersion": "7.5.17", + "pointradius": 5, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "targets": [ + { + "expr": "sum(rate(grpc_server_handled_total{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\", instance=~\"$instance\", grpc_type=\"unary\", grpc_code!=\"OK\"}[1m])) by (grpc_method)", + "legendFormat": "{{grpc_method}}", + "interval": "", + "exemplar": true, + "intervalFactor": 2, + "refId": "A", + "step": 4 + } + ], + "thresholds": [], + "timeRegions": [], + "title": "gRPC Error rate", + "tooltip": { + "msResolution": false, + "shared": true, + "sort": 0, + "value_type": "cumulative" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "ops", + "label": null, + "logBase": 10, + "max": null, + "min": null, + "show": true, + "$$hashKey": "object:132" + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true, + "$$hashKey": "object:133" + } + ], + "yaxis": { + "align": false, + "alignLevel": null + }, + "bars": false, + "dashes": false, + "decimals": null, + "error": false, + "fillGradient": 0, + "hiddenSeries": false, + "percentage": false, + "points": false, + "stack": false, + "steppedLine": false, + "timeFrom": null, + "timeShift": null } ], "repeat": null, diff --git a/pkg/cgroup/cgroup_cpu_test.go b/pkg/cgroup/cgroup_cpu_test.go index 441c2192e79..6d4d8f39f49 100644 --- a/pkg/cgroup/cgroup_cpu_test.go +++ b/pkg/cgroup/cgroup_cpu_test.go @@ -17,7 +17,6 @@ package cgroup import ( - "fmt" "regexp" "runtime" "strconv" @@ -45,10 +44,10 @@ func checkKernelVersionNewerThan(re *require.Assertions, t *testing.T, major, mi t.Log("kernel release string:", releaseStr) versionInfoRE := regexp.MustCompile(`[0-9]+\.[0-9]+\.[0-9]+`) kernelVersion := versionInfoRE.FindAllString(releaseStr, 1) - re.Len(kernelVersion, 1, fmt.Sprintf("release str is %s", releaseStr)) + re.Lenf(kernelVersion, 1, "release str is %s", releaseStr) kernelVersionPartRE := regexp.MustCompile(`[0-9]+`) kernelVersionParts := kernelVersionPartRE.FindAllString(kernelVersion[0], -1) - re.Len(kernelVersionParts, 3, fmt.Sprintf("kernel version str is %s", kernelVersion[0])) + re.Lenf(kernelVersionParts, 3, "kernel version str is %s", kernelVersion[0]) t.Logf("parsed kernel version parts: major %s, minor %s, patch %s", kernelVersionParts[0], kernelVersionParts[1], kernelVersionParts[2]) mustConvInt := func(s string) int { diff --git a/pkg/core/region_test.go b/pkg/core/region_test.go index 51ba5fe96dc..473421b0e52 100644 --- a/pkg/core/region_test.go +++ b/pkg/core/region_test.go @@ -985,10 +985,10 @@ func TestUpdateRegionEquivalence(t *testing.T) { checkRegions(re, regionsNew) for _, r := range regionsOld.GetRegions() { - re.Equal(int32(2), r.GetRef(), fmt.Sprintf("inconsistent region %d", r.GetID())) + re.Equalf(int32(2), r.GetRef(), "inconsistent region %d", r.GetID()) } for _, r := range regionsNew.GetRegions() { - re.Equal(int32(2), r.GetRef(), fmt.Sprintf("inconsistent region %d", r.GetID())) + re.Equalf(int32(2), r.GetRef(), "inconsistent region %d", r.GetID()) } for i := 1; i <= storeNums; i++ { diff --git a/pkg/errs/errno.go b/pkg/errs/errno.go index ee24b4d0673..834bf4f824e 100644 --- a/pkg/errs/errno.go +++ b/pkg/errs/errno.go @@ -144,7 +144,6 @@ var ( // client errors var ( ErrClientCreateTSOStream = errors.Normalize("create TSO stream failed, %s", errors.RFCCodeText("PD:client:ErrClientCreateTSOStream")) - ErrClientGetTSOTimeout = errors.Normalize("get TSO timeout", errors.RFCCodeText("PD:client:ErrClientGetTSOTimeout")) ErrClientGetTSO = errors.Normalize("get TSO failed", errors.RFCCodeText("PD:client:ErrClientGetTSO")) ErrClientGetLeader = errors.Normalize("get leader failed, %v", errors.RFCCodeText("PD:client:ErrClientGetLeader")) ErrClientGetMember = errors.Normalize("get member failed", errors.RFCCodeText("PD:client:ErrClientGetMember")) diff --git a/pkg/mcs/resourcemanager/server/metrics_test.go b/pkg/mcs/resourcemanager/server/metrics_test.go index 4c3ec7ce5ef..97b21bf5ce3 100644 --- a/pkg/mcs/resourcemanager/server/metrics_test.go +++ b/pkg/mcs/resourcemanager/server/metrics_test.go @@ -15,7 +15,6 @@ package server import ( - "fmt" "testing" "github.com/stretchr/testify/require" @@ -43,8 +42,8 @@ func TestMaxPerSecCostTracker(t *testing.T) { // Check the max values at the end of each flushPeriod if (i+1)%20 == 0 { period := i / 20 - re.Equal(tracker.maxPerSecRRU, expectedMaxRU[period], fmt.Sprintf("maxPerSecRRU in period %d is incorrect", period+1)) - re.Equal(tracker.maxPerSecWRU, expectedMaxRU[period], fmt.Sprintf("maxPerSecWRU in period %d is incorrect", period+1)) + re.Equalf(tracker.maxPerSecRRU, expectedMaxRU[period], "maxPerSecRRU in period %d is incorrect", period+1) + re.Equalf(tracker.maxPerSecWRU, expectedMaxRU[period], "maxPerSecWRU in period %d is incorrect", period+1) re.Equal(tracker.rruSum, expectedSum[period]) re.Equal(tracker.rruSum, expectedSum[period]) } diff --git a/pkg/mcs/resourcemanager/server/token_buckets_test.go b/pkg/mcs/resourcemanager/server/token_buckets_test.go index b56ccb6ab96..676b1127f35 100644 --- a/pkg/mcs/resourcemanager/server/token_buckets_test.go +++ b/pkg/mcs/resourcemanager/server/token_buckets_test.go @@ -15,7 +15,6 @@ package server import ( - "fmt" "math" "testing" "time" @@ -182,9 +181,9 @@ func TestGroupTokenBucketRequestLoop(t *testing.T) { currentTime := initialTime for i, tc := range testCases { tb, trickle := gtb.request(currentTime, tc.requestTokens, uint64(targetPeriod)/uint64(time.Millisecond), clientUniqueID) - re.Equal(tc.globalBucketTokensAfterAssign, gtb.GetTokenBucket().Tokens, fmt.Sprintf("Test case %d failed: expected bucket tokens %f, got %f", i, tc.globalBucketTokensAfterAssign, gtb.GetTokenBucket().Tokens)) - re.LessOrEqual(math.Abs(tb.Tokens-tc.assignedTokens), 1e-7, fmt.Sprintf("Test case %d failed: expected tokens %f, got %f", i, tc.assignedTokens, tb.Tokens)) - re.Equal(tc.expectedTrickleMs, trickle, fmt.Sprintf("Test case %d failed: expected trickle %d, got %d", i, tc.expectedTrickleMs, trickle)) + re.Equalf(tc.globalBucketTokensAfterAssign, gtb.GetTokenBucket().Tokens, "Test case %d failed: expected bucket tokens %f, got %f", i, tc.globalBucketTokensAfterAssign, gtb.GetTokenBucket().Tokens) + re.LessOrEqualf(math.Abs(tb.Tokens-tc.assignedTokens), 1e-7, "Test case %d failed: expected tokens %f, got %f", i, tc.assignedTokens, tb.Tokens) + re.Equalf(tc.expectedTrickleMs, trickle, "Test case %d failed: expected trickle %d, got %d", i, tc.expectedTrickleMs, trickle) currentTime = currentTime.Add(timeIncrement) } } diff --git a/pkg/tso/keyspace_group_manager_test.go b/pkg/tso/keyspace_group_manager_test.go index b4393a23471..c54bbcc1b33 100644 --- a/pkg/tso/keyspace_group_manager_test.go +++ b/pkg/tso/keyspace_group_manager_test.go @@ -886,7 +886,7 @@ func collectAssignedKeyspaceGroupIDs(re *require.Assertions, kgm *KeyspaceGroupM for i := range kgm.kgs { kg := kgm.kgs[i] if kg == nil { - re.Nil(kgm.ams[i], fmt.Sprintf("ksg is nil but am is not nil for id %d", i)) + re.Nilf(kgm.ams[i], "ksg is nil but am is not nil for id %d", i) } else { am := kgm.ams[i] if am != nil { @@ -976,8 +976,8 @@ func (suite *keyspaceGroupManagerTestSuite) TestUpdateKeyspaceGroupMembership() func verifyLocalKeyspaceLookupTable( re *require.Assertions, keyspaceLookupTable map[uint32]struct{}, newKeyspaces []uint32, ) { - re.Equal(len(newKeyspaces), len(keyspaceLookupTable), - fmt.Sprintf("%v %v", newKeyspaces, keyspaceLookupTable)) + re.Equalf(len(newKeyspaces), len(keyspaceLookupTable), + "%v %v", newKeyspaces, keyspaceLookupTable) for _, keyspace := range newKeyspaces { _, ok := keyspaceLookupTable[keyspace] re.True(ok) diff --git a/pkg/window/policy_test.go b/pkg/window/policy_test.go index 936360ccb2b..b5a04c03e4b 100644 --- a/pkg/window/policy_test.go +++ b/pkg/window/policy_test.go @@ -18,7 +18,6 @@ package window import ( - "fmt" "math" "testing" "time" @@ -79,11 +78,11 @@ func TestRollingPolicy_Add(t *testing.T) { asExpected = false } if asExpected { - re.Less(math.Abs(point-policy.window.buckets[offset].Points[0]), 1e-6, - fmt.Sprintf("error, time since last append: %vms, last offset: %v", totalTS, lastOffset)) + re.Lessf(math.Abs(point-policy.window.buckets[offset].Points[0]), 1e-6, + "error, time since last append: %vms, last offset: %v", totalTS, lastOffset) } - re.Less(math.Abs(points[i]-policy.window.buckets[offset].Points[0]), 1e-6, - fmt.Sprintf("error, time since last append: %vms, last offset: %v", totalTS, lastOffset)) + re.Lessf(math.Abs(points[i]-policy.window.buckets[offset].Points[0]), 1e-6, + "error, time since last append: %vms, last offset: %v", totalTS, lastOffset) lastOffset = offset } }) diff --git a/tests/integrations/mcs/scheduling/server_test.go b/tests/integrations/mcs/scheduling/server_test.go index ea1e9df0b50..3401fb880cb 100644 --- a/tests/integrations/mcs/scheduling/server_test.go +++ b/tests/integrations/mcs/scheduling/server_test.go @@ -106,7 +106,7 @@ func (suite *serverTestSuite) TestAllocIDAfterLeaderChange() { re := suite.Require() re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/mcs/scheduling/server/fastUpdateMember", `return(true)`)) pd2, err := suite.cluster.Join(suite.ctx) - re.NoError(err) + re.NoError(err, "error: %v", err) err = pd2.Run() re.NotEmpty(suite.cluster.WaitLeader()) re.NoError(err) @@ -261,6 +261,8 @@ func (suite *serverTestSuite) TestDisableSchedulingServiceFallback() { // API server will execute scheduling jobs since there is no scheduling server. testutil.Eventually(re, func() bool { + re.NotNil(suite.pdLeader.GetServer()) + re.NotNil(suite.pdLeader.GetServer().GetRaftCluster()) return suite.pdLeader.GetServer().GetRaftCluster().IsSchedulingControllerRunning() }) leaderServer := suite.pdLeader.GetServer() diff --git a/tests/integrations/realcluster/cluster.go b/tests/integrations/realcluster/cluster.go index cc4f6b54713..fc5d1bc4441 100644 --- a/tests/integrations/realcluster/cluster.go +++ b/tests/integrations/realcluster/cluster.go @@ -185,6 +185,5 @@ func waitTiupReady(t *testing.T, tag string) { zap.String("tag", tag), zap.Error(err)) time.Sleep(time.Duration(interval) * time.Second) } - // this check can trigger the cleanup function - require.NotZero(t, 1, "TiUP is not ready", "tag: %s", tag) + require.FailNowf(t, "TiUP is not ready after retry: %s", tag) } diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index 00c43d11309..faa22ce08f4 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -173,8 +173,8 @@ func (suite *middlewareTestSuite) TestRequestInfoMiddleware() { re.Equal(http.StatusOK, resp.StatusCode) re.Equal("Profile", resp.Header.Get("service-label")) - re.Equal("{\"seconds\":[\"1\"]}", resp.Header.Get("url-param")) - re.Equal("{\"testkey\":\"testvalue\"}", resp.Header.Get("body-param")) + re.JSONEq("{\"seconds\":[\"1\"]}", resp.Header.Get("url-param")) + re.JSONEq("{\"testkey\":\"testvalue\"}", resp.Header.Get("body-param")) re.Equal("HTTP/1.1/POST:/pd/api/v1/debug/pprof/profile", resp.Header.Get("method")) re.Equal("anonymous", resp.Header.Get("caller-id")) re.Equal("127.0.0.1", resp.Header.Get("ip")) diff --git a/tools/pd-ctl/tests/scheduler/scheduler_test.go b/tools/pd-ctl/tests/scheduler/scheduler_test.go index 787bdaa4521..f3a81845921 100644 --- a/tools/pd-ctl/tests/scheduler/scheduler_test.go +++ b/tools/pd-ctl/tests/scheduler/scheduler_test.go @@ -186,18 +186,18 @@ func (suite *schedulerTestSuite) checkScheduler(cluster *pdTests.TestCluster) { case "grant-leader-scheduler": return "paused", !storeInfo.AllowLeaderTransferOut() default: - re.Fail(fmt.Sprintf("unknown scheduler %s", schedulerName)) + re.Failf("unknown scheduler %s", schedulerName) return "", false } }() if slice.AnyOf(changedStores, func(i int) bool { return store.GetId() == changedStores[i] }) { - re.True(isStorePaused, - fmt.Sprintf("store %d should be %s with %s", store.GetId(), status, schedulerName)) + re.Truef(isStorePaused, + "store %d should be %s with %s", store.GetId(), status, schedulerName) } else { - re.False(isStorePaused, - fmt.Sprintf("store %d should not be %s with %s", store.GetId(), status, schedulerName)) + re.Falsef(isStorePaused, + "store %d should not be %s with %s", store.GetId(), status, schedulerName) } if sche := cluster.GetSchedulingPrimaryServer(); sche != nil { switch schedulerName { @@ -206,7 +206,7 @@ func (suite *schedulerTestSuite) checkScheduler(cluster *pdTests.TestCluster) { case "grant-leader-scheduler": re.Equal(isStorePaused, !sche.GetCluster().GetStore(store.GetId()).AllowLeaderTransferOut()) default: - re.Fail(fmt.Sprintf("unknown scheduler %s", schedulerName)) + re.Failf("unknown scheduler %s", schedulerName) } } }