From 054a3d43e047b2baeb22ad6389c09f337d1fc094 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Wed, 25 Dec 2024 15:39:39 +0800 Subject: [PATCH 1/3] *: unify the gRPC errors (#8910) ref tikv/pd#8922 Signed-off-by: Ryan Leung --- errors.toml | 4 +- pkg/errs/errno.go | 67 ++++++++- pkg/mcs/metastorage/server/grpc_service.go | 10 +- .../resourcemanager/server/grpc_service.go | 10 +- pkg/mcs/scheduling/server/grpc_service.go | 10 +- pkg/mcs/tso/server/grpc_service.go | 21 +-- pkg/mcs/tso/server/server.go | 9 +- pkg/syncer/server.go | 4 +- server/api/admin.go | 2 +- server/api/config.go | 2 +- server/forward.go | 10 +- server/grpc_service.go | 138 ++++++++---------- tests/server/cluster/cluster_test.go | 3 +- 13 files changed, 148 insertions(+), 142 deletions(-) diff --git a/errors.toml b/errors.toml index 785de6662f4..2ab3b014f5a 100644 --- a/errors.toml +++ b/errors.toml @@ -661,9 +661,9 @@ error = ''' init file log error, %s ''' -["PD:mcs:ErrNotFoundSchedulingAddr"] +["PD:mcs:ErrNotFoundSchedulingPrimary"] error = ''' -cannot find scheduling address +cannot find scheduling primary ''' ["PD:mcs:ErrSchedulingServer"] diff --git a/pkg/errs/errno.go b/pkg/errs/errno.go index 30e24647a3f..ee24b4d0673 100644 --- a/pkg/errs/errno.go +++ b/pkg/errs/errno.go @@ -14,7 +14,12 @@ package errs -import "github.com/pingcap/errors" +import ( + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/pingcap/errors" +) const ( // NotLeaderErr indicates the non-leader member received the requests which should be received by leader. @@ -31,6 +36,62 @@ const ( NotServedErr = "is not served" ) +// gRPC errors +var ( + // Canceled indicates the operation was canceled (typically by the caller). + ErrStreamClosed = status.Error(codes.Canceled, "stream is closed") + + // Unknown error. An example of where this error may be returned is + // if a Status value received from another address space belongs to + // an error-space that is not known in this address space. Also + // errors raised by APIs that do not return enough error information + // may be converted to this error. + ErrUnknown = func(err error) error { + return status.Error(codes.Unknown, err.Error()) + } + + // DeadlineExceeded means operation expired before completion. + // For operations that change the state of the system, this error may be + // returned even if the operation has completed successfully. For + // example, a successful response from a server could have been delayed + // long enough for the deadline to expire. + ErrForwardTSOTimeout = status.Error(codes.DeadlineExceeded, "forward tso request timeout") + ErrTSOProxyRecvFromClientTimeout = status.Error(codes.DeadlineExceeded, "tso proxy timeout when receiving from client; stream closed by server") + ErrSendHeartbeatTimeout = status.Error(codes.DeadlineExceeded, "send heartbeat timeout") + + // NotFound means some requested entity (e.g., file or directory) was + // not found. + ErrNotFoundTSOAddr = status.Error(codes.NotFound, "not found tso address") + ErrNotFoundSchedulingAddr = status.Error(codes.NotFound, "not found scheduling address") + ErrNotFoundService = status.Error(codes.NotFound, "not found service") + + // ResourceExhausted indicates some resource has been exhausted, perhaps + // a per-user quota, or perhaps the entire file system is out of space. + ErrMaxCountTSOProxyRoutinesExceeded = status.Error(codes.ResourceExhausted, "max count of concurrent tso proxy routines exceeded") + ErrGRPCRateLimitExceeded = func(err error) error { + return status.Error(codes.ResourceExhausted, err.Error()) + } + + // FailedPrecondition indicates operation was rejected because the + // system is not in a state required for the operation's execution. + // For example, directory to be deleted may be non-empty, an rmdir + // operation is applied to a non-directory, etc. + ErrMismatchClusterID = func(clusterID, requestClusterID uint64) error { + return status.Errorf(codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", clusterID, requestClusterID) + } + + // Unavailable indicates the service is currently unavailable. + // This is a most likely a transient condition and may be corrected + // by retrying with a backoff. Note that it is not always safe to retry + // non-idempotent operations. + // ErrNotLeader is returned when current server is not the leader and not possible to process request. + // TODO: work as proxy. + ErrNotLeader = status.Error(codes.Unavailable, "not leader") + ErrNotStarted = status.Error(codes.Unavailable, "server not started") + ErrEtcdNotStarted = status.Error(codes.Unavailable, "server is started, but etcd not started") + ErrFollowerHandlingNotAllowed = status.Error(codes.Unavailable, "not leader and follower handling not allowed") +) + // common error in multiple packages var ( ErrGetSourceStore = errors.Normalize("failed to get the source store", errors.RFCCodeText("PD:common:ErrGetSourceStore")) @@ -484,6 +545,6 @@ var ( // Micro service errors var ( - ErrNotFoundSchedulingAddr = errors.Normalize("cannot find scheduling address", errors.RFCCodeText("PD:mcs:ErrNotFoundSchedulingAddr")) - ErrSchedulingServer = errors.Normalize("scheduling server meets %v", errors.RFCCodeText("PD:mcs:ErrSchedulingServer")) + ErrNotFoundSchedulingPrimary = errors.Normalize("cannot find scheduling primary", errors.RFCCodeText("PD:mcs:ErrNotFoundSchedulingPrimary")) + ErrSchedulingServer = errors.Normalize("scheduling server meets %v", errors.RFCCodeText("PD:mcs:ErrSchedulingServer")) ) diff --git a/pkg/mcs/metastorage/server/grpc_service.go b/pkg/mcs/metastorage/server/grpc_service.go index 00f4efb56fd..af09bd3a987 100644 --- a/pkg/mcs/metastorage/server/grpc_service.go +++ b/pkg/mcs/metastorage/server/grpc_service.go @@ -22,23 +22,17 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/pingcap/kvproto/pkg/meta_storagepb" "github.com/pingcap/log" bs "github.com/tikv/pd/pkg/basicserver" + "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/mcs/registry" "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/pkg/utils/keypath" ) -var ( - // errNotLeader is returned when current server is not the leader. - errNotLeader = status.Errorf(codes.Unavailable, "not leader") -) - var _ meta_storagepb.MetaStorageServer = (*Service)(nil) // SetUpRestHandler is a hook to sets up the REST service. @@ -81,7 +75,7 @@ func (*Service) RegisterRESTHandler(_ map[string]http.Handler) error { func (s *Service) checkServing() error { if s.manager == nil || s.manager.srv == nil || !s.manager.srv.IsServing() { - return errNotLeader + return errs.ErrNotLeader } return nil } diff --git a/pkg/mcs/resourcemanager/server/grpc_service.go b/pkg/mcs/resourcemanager/server/grpc_service.go index 6c0d7ce0120..4cc162f7145 100644 --- a/pkg/mcs/resourcemanager/server/grpc_service.go +++ b/pkg/mcs/resourcemanager/server/grpc_service.go @@ -22,8 +22,6 @@ import ( "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/pingcap/errors" "github.com/pingcap/failpoint" @@ -31,15 +29,11 @@ import ( "github.com/pingcap/log" bs "github.com/tikv/pd/pkg/basicserver" + "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/mcs/registry" "github.com/tikv/pd/pkg/utils/apiutil" ) -var ( - // errNotLeader is returned when current server is not the leader. - errNotLeader = status.Errorf(codes.Unavailable, "not leader") -) - var _ rmpb.ResourceManagerServer = (*Service)(nil) // SetUpRestHandler is a hook to sets up the REST service. @@ -89,7 +83,7 @@ func (s *Service) GetManager() *Manager { func (s *Service) checkServing() error { if s.manager == nil || s.manager.srv == nil || !s.manager.srv.IsServing() { - return errNotLeader + return errs.ErrNotLeader } return nil } diff --git a/pkg/mcs/scheduling/server/grpc_service.go b/pkg/mcs/scheduling/server/grpc_service.go index 440b2d47d4f..3d1183bf734 100644 --- a/pkg/mcs/scheduling/server/grpc_service.go +++ b/pkg/mcs/scheduling/server/grpc_service.go @@ -23,8 +23,6 @@ import ( "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/pdpb" @@ -41,12 +39,6 @@ import ( "github.com/tikv/pd/pkg/versioninfo" ) -// gRPC errors -var ( - ErrNotStarted = status.Errorf(codes.Unavailable, "server not started") - ErrClusterMismatched = status.Errorf(codes.Unavailable, "cluster mismatched") -) - // SetUpRestHandler is a hook to sets up the REST service. var SetUpRestHandler = func(*Service) (http.Handler, apiutil.APIServiceGroup) { return dummyRestService{}, apiutil.APIServiceGroup{} @@ -107,7 +99,7 @@ func (s *heartbeatServer) Send(m core.RegionHeartbeatResponse) error { return errors.WithStack(err) case <-timer.C: atomic.StoreInt32(&s.closed, 1) - return status.Errorf(codes.DeadlineExceeded, "send heartbeat timeout") + return errs.ErrSendHeartbeatTimeout } } diff --git a/pkg/mcs/tso/server/grpc_service.go b/pkg/mcs/tso/server/grpc_service.go index 59abed67213..dc85a730651 100644 --- a/pkg/mcs/tso/server/grpc_service.go +++ b/pkg/mcs/tso/server/grpc_service.go @@ -22,25 +22,18 @@ import ( "time" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/tsopb" "github.com/pingcap/log" bs "github.com/tikv/pd/pkg/basicserver" + "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/mcs/registry" "github.com/tikv/pd/pkg/utils/apiutil" "github.com/tikv/pd/pkg/utils/keypath" ) -// gRPC errors -var ( - ErrNotStarted = status.Errorf(codes.Unavailable, "server not started") - ErrClusterMismatched = status.Errorf(codes.Unavailable, "cluster mismatched") -) - var _ tsopb.TSOServer = (*Service)(nil) // SetUpRestHandler is a hook to sets up the REST service. @@ -102,14 +95,12 @@ func (s *Service) Tso(stream tsopb.TSO_TsoServer) error { start := time.Now() // TSO uses leader lease to determine validity. No need to check leader here. if s.IsClosed() { - return status.Errorf(codes.Unknown, "server not started") + return errs.ErrNotStarted } header := request.GetHeader() clusterID := header.GetClusterId() if clusterID != keypath.ClusterID() { - return status.Errorf( - codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", - keypath.ClusterID(), clusterID) + return errs.ErrMismatchClusterID(keypath.ClusterID(), clusterID) } keyspaceID := header.GetKeyspaceId() keyspaceGroupID := header.GetKeyspaceGroupId() @@ -119,7 +110,7 @@ func (s *Service) Tso(stream tsopb.TSO_TsoServer) error { keyspaceID, keyspaceGroupID, count) if err != nil { - return status.Error(codes.Unknown, err.Error()) + return errs.ErrUnknown(err) } keyspaceGroupIDStr := strconv.FormatUint(uint64(keyspaceGroupID), 10) tsoHandleDuration.WithLabelValues(keyspaceGroupIDStr).Observe(time.Since(start).Seconds()) @@ -220,10 +211,10 @@ func (s *Service) GetMinTS( func (s *Service) validRequest(header *tsopb.RequestHeader) (tsopb.ErrorType, error) { if s.IsClosed() || s.keyspaceGroupManager == nil { - return tsopb.ErrorType_NOT_BOOTSTRAPPED, ErrNotStarted + return tsopb.ErrorType_NOT_BOOTSTRAPPED, errs.ErrNotStarted } if header == nil || header.GetClusterId() != keypath.ClusterID() { - return tsopb.ErrorType_CLUSTER_MISMATCHED, ErrClusterMismatched + return tsopb.ErrorType_CLUSTER_MISMATCHED, errs.ErrMismatchClusterID(keypath.ClusterID(), header.GetClusterId()) } return tsopb.ErrorType_OK, nil } diff --git a/pkg/mcs/tso/server/server.go b/pkg/mcs/tso/server/server.go index 34f51573baf..ebd0cca8344 100644 --- a/pkg/mcs/tso/server/server.go +++ b/pkg/mcs/tso/server/server.go @@ -30,8 +30,6 @@ import ( "github.com/spf13/cobra" "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/diagnosticspb" @@ -279,7 +277,7 @@ func (s *Server) GetTSOAllocatorManager(keyspaceGroupID uint32) (*tso.AllocatorM // TODO: Check if the sender is from the global TSO allocator func (s *Server) ValidateInternalRequest(_ *tsopb.RequestHeader, _ bool) error { if s.IsClosed() { - return ErrNotStarted + return errs.ErrNotStarted } return nil } @@ -288,11 +286,10 @@ func (s *Server) ValidateInternalRequest(_ *tsopb.RequestHeader, _ bool) error { // TODO: Check if the keyspace replica is the primary func (s *Server) ValidateRequest(header *tsopb.RequestHeader) error { if s.IsClosed() { - return ErrNotStarted + return errs.ErrNotStarted } if header.GetClusterId() != keypath.ClusterID() { - return status.Errorf(codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", - keypath.ClusterID(), header.GetClusterId()) + return errs.ErrMismatchClusterID(keypath.ClusterID(), header.GetClusterId()) } return nil } diff --git a/pkg/syncer/server.go b/pkg/syncer/server.go index 89af3f79ccc..150ff738c15 100644 --- a/pkg/syncer/server.go +++ b/pkg/syncer/server.go @@ -23,8 +23,6 @@ import ( "github.com/docker/go-units" "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/pingcap/errors" "github.com/pingcap/failpoint" @@ -208,7 +206,7 @@ func (s *RegionSyncer) Sync(ctx context.Context, stream pdpb.PD_SyncRegionsServe } clusterID := request.GetHeader().GetClusterId() if clusterID != keypath.ClusterID() { - return status.Errorf(codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", keypath.ClusterID(), clusterID) + return errs.ErrMismatchClusterID(keypath.ClusterID(), clusterID) } log.Info("establish sync region stream", zap.String("requested-server", request.GetMember().GetName()), diff --git a/server/api/admin.go b/server/api/admin.go index 434508c98df..d2be53cf40e 100644 --- a/server/api/admin.go +++ b/server/api/admin.go @@ -231,7 +231,7 @@ func (h *adminHandler) recoverAllocID(w http.ResponseWriter, r *http.Request) { func (h *adminHandler) deleteRegionCacheInSchedulingServer(id ...uint64) error { addr, ok := h.svr.GetServicePrimaryAddr(h.svr.Context(), constant.SchedulingServiceName) if !ok { - return errs.ErrNotFoundSchedulingAddr.FastGenByArgs() + return errs.ErrNotFoundSchedulingPrimary.FastGenByArgs() } var idStr string if len(id) > 0 { diff --git a/server/api/config.go b/server/api/config.go index 3eda889507f..a27a1ed5e9b 100644 --- a/server/api/config.go +++ b/server/api/config.go @@ -566,7 +566,7 @@ func (h *confHandler) GetPDServerConfig(w http.ResponseWriter, _ *http.Request) func (h *confHandler) getSchedulingServerConfig() (*config.Config, error) { addr, ok := h.svr.GetServicePrimaryAddr(h.svr.Context(), constant.SchedulingServiceName) if !ok { - return nil, errs.ErrNotFoundSchedulingAddr.FastGenByArgs() + return nil, errs.ErrNotFoundSchedulingPrimary.FastGenByArgs() } url := fmt.Sprintf("%s/scheduling/api/v1/config", addr) req, err := http.NewRequest(http.MethodGet, url, http.NoBody) diff --git a/server/forward.go b/server/forward.go index b3d0d63c81b..9a604410fc0 100644 --- a/server/forward.go +++ b/server/forward.go @@ -22,8 +22,6 @@ import ( "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/pingcap/errors" "github.com/pingcap/failpoint" @@ -107,7 +105,7 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { maxConcurrentTSOProxyStreamings := int32(s.GetMaxConcurrentTSOProxyStreamings()) if maxConcurrentTSOProxyStreamings >= 0 { if newCount := s.concurrentTSOProxyStreamings.Add(1); newCount > maxConcurrentTSOProxyStreamings { - return errors.WithStack(ErrMaxCountTSOProxyRoutinesExceeded) + return errors.WithStack(errs.ErrMaxCountTSOProxyRoutinesExceeded) } } @@ -132,7 +130,7 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error { } if request.GetCount() == 0 { err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive") - return status.Error(codes.Unknown, err.Error()) + return errs.ErrUnknown(err) } forwardCtx, cancelForward, forwardStream, lastForwardedHost, tsoStreamErr, err = s.handleTSOForwarding(forwardCtx, forwardStream, stream, server, request, tsDeadlineCh, lastForwardedHost, cancelForward) if tsoStreamErr != nil { @@ -155,7 +153,7 @@ func (s *GrpcServer) handleTSOForwarding(forwardCtx context.Context, forwardStre ) { forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), constant.TSOServiceName) if !ok || len(forwardedHost) == 0 { - return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(ErrNotFoundTSOAddr), nil + return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(errs.ErrNotFoundTSOAddr), nil } if forwardStream == nil || lastForwardedHost != forwardedHost { if cancelForward != nil { @@ -458,7 +456,7 @@ func (s *GrpcServer) getGlobalTSO(ctx context.Context) (pdpb.Timestamp, error) { } forwardedHost, ok = s.GetServicePrimaryAddr(ctx, constant.TSOServiceName) if !ok || forwardedHost == "" { - return pdpb.Timestamp{}, ErrNotFoundTSOAddr + return pdpb.Timestamp{}, errs.ErrNotFoundTSOAddr } forwardStream, err = s.getTSOForwardStream(forwardedHost) if err != nil { diff --git a/server/grpc_service.go b/server/grpc_service.go index d10421e87d7..398325cd30a 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -32,8 +32,6 @@ import ( "go.uber.org/multierr" "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/pingcap/errors" "github.com/pingcap/failpoint" @@ -67,23 +65,6 @@ const ( gRPCServiceName = "pdpb.PD" ) -// gRPC errors -var ( - // ErrNotLeader is returned when current server is not the leader and not possible to process request. - // TODO: work as proxy. - ErrNotLeader = status.Errorf(codes.Unavailable, "not leader") - ErrNotStarted = status.Errorf(codes.Unavailable, "server not started") - ErrSendHeartbeatTimeout = status.Errorf(codes.DeadlineExceeded, "send heartbeat timeout") - ErrNotFoundTSOAddr = status.Errorf(codes.NotFound, "not found tso address") - ErrNotFoundSchedulingAddr = status.Errorf(codes.NotFound, "not found scheduling address") - ErrNotFoundService = status.Errorf(codes.NotFound, "not found service") - ErrForwardTSOTimeout = status.Errorf(codes.DeadlineExceeded, "forward tso request timeout") - ErrMaxCountTSOProxyRoutinesExceeded = status.Errorf(codes.ResourceExhausted, "max count of concurrent tso proxy routines exceeded") - ErrTSOProxyRecvFromClientTimeout = status.Errorf(codes.DeadlineExceeded, "tso proxy timeout when receiving from client; stream closed by server") - ErrEtcdNotStarted = status.Errorf(codes.Unavailable, "server is started, but etcd not started") - ErrFollowerHandlingNotAllowed = status.Errorf(codes.Unavailable, "not leader and follower handling not allowed") -) - var ( errRegionHeartbeatSend = forwardFailCounter.WithLabelValues("region_heartbeat", "send") errRegionHeartbeatClient = forwardFailCounter.WithLabelValues("region_heartbeat", "client") @@ -137,7 +118,7 @@ func (s *tsoServer) send(m *pdpb.TsoResponse) error { return errors.WithStack(err) case <-timer.C: atomic.StoreInt32(&s.closed, 1) - return ErrForwardTSOTimeout + return errs.ErrForwardTSOTimeout } } @@ -167,7 +148,7 @@ func (s *tsoServer) recv(timeout time.Duration) (*pdpb.TsoRequest, error) { return req.request, nil case <-timer.C: atomic.StoreInt32(&s.closed, 1) - return nil, ErrTSOProxyRecvFromClientTimeout + return nil, errs.ErrTSOProxyRecvFromClientTimeout } } @@ -198,7 +179,7 @@ func (s *heartbeatServer) Send(m core.RegionHeartbeatResponse) error { return errors.WithStack(err) case <-timer.C: atomic.StoreInt32(&s.closed, 1) - return ErrSendHeartbeatTimeout + return errs.ErrSendHeartbeatTimeout } } @@ -302,7 +283,7 @@ func (s *GrpcServer) GetMinTS( if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -456,7 +437,7 @@ func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } // Here we purposely do not check the cluster ID because the client does not know the correct cluster ID @@ -506,7 +487,7 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return status.Error(codes.ResourceExhausted, err.Error()) + return errs.ErrGRPCRateLimitExceeded(err) } } if s.IsServiceIndependent(constant.TSOServiceName) { @@ -576,7 +557,7 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { if s.IsServiceIndependent(constant.TSOServiceName) { if request.GetCount() == 0 { err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive") - return status.Error(codes.Unknown, err.Error()) + return errs.ErrUnknown(err) } forwardCtx, cancelForward, forwardStream, lastForwardedHost, tsoStreamErr, err = s.handleTSOForwarding(forwardCtx, forwardStream, stream, nil, request, tsDeadlineCh, lastForwardedHost, cancelForward) if tsoStreamErr != nil { @@ -591,11 +572,10 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { start := time.Now() // TSO uses leader lease to determine validity. No need to check leader here. if s.IsClosed() { - return status.Errorf(codes.Unknown, "server not started") + return errs.ErrNotStarted } if clusterID := keypath.ClusterID(); request.GetHeader().GetClusterId() != clusterID { - return status.Errorf(codes.FailedPrecondition, - "mismatch cluster id, need %d but got %d", clusterID, request.GetHeader().GetClusterId()) + return errs.ErrMismatchClusterID(clusterID, request.GetHeader().GetClusterId()) } count := request.GetCount() ctx, task := trace.NewTask(ctx, "tso") @@ -603,7 +583,7 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { task.End() tsoHandleDuration.Observe(time.Since(start).Seconds()) if err != nil { - return status.Error(codes.Unknown, err.Error()) + return errs.ErrUnknown(err) } response := &pdpb.TsoResponse{ Header: wrapHeader(), @@ -624,7 +604,7 @@ func (s *GrpcServer) Bootstrap(ctx context.Context, request *pdpb.BootstrapReque if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -666,7 +646,7 @@ func (s *GrpcServer) IsBootstrapped(ctx context.Context, request *pdpb.IsBootstr if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -693,7 +673,7 @@ func (s *GrpcServer) AllocID(ctx context.Context, request *pdpb.AllocIDRequest) if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -727,7 +707,7 @@ func (s *GrpcServer) IsSnapshotRecovering(ctx context.Context, _ *pdpb.IsSnapsho if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } // recovering mark is stored in etcd directly, there's no need to forward. @@ -751,7 +731,7 @@ func (s *GrpcServer) GetStore(ctx context.Context, request *pdpb.GetStoreRequest if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -805,7 +785,7 @@ func (s *GrpcServer) PutStore(ctx context.Context, request *pdpb.PutStoreRequest if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -860,7 +840,7 @@ func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStore if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -903,7 +883,7 @@ func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHear if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -992,7 +972,7 @@ func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHear func (s *GrpcServer) updateSchedulingClient(ctx context.Context) (*schedulingClient, error) { forwardedHost, _ := s.GetServicePrimaryAddr(ctx, constant.SchedulingServiceName) if forwardedHost == "" { - return nil, ErrNotFoundSchedulingAddr + return nil, errs.ErrNotFoundSchedulingAddr } pre := s.schedulingClient.Load() @@ -1029,7 +1009,7 @@ type bucketHeartbeatServer struct { func (b *bucketHeartbeatServer) send(bucket *pdpb.ReportBucketsResponse) error { if atomic.LoadInt32(&b.closed) == 1 { - return status.Errorf(codes.Canceled, "stream is closed") + return errs.ErrStreamClosed } done := make(chan error, 1) go func() { @@ -1046,7 +1026,7 @@ func (b *bucketHeartbeatServer) send(bucket *pdpb.ReportBucketsResponse) error { return err case <-timer.C: atomic.StoreInt32(&b.closed, 1) - return ErrSendHeartbeatTimeout + return errs.ErrSendHeartbeatTimeout } } @@ -1082,13 +1062,13 @@ func (s *GrpcServer) ReportBuckets(stream pdpb.PD_ReportBucketsServer) error { if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return status.Error(codes.ResourceExhausted, err.Error()) + return errs.ErrGRPCRateLimitExceeded(err) } } for { request, err := server.recv() failpoint.Inject("grpcClientClosed", func() { - err = status.Error(codes.Canceled, "grpc client closed") + err = errs.ErrStreamClosed request = nil }) if err == io.EOF { @@ -1198,7 +1178,7 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return status.Error(codes.ResourceExhausted, err.Error()) + return errs.ErrGRPCRateLimitExceeded(err) } } for { @@ -1398,7 +1378,7 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error // GetRegion implements gRPC PDServer. func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionRequest) (*pdpb.GetRegionResponse, error) { failpoint.Inject("rateLimit", func() { - failpoint.Return(nil, status.Error(codes.ResourceExhausted, errs.ErrRateLimitExceeded.Error())) + failpoint.Return(nil, errs.ErrGRPCRateLimitExceeded(errs.ErrRateLimitExceeded)) }) if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() @@ -1406,7 +1386,7 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -1468,7 +1448,7 @@ func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionR if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -1525,7 +1505,7 @@ func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionB if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -1585,7 +1565,7 @@ func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsR if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -1641,7 +1621,7 @@ func (s *GrpcServer) BatchScanRegions(ctx context.Context, request *pdpb.BatchSc if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -1731,7 +1711,7 @@ func (s *GrpcServer) AskSplit(ctx context.Context, request *pdpb.AskSplitRequest if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -1775,7 +1755,7 @@ func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSp if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } @@ -1849,7 +1829,7 @@ func (s *GrpcServer) ReportSplit(ctx context.Context, request *pdpb.ReportSplitR if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -1885,7 +1865,7 @@ func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportB if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -1922,7 +1902,7 @@ func (s *GrpcServer) GetClusterConfig(ctx context.Context, request *pdpb.GetClus if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -1952,7 +1932,7 @@ func (s *GrpcServer) PutClusterConfig(ctx context.Context, request *pdpb.PutClus if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -1991,7 +1971,7 @@ func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterReg if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } @@ -2103,7 +2083,7 @@ func (s *GrpcServer) GetGCSafePoint(ctx context.Context, request *pdpb.GetGCSafe if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -2134,7 +2114,7 @@ func (s *GrpcServer) GetGCSafePoint(ctx context.Context, request *pdpb.GetGCSafe // SyncRegions syncs the regions. func (s *GrpcServer) SyncRegions(stream pdpb.PD_SyncRegionsServer) error { if s.IsClosed() || s.cluster == nil { - return ErrNotStarted + return errs.ErrNotStarted } if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() @@ -2142,12 +2122,12 @@ func (s *GrpcServer) SyncRegions(stream pdpb.PD_SyncRegionsServer) error { if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return status.Error(codes.ResourceExhausted, err.Error()) + return errs.ErrGRPCRateLimitExceeded(err) } } ctx := s.cluster.Context() if ctx == nil { - return ErrNotStarted + return errs.ErrNotStarted } return s.cluster.GetRegionSyncer().Sync(ctx, stream) } @@ -2160,7 +2140,7 @@ func (s *GrpcServer) UpdateGCSafePoint(ctx context.Context, request *pdpb.Update if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -2207,7 +2187,7 @@ func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -2261,7 +2241,7 @@ func (s *GrpcServer) GetOperator(ctx context.Context, request *pdpb.GetOperatorR if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } @@ -2334,20 +2314,20 @@ func (s *GrpcServer) validateRequest(header *pdpb.RequestHeader) error { // TODO: Call it in gRPC interceptor. func (s *GrpcServer) validateRoleInRequest(ctx context.Context, header *pdpb.RequestHeader, allowFollower *bool) error { if s.IsClosed() { - return ErrNotStarted + return errs.ErrNotStarted } if !s.member.IsLeader() { if allowFollower == nil { - return ErrNotLeader + return errs.ErrNotLeader } if !grpcutil.IsFollowerHandleEnabled(ctx) { // TODO: change the error code - return ErrFollowerHandlingNotAllowed + return errs.ErrFollowerHandlingNotAllowed } *allowFollower = true } if clusterID := keypath.ClusterID(); header.GetClusterId() != clusterID { - return status.Errorf(codes.FailedPrecondition, "mismatch cluster id, need %d but got %d", clusterID, header.GetClusterId()) + return errs.ErrMismatchClusterID(clusterID, header.GetClusterId()) } return nil } @@ -2473,7 +2453,7 @@ func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegion if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } @@ -2537,7 +2517,7 @@ func (s *GrpcServer) SplitAndScatterRegions(ctx context.Context, request *pdpb.S if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -2600,7 +2580,7 @@ const globalConfigPath = "/global/config/" // it should be set to `Payload bytes` instead of `Value string` func (s *GrpcServer) StoreGlobalConfig(_ context.Context, request *pdpb.StoreGlobalConfigRequest) (*pdpb.StoreGlobalConfigResponse, error) { if s.client == nil { - return nil, ErrEtcdNotStarted + return nil, errs.ErrEtcdNotStarted } if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() @@ -2608,7 +2588,7 @@ func (s *GrpcServer) StoreGlobalConfig(_ context.Context, request *pdpb.StoreGlo if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } configPath := request.GetConfigPath() @@ -2646,7 +2626,7 @@ func (s *GrpcServer) StoreGlobalConfig(_ context.Context, request *pdpb.StoreGlo // - `ConfigPath` if `Names` is nil can get all values and revision of current path func (s *GrpcServer) LoadGlobalConfig(ctx context.Context, request *pdpb.LoadGlobalConfigRequest) (*pdpb.LoadGlobalConfigResponse, error) { if s.client == nil { - return nil, ErrEtcdNotStarted + return nil, errs.ErrEtcdNotStarted } if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() @@ -2654,7 +2634,7 @@ func (s *GrpcServer) LoadGlobalConfig(ctx context.Context, request *pdpb.LoadGlo if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } configPath := request.GetConfigPath() @@ -2694,7 +2674,7 @@ func (s *GrpcServer) LoadGlobalConfig(ctx context.Context, request *pdpb.LoadGlo // Watch on revision which greater than or equal to the required revision. func (s *GrpcServer) WatchGlobalConfig(req *pdpb.WatchGlobalConfigRequest, server pdpb.PD_WatchGlobalConfigServer) error { if s.client == nil { - return ErrEtcdNotStarted + return errs.ErrEtcdNotStarted } if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() @@ -2702,7 +2682,7 @@ func (s *GrpcServer) WatchGlobalConfig(req *pdpb.WatchGlobalConfigRequest, serve if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return status.Error(codes.ResourceExhausted, err.Error()) + return errs.ErrGRPCRateLimitExceeded(err) } } ctx, cancel := context.WithCancel(server.Context()) @@ -2799,7 +2779,7 @@ func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.Repo if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -2837,7 +2817,7 @@ func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.Set if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { @@ -2873,7 +2853,7 @@ func (s *GrpcServer) GetExternalTimestamp(ctx context.Context, request *pdpb.Get if done, err := limiter.Allow(fName); err == nil { defer done() } else { - return nil, status.Error(codes.ResourceExhausted, err.Error()) + return nil, errs.ErrGRPCRateLimitExceeded(err) } } fn := func(ctx context.Context, client *grpc.ClientConn) (any, error) { diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index b7467eb99a5..357a76ace21 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -39,6 +39,7 @@ import ( "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/core/storelimit" "github.com/tikv/pd/pkg/dashboard" + "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/id" "github.com/tikv/pd/pkg/mock/mockid" "github.com/tikv/pd/pkg/mock/mockserver" @@ -767,7 +768,7 @@ func TestNotLeader(t *testing.T) { grpcStatus, ok := status.FromError(err) re.True(ok) re.Equal(codes.Unavailable, grpcStatus.Code()) - re.ErrorContains(server.ErrNotLeader, grpcStatus.Message()) + re.ErrorContains(errs.ErrNotLeader, grpcStatus.Message()) } func TestStoreVersionChange(t *testing.T) { From 95bfbe69ae1b48bac2b76b0f229cc030ac89a0ae Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 25 Dec 2024 16:03:03 +0800 Subject: [PATCH 2/3] client/tso: init the ticker when TSO Follower Proxy is already enabled (#8948) close tikv/pd#8947 Init the ticker directly when TSO Follower Proxy is already enabled. Signed-off-by: JmPotato Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- client/clients/tso/dispatcher.go | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/client/clients/tso/dispatcher.go b/client/clients/tso/dispatcher.go index bdac8096f85..58722088886 100644 --- a/client/clients/tso/dispatcher.go +++ b/client/clients/tso/dispatcher.go @@ -477,14 +477,22 @@ func (td *tsoDispatcher) connectionCtxsUpdater() { ) log.Info("[tso] start tso connection contexts updater") - setNewUpdateTicker := func(ticker *time.Ticker) { + setNewUpdateTicker := func(interval time.Duration) { if updateTicker.C != nil { updateTicker.Stop() } - updateTicker = ticker + if interval == 0 { + updateTicker = &time.Ticker{} + } else { + updateTicker = time.NewTicker(interval) + } + } + // If the TSO Follower Proxy is enabled, set the update interval to the member update interval. + if option.GetEnableTSOFollowerProxy() { + setNewUpdateTicker(sd.MemberUpdateInterval) } // Set to nil before returning to ensure that the existing ticker can be GC. - defer setNewUpdateTicker(nil) + defer setNewUpdateTicker(0) for { provider.updateConnectionCtxs(ctx, connectionCtxs) @@ -499,13 +507,11 @@ func (td *tsoDispatcher) connectionCtxsUpdater() { if enableTSOFollowerProxy && updateTicker.C == nil { // Because the TSO Follower Proxy is enabled, // the periodic check needs to be performed. - setNewUpdateTicker(time.NewTicker(sd.MemberUpdateInterval)) + setNewUpdateTicker(sd.MemberUpdateInterval) } else if !enableTSOFollowerProxy && updateTicker.C != nil { // Because the TSO Follower Proxy is disabled, // the periodic check needs to be turned off. - setNewUpdateTicker(&time.Ticker{}) - } else { - continue + setNewUpdateTicker(0) } case <-updateTicker.C: // Triggered periodically when the TSO Follower Proxy is enabled. From c2f72acc3335dfaa11fb4b8df0d5cce538db965a Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Wed, 25 Dec 2024 17:02:24 +0800 Subject: [PATCH 3/3] api: return not found when region doesn't exist (#8869) close tikv/pd#8868 Signed-off-by: lhy1024 Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/mcs/scheduling/server/apis/v1/api.go | 4 ++++ server/api/region.go | 12 ++++++++++++ server/api/region_test.go | 8 +++++++- tests/integrations/mcs/scheduling/api_test.go | 3 +++ 4 files changed, 26 insertions(+), 1 deletion(-) diff --git a/pkg/mcs/scheduling/server/apis/v1/api.go b/pkg/mcs/scheduling/server/apis/v1/api.go index 3d2d0005a24..535fa79ee0c 100644 --- a/pkg/mcs/scheduling/server/apis/v1/api.go +++ b/pkg/mcs/scheduling/server/apis/v1/api.go @@ -1476,6 +1476,10 @@ func getRegionByID(c *gin.Context) { c.String(http.StatusBadRequest, err.Error()) return } + if regionID == 0 { + c.String(http.StatusBadRequest, errs.ErrRegionInvalidID.FastGenByArgs().Error()) + return + } regionInfo := svr.GetBasicCluster().GetRegion(regionID) if regionInfo == nil { c.String(http.StatusNotFound, errs.ErrRegionNotFound.FastGenByArgs(regionID).Error()) diff --git a/server/api/region.go b/server/api/region.go index afc32d2e762..a439cbfb349 100644 --- a/server/api/region.go +++ b/server/api/region.go @@ -67,8 +67,16 @@ func (h *regionHandler) GetRegionByID(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } + if regionID == 0 { + h.rd.JSON(w, http.StatusBadRequest, errs.ErrRegionInvalidID.FastGenByArgs()) + return + } regionInfo := rc.GetRegion(regionID) + if regionInfo == nil { + h.rd.JSON(w, http.StatusNotFound, errs.ErrRegionNotFound.FastGenByArgs(regionID).Error()) + return + } b, err := response.MarshalRegionInfoJSON(r.Context(), regionInfo) if err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) @@ -101,6 +109,10 @@ func (h *regionHandler) GetRegion(w http.ResponseWriter, r *http.Request) { } regionInfo := rc.GetRegionByKey(paramsByte[0]) + if regionInfo == nil { + h.rd.JSON(w, http.StatusNotFound, errs.ErrRegionNotFound.FastGenByArgs().Error()) + return + } b, err := response.MarshalRegionInfoJSON(r.Context(), regionInfo) if err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) diff --git a/server/api/region_test.go b/server/api/region_test.go index ae91b41ef5e..4e0929636e8 100644 --- a/server/api/region_test.go +++ b/server/api/region_test.go @@ -80,7 +80,11 @@ func (suite *regionTestSuite) TestRegion() { r.UpdateBuckets(buckets, r.GetBuckets()) re := suite.Require() mustRegionHeartbeat(re, suite.svr, r) - url := fmt.Sprintf("%s/region/id/%d", suite.urlPrefix, r.GetID()) + url := fmt.Sprintf("%s/region/id/%d", suite.urlPrefix, 0) + re.NoError(tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusBadRequest))) + url = fmt.Sprintf("%s/region/id/%d", suite.urlPrefix, 2333) + re.NoError(tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusNotFound))) + url = fmt.Sprintf("%s/region/id/%d", suite.urlPrefix, r.GetID()) r1 := &response.RegionInfo{} r1m := make(map[string]any) re.NoError(tu.ReadGetJSON(re, testDialClient, url, r1)) @@ -96,6 +100,8 @@ func (suite *regionTestSuite) TestRegion() { re.Equal(core.HexRegionKeyStr([]byte("a")), keys[0].(string)) re.Equal(core.HexRegionKeyStr([]byte("b")), keys[1].(string)) + url = fmt.Sprintf("%s/region/key/%s", suite.urlPrefix, "c") + re.NoError(tu.CheckGetJSON(testDialClient, url, nil, tu.Status(re, http.StatusNotFound))) url = fmt.Sprintf("%s/region/key/%s", suite.urlPrefix, "a") r2 := &response.RegionInfo{} re.NoError(tu.ReadGetJSON(re, testDialClient, url, r2)) diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index 14b867a587d..abace06bb78 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -714,6 +714,9 @@ func (suite *apiTestSuite) checkRegions(cluster *tests.TestCluster) { err = testutil.ReadGetJSON(re, tests.TestDialClient, urlPrefix, &resp) re.NoError(err) re.Equal(3., resp["count"]) + urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/regions/0", scheServerAddr) + testutil.CheckGetJSON(tests.TestDialClient, urlPrefix, nil, + testutil.Status(re, http.StatusBadRequest)) urlPrefix = fmt.Sprintf("%s/scheduling/api/v1/regions/233", scheServerAddr) testutil.CheckGetJSON(tests.TestDialClient, urlPrefix, nil, testutil.Status(re, http.StatusNotFound), testutil.StringContain(re, "not found"))