diff --git a/client/client.go b/client/client.go index 8b21b17169e..b2543a4e49d 100644 --- a/client/client.go +++ b/client/client.go @@ -570,8 +570,18 @@ func (c *client) GetMinTS(ctx context.Context) (physical int64, logical int64, e return minTS.Physical, minTS.Logical, nil } +// EnableRouterClient enables the router client. +// This is only for test currently. +func (c *client) EnableRouterClient() { + c.inner.enableRouterClient.Store(true) +} + +func (c *client) isRouterClientEnabled() bool { + return c.inner.enableRouterClient.Load() +} + // GetRegionFromMember implements the RPCClient interface. -func (c *client) GetRegionFromMember(ctx context.Context, key []byte, memberURLs []string, _ ...opt.GetRegionOption) (*router.Region, error) { +func (c *client) GetRegionFromMember(ctx context.Context, key []byte, memberURLs []string, opts ...opt.GetRegionOption) (*router.Region, error) { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span = span.Tracer().StartSpan("pdclient.GetRegionFromMember", opentracing.ChildOf(span.Context())) defer span.Finish() @@ -620,6 +630,10 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout) defer cancel() + if c.isRouterClientEnabled() { + return c.inner.routerClient.GetRegion(ctx, key, opts...) + } + options := &opt.GetRegionOp{} for _, opt := range opts { opt(options) @@ -660,6 +674,10 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout) defer cancel() + if c.isRouterClientEnabled() { + return c.inner.routerClient.GetPrevRegion(ctx, key, opts...) + } + options := &opt.GetRegionOp{} for _, opt := range opts { opt(options) @@ -700,6 +718,10 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout) defer cancel() + if c.isRouterClientEnabled() { + return c.inner.routerClient.GetRegionByID(ctx, regionID, opts...) + } + options := &opt.GetRegionOp{} for _, opt := range opts { opt(options) diff --git a/client/clients/router/client.go b/client/clients/router/client.go index 8d0cfd64d2d..8bd44b8b6a3 100644 --- a/client/clients/router/client.go +++ b/client/clients/router/client.go @@ -227,10 +227,7 @@ func requestFinisher(resp *pdpb.QueryRegionResponse) batch.FinisherFunc[*Request } else if req.id != 0 { id = req.id } - region, ok := resp.RegionsById[id] - if !ok { - err = errs.ErrClientRegionNotFound.FastGenByArgs(id) - } else { + if region, ok := resp.RegionsById[id]; ok { req.region = ConvertToRegion(region) } req.tryDone(err) diff --git a/client/clients/router/request.go b/client/clients/router/request.go index 2e1c2e97aa5..cc1ada0a729 100644 --- a/client/clients/router/request.go +++ b/client/clients/router/request.go @@ -80,12 +80,6 @@ func (c *Cli) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOp return req.wait() } -// GetRegionFromMember implements the Client interface. -func (c *Cli) GetRegionFromMember(ctx context.Context, key []byte, _ []string, opts ...opt.GetRegionOption) (*Region, error) { - // Before we support the follower stream connection, this method is equivalent to `GetRegion`. - return c.GetRegion(ctx, key, opts...) -} - // GetPrevRegion implements the Client interface. func (c *Cli) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOption) (*Region, error) { req := c.newRequest(ctx) diff --git a/client/errs/errno.go b/client/errs/errno.go index 8f81d2d6777..99a426d0776 100644 --- a/client/errs/errno.go +++ b/client/errs/errno.go @@ -70,7 +70,6 @@ var ( ErrClientFindGroupByKeyspaceID = errors.Normalize("can't find keyspace group by keyspace id", errors.RFCCodeText("PD:client:ErrClientFindGroupByKeyspaceID")) ErrClientWatchGCSafePointV2Stream = errors.Normalize("watch gc safe point v2 stream failed", errors.RFCCodeText("PD:client:ErrClientWatchGCSafePointV2Stream")) ErrCircuitBreakerOpen = errors.Normalize("circuit breaker is open", errors.RFCCodeText("PD:client:ErrCircuitBreakerOpen")) - ErrClientRegionNotFound = errors.Normalize("region %d not found", errors.RFCCodeText("PD:client:ErrClientRegionNotFound")) ) // grpcutil errors diff --git a/client/inner_client.go b/client/inner_client.go index 269c2330f8e..464fd413e25 100644 --- a/client/inner_client.go +++ b/client/inner_client.go @@ -18,6 +18,7 @@ import ( "context" "crypto/tls" "sync" + "sync/atomic" "time" "go.uber.org/zap" @@ -46,7 +47,12 @@ type innerClient struct { serviceDiscovery sd.ServiceDiscovery tokenDispatcher *tokenDispatcher - routerClient *router.Cli + // The router client is used to get the region info via the streaming gRPC, + // this flag is used to control whether to enable it, currently only used + // in the test. + enableRouterClient atomic.Bool + routerClient *router.Cli + // For service mode switching. serviceModeKeeper diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index 91a6d44943e..b8c4c6b3d62 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -1025,12 +1025,18 @@ type clientTestSuite struct { grpcPDClient pdpb.PDClient regionHeartbeat pdpb.PD_RegionHeartbeatClient reportBucket pdpb.PD_ReportBucketsClient + + enableRouterClient bool } func TestClientTestSuite(t *testing.T) { suite.Run(t, new(clientTestSuite)) } +func TestClientTestSuiteWithRouterClient(t *testing.T) { + suite.Run(t, &clientTestSuite{enableRouterClient: true}) +} + func (suite *clientTestSuite) SetupSuite() { var err error re := suite.Require() @@ -1044,6 +1050,9 @@ func (suite *clientTestSuite) SetupSuite() { suite.ctx, suite.clean = context.WithCancel(context.Background()) suite.client = setupCli(suite.ctx, re, suite.srv.GetEndpoints()) + if suite.enableRouterClient { + suite.client.(interface{ EnableRouterClient() }).EnableRouterClient() + } suite.regionHeartbeat, err = suite.grpcPDClient.RegionHeartbeat(suite.ctx) re.NoError(err)