diff --git a/server/gc_service.go b/server/gc_service.go index 114482fdd39..c88a0395db6 100644 --- a/server/gc_service.go +++ b/server/gc_service.go @@ -31,18 +31,24 @@ import ( "github.com/tikv/pd/pkg/utils/tsoutil" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" - "google.golang.org/grpc" ) // GetGCSafePointV2 return gc safe point for the given keyspace. func (s *GrpcServer) GetGCSafePointV2(ctx context.Context, request *pdpb.GetGCSafePointV2Request) (*pdpb.GetGCSafePointV2Response, error) { - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).GetGCSafePointV2(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { + if midResp, err := s.unaryMiddleware(ctx, request, "GetGCSafePointV2"); err != nil { return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetGCSafePointV2Response), err + } else if midResp != nil { + if midResp.header != nil { + return &pdpb.GetGCSafePointV2Response{ + Header: midResp.header, + }, nil + } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetGCSafePointV2Response), nil + } } safePoint, err := s.safePointV2Manager.LoadGCSafePoint(request.GetKeyspaceId()) @@ -61,13 +67,20 @@ func (s *GrpcServer) GetGCSafePointV2(ctx context.Context, request *pdpb.GetGCSa // UpdateGCSafePointV2 update gc safe point for the given keyspace. func (s *GrpcServer) UpdateGCSafePointV2(ctx context.Context, request *pdpb.UpdateGCSafePointV2Request) (*pdpb.UpdateGCSafePointV2Response, error) { - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).UpdateGCSafePointV2(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { + if midResp, err := s.unaryMiddleware(ctx, request, "UpdateGCSafePointV2"); err != nil { return nil, err - } else if rsp != nil { - return rsp.(*pdpb.UpdateGCSafePointV2Response), err + } else if midResp != nil { + if midResp.header != nil { + return &pdpb.UpdateGCSafePointV2Response{ + Header: midResp.header, + }, nil + } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.UpdateGCSafePointV2Response), nil + } } newSafePoint := request.GetSafePoint() @@ -98,13 +111,20 @@ func (s *GrpcServer) UpdateGCSafePointV2(ctx context.Context, request *pdpb.Upda // UpdateServiceSafePointV2 update service safe point for the given keyspace. func (s *GrpcServer) UpdateServiceSafePointV2(ctx context.Context, request *pdpb.UpdateServiceSafePointV2Request) (*pdpb.UpdateServiceSafePointV2Response, error) { - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).UpdateServiceSafePointV2(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { + if midResp, err := s.unaryMiddleware(ctx, request, "UpdateServiceSafePointV2"); err != nil { return nil, err - } else if rsp != nil { - return rsp.(*pdpb.UpdateServiceSafePointV2Response), err + } else if midResp != nil { + if midResp.header != nil { + return &pdpb.UpdateServiceSafePointV2Response{ + Header: midResp.header, + }, nil + } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.UpdateServiceSafePointV2Response), nil + } } nowTSO, err := s.getGlobalTSO(ctx) @@ -195,13 +215,20 @@ func (s *GrpcServer) WatchGCSafePointV2(request *pdpb.WatchGCSafePointV2Request, // GetAllGCSafePointV2 return all gc safe point v2. func (s *GrpcServer) GetAllGCSafePointV2(ctx context.Context, request *pdpb.GetAllGCSafePointV2Request) (*pdpb.GetAllGCSafePointV2Response, error) { - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).GetAllGCSafePointV2(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { + if midResp, err := s.unaryMiddleware(ctx, request, "GetAllGCSafePointV2"); err != nil { return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetAllGCSafePointV2Response), err + } else if midResp != nil { + if midResp.header != nil { + return &pdpb.GetAllGCSafePointV2Response{ + Header: midResp.header, + }, nil + } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetAllGCSafePointV2Response), nil + } } startkey := keypath.GCSafePointV2Prefix() diff --git a/server/grpc_service.go b/server/grpc_service.go index d5fd8ae3e32..d73fa8334c9 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -20,7 +20,6 @@ import ( "fmt" "io" "path" - "runtime" "runtime/trace" "strconv" "strings" @@ -51,7 +50,6 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/multierr" "go.uber.org/zap" - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -233,36 +231,6 @@ func (s *schedulingClient) getPrimaryAddr() string { return s.primary } -type request interface { - GetHeader() *pdpb.RequestHeader -} - -type forwardFn func(ctx context.Context, client *grpc.ClientConn) (any, error) - -func (s *GrpcServer) unaryMiddleware(ctx context.Context, req request, fn forwardFn) (rsp any, err error) { - return s.unaryFollowerMiddleware(ctx, req, fn, nil) -} - -// unaryFollowerMiddleware adds the check of followers enable compared to unaryMiddleware. -func (s *GrpcServer) unaryFollowerMiddleware(ctx context.Context, req request, fn forwardFn, allowFollower *bool) (rsp any, err error) { - failpoint.Inject("customTimeout", func() { - time.Sleep(5 * time.Second) - }) - forwardedHost := grpcutil.GetForwardedHost(ctx) - if !s.isLocalRequest(forwardedHost) { - client, err := s.getDelegateClient(ctx, forwardedHost) - if err != nil { - return nil, err - } - ctx = grpcutil.ResetForwardContext(ctx) - return fn(ctx, client) - } - if err := s.validateRoleInRequest(ctx, req.GetHeader(), allowFollower); err != nil { - return nil, err - } - return nil, nil -} - // GetClusterInfo implements gRPC PDServer. func (s *GrpcServer) GetClusterInfo(context.Context, *pdpb.GetClusterInfoRequest) (*pdpb.GetClusterInfoResponse, error) { // Here we purposely do not check the cluster ID because the client does not know the correct cluster ID @@ -295,24 +263,20 @@ func (s *GrpcServer) GetClusterInfo(context.Context, *pdpb.GetClusterInfoRequest func (s *GrpcServer) GetMinTS( ctx context.Context, request *pdpb.GetMinTSRequest, ) (*pdpb.GetMinTSResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetMinTS"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetMinTSResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).GetMinTS(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetMinTSResponse), nil + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetMinTSResponse), nil + } } var ( @@ -634,24 +598,20 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { // Bootstrap implements gRPC PDServer. func (s *GrpcServer) Bootstrap(ctx context.Context, request *pdpb.BootstrapRequest) (*pdpb.BootstrapResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "Bootstrap"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.BootstrapResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).Bootstrap(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.BootstrapResponse), nil + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.BootstrapResponse), nil + } } rc := s.GetRaftCluster() @@ -678,24 +638,20 @@ func (s *GrpcServer) Bootstrap(ctx context.Context, request *pdpb.BootstrapReque // IsBootstrapped implements gRPC PDServer. func (s *GrpcServer) IsBootstrapped(ctx context.Context, request *pdpb.IsBootstrappedRequest) (*pdpb.IsBootstrappedResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "IsBootstrapped"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.IsBootstrappedResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).IsBootstrapped(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.IsBootstrappedResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.IsBootstrappedResponse), nil + } } rc := s.GetRaftCluster() @@ -707,24 +663,20 @@ func (s *GrpcServer) IsBootstrapped(ctx context.Context, request *pdpb.IsBootstr // AllocID implements gRPC PDServer. func (s *GrpcServer) AllocID(ctx context.Context, request *pdpb.AllocIDRequest) (*pdpb.AllocIDResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "AllocID"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.AllocIDResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).AllocID(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.AllocIDResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.AllocIDResponse), nil + } } // We can use an allocator for all types ID allocation. @@ -742,18 +694,23 @@ func (s *GrpcServer) AllocID(ctx context.Context, request *pdpb.AllocIDRequest) } // IsSnapshotRecovering implements gRPC PDServer. -func (s *GrpcServer) IsSnapshotRecovering(ctx context.Context, _ *pdpb.IsSnapshotRecoveringRequest) (*pdpb.IsSnapshotRecoveringResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { +func (s *GrpcServer) IsSnapshotRecovering(ctx context.Context, req *pdpb.IsSnapshotRecoveringRequest) (*pdpb.IsSnapshotRecoveringResponse, error) { + if midResp, err := s.unaryMiddleware(ctx, req, "IsSnapshotRecovering"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.IsSnapshotRecoveringResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.IsSnapshotRecoveringResponse), nil + } } + // recovering mark is stored in etcd directly, there's no need to forward. marked, err := s.Server.IsSnapshotRecovering(ctx) if err != nil { @@ -769,25 +726,22 @@ func (s *GrpcServer) IsSnapshotRecovering(ctx context.Context, _ *pdpb.IsSnapsho // GetStore implements gRPC PDServer. func (s *GrpcServer) GetStore(ctx context.Context, request *pdpb.GetStoreRequest) (*pdpb.GetStoreResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetStore"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetStoreResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetStoreResponse), nil + } } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).GetStore(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetStoreResponse), err - } + rc := s.GetRaftCluster() if rc == nil { return &pdpb.GetStoreResponse{Header: notBootstrappedHeader()}, nil @@ -825,24 +779,20 @@ func checkStore(rc *cluster.RaftCluster, storeID uint64) *pdpb.Error { // PutStore implements gRPC PDServer. func (s *GrpcServer) PutStore(ctx context.Context, request *pdpb.PutStoreRequest) (*pdpb.PutStoreResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "PutStore"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.PutStoreResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).PutStore(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.PutStoreResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.PutStoreResponse), nil + } } rc := s.GetRaftCluster() @@ -882,24 +832,20 @@ func (s *GrpcServer) PutStore(ctx context.Context, request *pdpb.PutStoreRequest // GetAllStores implements gRPC PDServer. func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStoresRequest) (*pdpb.GetAllStoresResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetAllStores"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetAllStoresResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).GetAllStores(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetAllStoresResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetAllStoresResponse), nil + } } rc := s.GetRaftCluster() @@ -927,24 +873,20 @@ func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStore // StoreHeartbeat implements gRPC PDServer. func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHeartbeatRequest) (*pdpb.StoreHeartbeatResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "StoreHearbeat"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.StoreHeartbeatResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).StoreHeartbeat(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.StoreHeartbeatResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.StoreHeartbeatResponse), nil + } } if request.GetStats() == nil { @@ -1429,32 +1371,32 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error // GetRegion implements gRPC PDServer. func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionRequest) (*pdpb.GetRegionResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetRegion"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetRegionResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetRegionResponse), nil + } } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).GetRegion(ctx, request) - } - followerHandle := new(bool) - if rsp, err := s.unaryFollowerMiddleware(ctx, request, fn, followerHandle); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetRegionResponse), nil - } + failpoint.Inject("delayProcess", nil) var ( - rc *cluster.RaftCluster - region *core.RegionInfo + rc *cluster.RaftCluster + followerHandle = !s.member.IsLeader() + region *core.RegionInfo ) - if *followerHandle { + if rc == nil { + return &pdpb.GetRegionResponse{Header: notBootstrappedHeader()}, nil + } + if followerHandle { rc = s.cluster if !rc.GetRegionSyncer().IsRunning() { return &pdpb.GetRegionResponse{Header: regionNotFound()}, nil @@ -1478,7 +1420,7 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque var buckets *metapb.Buckets // FIXME: If the bucket is disabled dynamically, the bucket information is returned unexpectedly - if !*followerHandle && rc.GetStoreConfig().IsEnableRegionBucket() && request.GetNeedBuckets() { + if !followerHandle && rc.GetStoreConfig().IsEnableRegionBucket() && request.GetNeedBuckets() { buckets = region.GetBuckets() } return &pdpb.GetRegionResponse{ @@ -1493,29 +1435,27 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque // GetPrevRegion implements gRPC PDServer func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionRequest) (*pdpb.GetRegionResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetPrevRegion"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetRegionResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).GetPrevRegion(ctx, request) - } - followerHandle := new(bool) - if rsp, err := s.unaryFollowerMiddleware(ctx, request, fn, followerHandle); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetRegionResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetRegionResponse), nil + } } - var rc *cluster.RaftCluster - if *followerHandle { + var ( + rc *cluster.RaftCluster + followerHandle = !s.member.IsLeader() + ) + if followerHandle { // no need to check running status rc = s.cluster if !rc.GetRegionSyncer().IsRunning() { @@ -1530,14 +1470,14 @@ func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionR region := rc.GetPrevRegionByKey(request.GetRegionKey()) if region == nil { - if *followerHandle { + if followerHandle { return &pdpb.GetRegionResponse{Header: regionNotFound()}, nil } return &pdpb.GetRegionResponse{Header: wrapHeader()}, nil } var buckets *metapb.Buckets // FIXME: If the bucket is disabled dynamically, the bucket information is returned unexpectedly - if !*followerHandle && rc.GetStoreConfig().IsEnableRegionBucket() && request.GetNeedBuckets() { + if !followerHandle && rc.GetStoreConfig().IsEnableRegionBucket() && request.GetNeedBuckets() { buckets = region.GetBuckets() } return &pdpb.GetRegionResponse{ @@ -1552,29 +1492,27 @@ func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionR // GetRegionByID implements gRPC PDServer. func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionByIDRequest) (*pdpb.GetRegionResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetRegionByID"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetRegionResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).GetRegionByID(ctx, request) - } - followerHandle := new(bool) - if rsp, err := s.unaryFollowerMiddleware(ctx, request, fn, followerHandle); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetRegionResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetRegionResponse), nil + } } - var rc *cluster.RaftCluster - if *followerHandle { + var ( + rc *cluster.RaftCluster + followerHandle = !s.member.IsLeader() + ) + if followerHandle { rc = s.cluster if !rc.GetRegionSyncer().IsRunning() { return &pdpb.GetRegionResponse{Header: regionNotFound()}, nil @@ -1587,18 +1525,18 @@ func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionB } region := rc.GetRegion(request.GetRegionId()) failpoint.Inject("followerHandleError", func() { - if *followerHandle { + if followerHandle { region = nil } }) if region == nil { - if *followerHandle { + if followerHandle { return &pdpb.GetRegionResponse{Header: regionNotFound()}, nil } return &pdpb.GetRegionResponse{Header: wrapHeader()}, nil } var buckets *metapb.Buckets - if !*followerHandle && rc.GetStoreConfig().IsEnableRegionBucket() && request.GetNeedBuckets() { + if !followerHandle && rc.GetStoreConfig().IsEnableRegionBucket() && request.GetNeedBuckets() { buckets = region.GetBuckets() } return &pdpb.GetRegionResponse{ @@ -1614,29 +1552,27 @@ func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionB // Deprecated: use BatchScanRegions instead. // ScanRegions implements gRPC PDServer. func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsRequest) (*pdpb.ScanRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "ScanRegions"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.ScanRegionsResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).ScanRegions(ctx, request) - } - followerHandle := new(bool) - if rsp, err := s.unaryFollowerMiddleware(ctx, request, fn, followerHandle); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.ScanRegionsResponse), nil + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.ScanRegionsResponse), nil + } } - var rc *cluster.RaftCluster - if *followerHandle { + var ( + rc *cluster.RaftCluster + followerHandle = !s.member.IsLeader() + ) + if followerHandle { rc = s.cluster if !rc.GetRegionSyncer().IsRunning() { return &pdpb.ScanRegionsResponse{Header: regionNotFound()}, nil @@ -1648,7 +1584,7 @@ func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsR } } regions := rc.ScanRegions(request.GetStartKey(), request.GetEndKey(), int(request.GetLimit())) - if *followerHandle && len(regions) == 0 { + if followerHandle && len(regions) == 0 { return &pdpb.ScanRegionsResponse{Header: regionNotFound()}, nil } resp := &pdpb.ScanRegionsResponse{Header: wrapHeader()} @@ -1672,29 +1608,27 @@ func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsR // BatchScanRegions implements gRPC PDServer. func (s *GrpcServer) BatchScanRegions(ctx context.Context, request *pdpb.BatchScanRegionsRequest) (*pdpb.BatchScanRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "BatchScanRegions"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.BatchScanRegionsResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).BatchScanRegions(ctx, request) - } - followerHandle := new(bool) - if rsp, err := s.unaryFollowerMiddleware(ctx, request, fn, followerHandle); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.BatchScanRegionsResponse), nil + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.BatchScanRegionsResponse), nil + } } - var rc *cluster.RaftCluster - if *followerHandle { + var ( + rc *cluster.RaftCluster + followerHandle = !s.member.IsLeader() + ) + if followerHandle { rc = s.cluster if !rc.GetRegionSyncer().IsRunning() { return &pdpb.BatchScanRegionsResponse{Header: regionNotFound()}, nil @@ -1705,7 +1639,7 @@ func (s *GrpcServer) BatchScanRegions(ctx context.Context, request *pdpb.BatchSc return &pdpb.BatchScanRegionsResponse{Header: notBootstrappedHeader()}, nil } } - needBucket := request.GetNeedBuckets() && !*followerHandle && rc.GetStoreConfig().IsEnableRegionBucket() + needBucket := request.GetNeedBuckets() && !followerHandle && rc.GetStoreConfig().IsEnableRegionBucket() limit := request.GetLimit() // cast to core.KeyRanges and check the validation. keyRanges := core.NewKeyRangesWithSize(len(request.GetRanges())) @@ -1755,7 +1689,7 @@ func (s *GrpcServer) BatchScanRegions(ctx context.Context, request *pdpb.BatchSc Buckets: buckets, }) } - if *followerHandle && len(regions) == 0 { + if followerHandle && len(regions) == 0 { return &pdpb.BatchScanRegionsResponse{Header: regionNotFound()}, nil } resp := &pdpb.BatchScanRegionsResponse{Header: wrapHeader(), Regions: regions} @@ -1764,24 +1698,20 @@ func (s *GrpcServer) BatchScanRegions(ctx context.Context, request *pdpb.BatchSc // AskSplit implements gRPC PDServer. func (s *GrpcServer) AskSplit(ctx context.Context, request *pdpb.AskSplitRequest) (*pdpb.AskSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "AskSplit"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.AskSplitResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).AskSplit(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.AskSplitResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.AskSplitResponse), nil + } } rc := s.GetRaftCluster() @@ -1810,16 +1740,20 @@ func (s *GrpcServer) AskSplit(ctx context.Context, request *pdpb.AskSplitRequest // AskBatchSplit implements gRPC PDServer. func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSplitRequest) (*pdpb.AskBatchSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "AskBatchSplit"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.AskBatchSplitResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.AskBatchSplitResponse), nil + } } rc := s.GetRaftCluster() @@ -1853,14 +1787,6 @@ func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSp return convertAskSplitResponse(resp), nil } } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).AskBatchSplit(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.AskBatchSplitResponse), err - } if !versioninfo.IsFeatureSupported(rc.GetOpts().GetClusterVersion(), versioninfo.BatchSplit) { return &pdpb.AskBatchSplitResponse{Header: s.incompatibleVersion("batch_split")}, nil @@ -1886,24 +1812,20 @@ func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSp // ReportSplit implements gRPC PDServer. func (s *GrpcServer) ReportSplit(ctx context.Context, request *pdpb.ReportSplitRequest) (*pdpb.ReportSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "ReportSplit"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.ReportSplitResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).ReportSplit(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.ReportSplitResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.ReportSplitResponse), nil + } } rc := s.GetRaftCluster() @@ -1924,24 +1846,20 @@ func (s *GrpcServer) ReportSplit(ctx context.Context, request *pdpb.ReportSplitR // ReportBatchSplit implements gRPC PDServer. func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportBatchSplitRequest) (*pdpb.ReportBatchSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "ReportBatchSplit"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.ReportBatchSplitResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).ReportBatchSplit(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.ReportBatchSplitResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.ReportBatchSplitResponse), nil + } } rc := s.GetRaftCluster() @@ -1963,24 +1881,20 @@ func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportB // GetClusterConfig implements gRPC PDServer. func (s *GrpcServer) GetClusterConfig(ctx context.Context, request *pdpb.GetClusterConfigRequest) (*pdpb.GetClusterConfigResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetClusterConfig"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetClusterConfigResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).GetClusterConfig(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetClusterConfigResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetClusterConfigResponse), nil + } } rc := s.GetRaftCluster() @@ -1995,24 +1909,20 @@ func (s *GrpcServer) GetClusterConfig(ctx context.Context, request *pdpb.GetClus // PutClusterConfig implements gRPC PDServer. func (s *GrpcServer) PutClusterConfig(ctx context.Context, request *pdpb.PutClusterConfigRequest) (*pdpb.PutClusterConfigResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "PutClusterConfig"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.PutClusterConfigResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).PutClusterConfig(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.PutClusterConfigResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.PutClusterConfigResponse), nil + } } rc := s.GetRaftCluster() @@ -2036,16 +1946,20 @@ func (s *GrpcServer) PutClusterConfig(ctx context.Context, request *pdpb.PutClus // ScatterRegion implements gRPC PDServer. func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterRegionRequest) (*pdpb.ScatterRegionResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "ScatterRegion"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.ScatterRegionResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.ScatterRegionResponse), nil + } } rc := s.GetRaftCluster() @@ -2096,15 +2010,6 @@ func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterReg } } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).ScatterRegion(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.ScatterRegionResponse), err - } - if len(request.GetRegionsId()) > 0 { percentage, err := scatterRegions(rc, request.GetRegionsId(), request.GetGroup(), int(request.GetRetryLimit()), request.GetSkipStoreLimit()) if err != nil { @@ -2150,26 +2055,21 @@ func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterReg // GetGCSafePoint implements gRPC PDServer. func (s *GrpcServer) GetGCSafePoint(ctx context.Context, request *pdpb.GetGCSafePointRequest) (*pdpb.GetGCSafePointResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetGCSafePoint"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetGCSafePointResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetGCSafePointResponse), nil + } } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).GetGCSafePoint(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetGCSafePointResponse), err - } - rc := s.GetRaftCluster() if rc == nil { return &pdpb.GetGCSafePointResponse{Header: notBootstrappedHeader()}, nil @@ -2209,26 +2109,21 @@ func (s *GrpcServer) SyncRegions(stream pdpb.PD_SyncRegionsServer) error { // UpdateGCSafePoint implements gRPC PDServer. func (s *GrpcServer) UpdateGCSafePoint(ctx context.Context, request *pdpb.UpdateGCSafePointRequest) (*pdpb.UpdateGCSafePointResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "UpdateGCSafePoint"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.UpdateGCSafePointResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.UpdateGCSafePointResponse), nil + } } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).UpdateGCSafePoint(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.UpdateGCSafePointResponse), err - } - rc := s.GetRaftCluster() if rc == nil { return &pdpb.UpdateGCSafePointResponse{Header: notBootstrappedHeader()}, nil @@ -2258,24 +2153,20 @@ func (s *GrpcServer) UpdateGCSafePoint(ctx context.Context, request *pdpb.Update // UpdateServiceGCSafePoint update the safepoint for specific service func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb.UpdateServiceGCSafePointRequest) (*pdpb.UpdateServiceGCSafePointResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "UpdateServiceGCSafePoint"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.UpdateServiceGCSafePointResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).UpdateServiceGCSafePoint(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.UpdateServiceGCSafePointResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.UpdateServiceGCSafePointResponse), nil + } } rc := s.GetRaftCluster() @@ -2314,16 +2205,20 @@ func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb // GetOperator gets information about the operator belonging to the specify region. func (s *GrpcServer) GetOperator(ctx context.Context, request *pdpb.GetOperatorRequest) (*pdpb.GetOperatorResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetOperator"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetOperatorResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetOperatorResponse), nil + } } rc := s.GetRaftCluster() @@ -2357,14 +2252,6 @@ func (s *GrpcServer) GetOperator(ctx context.Context, request *pdpb.GetOperatorR return convertOperatorResponse(resp), nil } } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).GetOperator(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetOperatorResponse), err - } opController := rc.GetOperatorController() requestID := request.GetRegionId() @@ -2388,24 +2275,24 @@ func (s *GrpcServer) GetOperator(ctx context.Context, request *pdpb.GetOperatorR // validateRequest checks if Server is leader and clusterID is matched. func (s *GrpcServer) validateRequest(header *pdpb.RequestHeader) error { - return s.validateRoleInRequest(context.TODO(), header, nil) + return s.validateRoleInRequest(context.TODO(), header, "") } // validateRoleInRequest checks if Server is leader when disallow follower-handle and clusterID is matched. // TODO: Call it in gRPC interceptor. -func (s *GrpcServer) validateRoleInRequest(ctx context.Context, header *pdpb.RequestHeader, allowFollower *bool) error { +func (s *GrpcServer) validateRoleInRequest(ctx context.Context, header *pdpb.RequestHeader, methodName string) error { if s.IsClosed() { return ErrNotStarted } + // Check follower handle if !s.member.IsLeader() { - if allowFollower == nil { + if _, ok := allowFollowerMethods[methodName]; !ok { return ErrNotLeader } if !grpcutil.IsFollowerHandleEnabled(ctx) { // TODO: change the error code return ErrFollowerHandlingNotAllowed } - *allowFollower = true } if clusterID := keypath.ClusterID(); header.GetClusterId() != clusterID { return status.Errorf(codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", clusterID, header.GetClusterId()) @@ -2634,16 +2521,20 @@ func (s *GrpcServer) SyncMaxTS(_ context.Context, request *pdpb.SyncMaxTSRequest // SplitRegions split regions by the given split keys func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegionsRequest) (*pdpb.SplitRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "SplitRegions"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.SplitRegionsResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.SplitRegionsResponse), nil + } } rc := s.GetRaftCluster() @@ -2679,15 +2570,6 @@ func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegion } } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).SplitRegions(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.SplitRegionsResponse), err - } - finishedPercentage, newRegionIDs := rc.GetRegionSplitter().SplitRegions(ctx, request.GetSplitKeys(), int(request.GetRetryLimit())) return &pdpb.SplitRegionsResponse{ Header: wrapHeader(), @@ -2700,24 +2582,20 @@ func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegion // Only regions which split successfully will be scattered. // scatterFinishedPercentage indicates the percentage of successfully split regions that are scattered. func (s *GrpcServer) SplitAndScatterRegions(ctx context.Context, request *pdpb.SplitAndScatterRegionsRequest) (*pdpb.SplitAndScatterRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "SplitAndScatterRegions"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.SplitAndScatterRegionsResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).SplitAndScatterRegions(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.SplitAndScatterRegionsResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.SplitAndScatterRegionsResponse), nil + } } rc := s.GetRaftCluster() if rc == nil { @@ -3028,24 +2906,20 @@ func (s *GrpcServer) handleDamagedStore(stats *pdpb.StoreStats) { // ReportMinResolvedTS implements gRPC PDServer. func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.ReportMinResolvedTsRequest) (*pdpb.ReportMinResolvedTsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "ReportMinResolvedTS"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.ReportMinResolvedTsResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).ReportMinResolvedTS(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.ReportMinResolvedTsResponse), nil + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.ReportMinResolvedTsResponse), nil + } } rc := s.GetRaftCluster() @@ -3068,24 +2942,20 @@ func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.Repo // SetExternalTimestamp implements gRPC PDServer. func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.SetExternalTimestampRequest) (*pdpb.SetExternalTimestampResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "SetExternalTimestamp"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.SetExternalTimestampResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).SetExternalTimestamp(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.SetExternalTimestampResponse), nil + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.SetExternalTimestampResponse), nil + } } nowTSO, err := s.getGlobalTSO(ctx) @@ -3106,24 +2976,20 @@ func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.Set // GetExternalTimestamp implements gRPC PDServer. func (s *GrpcServer) GetExternalTimestamp(ctx context.Context, request *pdpb.GetExternalTimestampRequest) (*pdpb.GetExternalTimestampResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetExternalTimestamp"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetExternalTimestampResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { - return pdpb.NewPDClient(client).GetExternalTimestamp(ctx, request) - } - if rsp, err := s.unaryMiddleware(ctx, request, fn); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetExternalTimestampResponse), nil + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetExternalTimestampResponse), nil + } } timestamp := s.GetExternalTS() @@ -3132,9 +2998,3 @@ func (s *GrpcServer) GetExternalTimestamp(ctx context.Context, request *pdpb.Get Timestamp: timestamp, }, nil } - -func currentFunction() string { - counter, _, _, _ := runtime.Caller(1) - s := strings.Split(runtime.FuncForPC(counter).Name(), ".") - return s[len(s)-1] -} diff --git a/server/middleware.go b/server/middleware.go new file mode 100644 index 00000000000..667c9da6bea --- /dev/null +++ b/server/middleware.go @@ -0,0 +1,205 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "context" + "runtime" + "strings" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/tikv/pd/pkg/utils/grpcutil" + "google.golang.org/grpc" +) + +type request interface { + GetHeader() *pdpb.RequestHeader +} + +type forwardFn func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) + +var forwardFns = map[string]forwardFn{ + "GetMinTS": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).GetMinTS(ctx, request.(*pdpb.GetMinTSRequest)) + }, + "Bootstrap": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).Bootstrap(ctx, request.(*pdpb.BootstrapRequest)) + }, + "IsBootstrapped": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).IsBootstrapped(ctx, request.(*pdpb.IsBootstrappedRequest)) + }, + "AllocID": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).AllocID(ctx, request.(*pdpb.AllocIDRequest)) + }, + "GetStore": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).GetStore(ctx, request.(*pdpb.GetStoreRequest)) + }, + "PutStore": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).PutStore(ctx, request.(*pdpb.PutStoreRequest)) + }, + "GetAllStores": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).GetAllStores(ctx, request.(*pdpb.GetAllStoresRequest)) + }, + "StoreHeartbeat": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).StoreHeartbeat(ctx, request.(*pdpb.StoreHeartbeatRequest)) + }, + "AskSplit": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).AskSplit(ctx, request.(*pdpb.AskSplitRequest)) + }, + "AskBatchSplit": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).AskBatchSplit(ctx, request.(*pdpb.AskBatchSplitRequest)) + }, + "ReportSplit": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).ReportSplit(ctx, request.(*pdpb.ReportSplitRequest)) + }, + "ReportBatchSplit": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).ReportBatchSplit(ctx, request.(*pdpb.ReportBatchSplitRequest)) + }, + "GetClusterConfig": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).GetClusterConfig(ctx, request.(*pdpb.GetClusterConfigRequest)) + }, + "PutClusterConfig": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).PutClusterConfig(ctx, request.(*pdpb.PutClusterConfigRequest)) + }, + "ScatterRegion": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).ScatterRegion(ctx, request.(*pdpb.ScatterRegionRequest)) + }, + "GetGCSafePoint": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).GetGCSafePoint(ctx, request.(*pdpb.GetGCSafePointRequest)) + }, + "UpdateGCSafePoint": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).UpdateGCSafePoint(ctx, request.(*pdpb.UpdateGCSafePointRequest)) + }, + "UpdateServiceGCSafePoint": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).UpdateServiceGCSafePoint(ctx, request.(*pdpb.UpdateServiceGCSafePointRequest)) + }, + "GetOperator": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).GetOperator(ctx, request.(*pdpb.GetOperatorRequest)) + }, + "SplitRegions": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).SplitRegions(ctx, request.(*pdpb.SplitRegionsRequest)) + }, + "SplitAndScatterRegions": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).SplitAndScatterRegions(ctx, request.(*pdpb.SplitAndScatterRegionsRequest)) + }, + "ReportMinResolvedTS": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).ReportMinResolvedTS(ctx, request.(*pdpb.ReportMinResolvedTsRequest)) + }, + "SetExternalTimestamp": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).SetExternalTimestamp(ctx, request.(*pdpb.SetExternalTimestampRequest)) + }, + "GetExternalTimestamp": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).GetExternalTimestamp(ctx, request.(*pdpb.GetExternalTimestampRequest)) + }, + + "GetGCSafePointV2": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).GetGCSafePointV2(ctx, request.(*pdpb.GetGCSafePointV2Request)) + }, + "UpdateGCSafePointV2": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).UpdateGCSafePointV2(ctx, request.(*pdpb.UpdateGCSafePointV2Request)) + }, + "UpdateServiceSafePointV2": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).UpdateServiceSafePointV2(ctx, request.(*pdpb.UpdateServiceSafePointV2Request)) + }, + "GetAllGCSafePointV2": func(ctx context.Context, client *grpc.ClientConn, request any) (any, error) { + return pdpb.NewPDClient(client).GetAllGCSafePointV2(ctx, request.(*pdpb.GetAllGCSafePointV2Request)) + }, + + "GetRegion": func(ctx context.Context, client *grpc.ClientConn, req any) (any, error) { + return pdpb.NewPDClient(client).GetRegion(ctx, req.(*pdpb.GetRegionRequest)) + }, + "GetPreRegion": func(ctx context.Context, client *grpc.ClientConn, req any) (any, error) { + return pdpb.NewPDClient(client).GetPrevRegion(ctx, req.(*pdpb.GetRegionRequest)) + }, + "GetRegionByID": func(ctx context.Context, client *grpc.ClientConn, req any) (any, error) { + return pdpb.NewPDClient(client).GetRegionByID(ctx, req.(*pdpb.GetRegionByIDRequest)) + }, + "ScanRegions": func(ctx context.Context, client *grpc.ClientConn, req any) (any, error) { + return pdpb.NewPDClient(client).ScanRegions(ctx, req.(*pdpb.ScanRegionsRequest)) + }, + "BatchScanRegions": func(ctx context.Context, client *grpc.ClientConn, req any) (any, error) { + return pdpb.NewPDClient(client).BatchScanRegions(ctx, req.(*pdpb.BatchScanRegionsRequest)) + }, +} + +var allowFollowerMethods = map[string]struct{}{ + "GetRegion": {}, + "GetPrevRegion": {}, + "GetRegionByID": {}, + "ScanRegions": {}, + "BatchScanRegions": {}, +} + +var notRateLimitMethods = map[string]struct{}{ + "GetGCSafePointV2": {}, + "UpdateGCSafePointV2": {}, + "UpdateServiceSafePointV2": {}, + "GetAllGCSafePointV2": {}, +} + +type middlewareResponse struct { + resp any + header *pdpb.ResponseHeader + deferFunc func() +} + +func (s *GrpcServer) unaryMiddleware(ctx context.Context, req request, methodName string) (rsp *middlewareResponse, err error) { + midResp := &middlewareResponse{} + _, ok := notRateLimitMethods[methodName] + if !ok && s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { + limiter := s.GetGRPCRateLimiter() + if done, err := limiter.Allow(methodName); err != nil { + midResp.header = wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()) + return midResp, nil + } else { + midResp.deferFunc = done + } + } + resp, err := s.unaryFollowerMiddleware(ctx, req, forwardFns[methodName]) + if resp != nil || err != nil { + midResp.resp = resp + return midResp, err + } + if err := s.validateRoleInRequest(ctx, req.GetHeader(), methodName); err != nil { + return nil, err + } + return midResp, nil +} + +// unaryFollowerMiddleware forward the request to the leader if the request is +// not sent by the leader. (client <-> follower <-> leader) +func (s *GrpcServer) unaryFollowerMiddleware(ctx context.Context, req request, fn forwardFn) (rsp any, err error) { + failpoint.Inject("customTimeout", func() { + time.Sleep(5 * time.Second) + }) + forwardedHost := grpcutil.GetForwardedHost(ctx) + if s.isLocalRequest(forwardedHost) { + return nil, nil + } + client, err := s.getDelegateClient(ctx, forwardedHost) + if err != nil { + return nil, err + } + ctx = grpcutil.ResetForwardContext(ctx) + return fn(ctx, client, req) +} + +func currentFunction() string { + counter, _, _, _ := runtime.Caller(1) + s := strings.Split(runtime.FuncForPC(counter).Name(), ".") + return s[len(s)-1] +}