Skip to content

Commit

Permalink
Setup the test suite
Browse files Browse the repository at this point in the history
Signed-off-by: JmPotato <[email protected]>
  • Loading branch information
JmPotato committed Jan 22, 2025
1 parent 48ff881 commit 4b253eb
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 13 deletions.
24 changes: 23 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,18 @@ func (c *client) GetMinTS(ctx context.Context) (physical int64, logical int64, e
return minTS.Physical, minTS.Logical, nil
}

// EnableRouterClient enables the router client.
// This is only for test currently.
func (c *client) EnableRouterClient() {
c.inner.enableRouterClient.Store(true)
}

func (c *client) isRouterClientEnabled() bool {
return c.inner.enableRouterClient.Load()
}

// GetRegionFromMember implements the RPCClient interface.
func (c *client) GetRegionFromMember(ctx context.Context, key []byte, memberURLs []string, _ ...opt.GetRegionOption) (*router.Region, error) {
func (c *client) GetRegionFromMember(ctx context.Context, key []byte, memberURLs []string, opts ...opt.GetRegionOption) (*router.Region, error) {
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
span = span.Tracer().StartSpan("pdclient.GetRegionFromMember", opentracing.ChildOf(span.Context()))
defer span.Finish()
Expand Down Expand Up @@ -620,6 +630,10 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio
ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout)
defer cancel()

if c.isRouterClientEnabled() {
return c.inner.routerClient.GetRegion(ctx, key, opts...)
}

options := &opt.GetRegionOp{}
for _, opt := range opts {
opt(options)
Expand Down Expand Up @@ -660,6 +674,10 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR
ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout)
defer cancel()

if c.isRouterClientEnabled() {
return c.inner.routerClient.GetPrevRegion(ctx, key, opts...)
}

options := &opt.GetRegionOp{}
for _, opt := range opts {
opt(options)
Expand Down Expand Up @@ -700,6 +718,10 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt
ctx, cancel := context.WithTimeout(ctx, c.inner.option.Timeout)
defer cancel()

if c.isRouterClientEnabled() {
return c.inner.routerClient.GetRegionByID(ctx, regionID, opts...)
}

options := &opt.GetRegionOp{}
for _, opt := range opts {
opt(options)
Expand Down
5 changes: 1 addition & 4 deletions client/clients/router/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,7 @@ func requestFinisher(resp *pdpb.QueryRegionResponse) batch.FinisherFunc[*Request
} else if req.id != 0 {
id = req.id
}
region, ok := resp.RegionsById[id]
if !ok {
err = errs.ErrClientRegionNotFound.FastGenByArgs(id)
} else {
if region, ok := resp.RegionsById[id]; ok {
req.region = ConvertToRegion(region)
}
req.tryDone(err)
Expand Down
6 changes: 0 additions & 6 deletions client/clients/router/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,6 @@ func (c *Cli) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOp
return req.wait()
}

// GetRegionFromMember implements the Client interface.
func (c *Cli) GetRegionFromMember(ctx context.Context, key []byte, _ []string, opts ...opt.GetRegionOption) (*Region, error) {
// Before we support the follower stream connection, this method is equivalent to `GetRegion`.
return c.GetRegion(ctx, key, opts...)
}

// GetPrevRegion implements the Client interface.
func (c *Cli) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOption) (*Region, error) {
req := c.newRequest(ctx)
Expand Down
1 change: 0 additions & 1 deletion client/errs/errno.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ var (
ErrClientFindGroupByKeyspaceID = errors.Normalize("can't find keyspace group by keyspace id", errors.RFCCodeText("PD:client:ErrClientFindGroupByKeyspaceID"))
ErrClientWatchGCSafePointV2Stream = errors.Normalize("watch gc safe point v2 stream failed", errors.RFCCodeText("PD:client:ErrClientWatchGCSafePointV2Stream"))
ErrCircuitBreakerOpen = errors.Normalize("circuit breaker is open", errors.RFCCodeText("PD:client:ErrCircuitBreakerOpen"))
ErrClientRegionNotFound = errors.Normalize("region %d not found", errors.RFCCodeText("PD:client:ErrClientRegionNotFound"))
)

// grpcutil errors
Expand Down
8 changes: 7 additions & 1 deletion client/inner_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"crypto/tls"
"sync"
"sync/atomic"
"time"

"go.uber.org/zap"
Expand Down Expand Up @@ -46,7 +47,12 @@ type innerClient struct {
serviceDiscovery sd.ServiceDiscovery
tokenDispatcher *tokenDispatcher

routerClient *router.Cli
// The router client is used to get the region info via the streaming gRPC,
// this flag is used to control whether to enable it, currently only used
// in the test.
enableRouterClient atomic.Bool
routerClient *router.Cli

// For service mode switching.
serviceModeKeeper

Expand Down
9 changes: 9 additions & 0 deletions tests/integrations/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1025,12 +1025,18 @@ type clientTestSuite struct {
grpcPDClient pdpb.PDClient
regionHeartbeat pdpb.PD_RegionHeartbeatClient
reportBucket pdpb.PD_ReportBucketsClient

enableRouterClient bool
}

func TestClientTestSuite(t *testing.T) {
suite.Run(t, new(clientTestSuite))
}

func TestClientTestSuiteWithRouterClient(t *testing.T) {
suite.Run(t, &clientTestSuite{enableRouterClient: true})
}

func (suite *clientTestSuite) SetupSuite() {
var err error
re := suite.Require()
Expand All @@ -1044,6 +1050,9 @@ func (suite *clientTestSuite) SetupSuite() {

suite.ctx, suite.clean = context.WithCancel(context.Background())
suite.client = setupCli(suite.ctx, re, suite.srv.GetEndpoints())
if suite.enableRouterClient {
suite.client.(interface{ EnableRouterClient() }).EnableRouterClient()
}

suite.regionHeartbeat, err = suite.grpcPDClient.RegionHeartbeat(suite.ctx)
re.NoError(err)
Expand Down

0 comments on commit 4b253eb

Please sign in to comment.