diff --git a/server/follower_controller.go b/server/follower_controller.go index 886e8b33..e0501e49 100644 --- a/server/follower_controller.go +++ b/server/follower_controller.go @@ -585,13 +585,6 @@ type MessageWithTerm interface { GetTerm() int64 } -func checkStatus(expected, actual proto.ServingStatus) error { - if actual != expected { - return status.Errorf(common.CodeInvalidStatus, "Received message in the wrong state. In %+v, should be %+v.", actual, expected) - } - return nil -} - func (fc *followerController) SendSnapshot(stream proto.OxiaLogReplication_SendSnapshotServer) error { fc.Lock() diff --git a/server/kv/db.go b/server/kv/db.go index c32f10ec..6f9cf1e7 100644 --- a/server/kv/db.go +++ b/server/kv/db.go @@ -298,8 +298,9 @@ func (it *rangeScanIterator) Value() (*proto.GetResponse, error) { } res := &proto.GetResponse{ - Key: pb.String(it.Key()), - Value: se.Value, + Key: pb.String(it.Key()), + Value: se.Value, + Status: proto.Status_OK, Version: &proto.Version{ VersionId: se.VersionId, ModificationsCount: se.ModificationsCount, diff --git a/server/leader_controller.go b/server/leader_controller.go index c43dc7a4..6f507da8 100644 --- a/server/leader_controller.go +++ b/server/leader_controller.go @@ -22,6 +22,8 @@ import ( "sync" "time" + "google.golang.org/grpc/status" + "github.com/pkg/errors" "go.uber.org/multierr" pb "google.golang.org/protobuf/proto" @@ -45,6 +47,7 @@ type LeaderController interface { Read(ctx context.Context, request *proto.ReadRequest) <-chan GetResult List(ctx context.Context, request *proto.ListRequest) (<-chan string, error) ListSliceNoMutex(ctx context.Context, request *proto.ListRequest) ([]string, error) + RangeScan(ctx context.Context, request *proto.RangeScanRequest) (<-chan *proto.GetResponse, <-chan error, error) // NewTerm Handle new term requests NewTerm(req *proto.NewTermRequest) (*proto.NewTermResponse, error) @@ -568,7 +571,7 @@ func (lc *leaderController) Read(ctx context.Context, request *proto.ReadRequest ch := make(chan GetResult) lc.RLock() - err := checkStatus(proto.ServingStatus_LEADER, lc.status) + err := checkStatusIsLeader(lc.status) lc.RUnlock() if err != nil { go func() { @@ -613,7 +616,7 @@ func (lc *leaderController) List(ctx context.Context, request *proto.ListRequest ch := make(chan string) lc.RLock() - err := checkStatus(proto.ServingStatus_LEADER, lc.status) + err := checkStatusIsLeader(lc.status) lc.RUnlock() if err != nil { return nil, err @@ -681,6 +684,71 @@ func (lc *leaderController) ListSliceNoMutex(ctx context.Context, request *proto } } +func (lc *leaderController) RangeScan(ctx context.Context, request *proto.RangeScanRequest) (<-chan *proto.GetResponse, <-chan error, error) { + ch := make(chan *proto.GetResponse) + errCh := make(chan error) + + lc.RLock() + err := checkStatusIsLeader(lc.status) + lc.RUnlock() + if err != nil { + return nil, nil, err + } + + go lc.rangeScan(ctx, request, ch, errCh) + + return ch, errCh, nil +} + +func (lc *leaderController) rangeScan(ctx context.Context, request *proto.RangeScanRequest, ch chan<- *proto.GetResponse, errCh chan<- error) { + common.DoWithLabels( + ctx, + map[string]string{ + "oxia": "range-scan", + "shard": fmt.Sprintf("%d", lc.shardId), + "peer": common.GetPeer(ctx), + }, + func() { + lc.log.Debug("Received list request", slog.Any("request", request)) + + it, err := lc.db.RangeScan(request) + if err != nil { + lc.log.Warn( + "Failed to process range-scan request", + slog.Any("error", err), + ) + errCh <- err + close(ch) + close(errCh) + return + } + + defer func() { + _ = it.Close() + // NOTE: + // we must close the channel after iterator is closed, to avoid the + // iterator keep open when caller is trying to process the next step (for example db.Close) + // because this is execute in another goroutine. + close(ch) + close(errCh) + }() + + for ; it.Valid(); it.Next() { + gr, err := it.Value() + if err != nil { + errCh <- err + return + } + + ch <- gr + if ctx.Err() != nil { + break + } + } + }, + ) +} + // Write // A client sends a batch of entries to the leader // @@ -714,7 +782,7 @@ func (lc *leaderController) write(ctx context.Context, request func(int64) *prot func (lc *leaderController) appendToWal(ctx context.Context, request func(int64) *proto.WriteRequest) (actualRequest *proto.WriteRequest, offset int64, timestamp uint64, err error) { lc.Lock() - if err := checkStatus(proto.ServingStatus_LEADER, lc.status); err != nil { + if err := checkStatusIsLeader(lc.status); err != nil { lc.Unlock() return nil, wal.InvalidOffset, 0, err } @@ -1006,3 +1074,10 @@ func (lc *leaderController) KeepAlive(sessionId int64) error { func (lc *leaderController) CloseSession(request *proto.CloseSessionRequest) (*proto.CloseSessionResponse, error) { return lc.sessionManager.CloseSession(request) } + +func checkStatusIsLeader(actual proto.ServingStatus) error { + if actual != proto.ServingStatus_LEADER { + return status.Errorf(common.CodeInvalidStatus, "Received message in the wrong state. In %+v, should be %+v.", actual, proto.ServingStatus_LEADER) + } + return nil +} diff --git a/server/leader_controller_test.go b/server/leader_controller_test.go index 74e6da04..a53bcc0a 100644 --- a/server/leader_controller_test.go +++ b/server/leader_controller_test.go @@ -974,6 +974,61 @@ func TestLeaderController_List(t *testing.T) { assert.Len(t, list, 0) } +func TestLeaderController_RangeScan(t *testing.T) { + var shard int64 = 1 + + kvFactory, _ := kv.NewPebbleKVFactory(testKVOptions) + walFactory := newTestWalFactory(t) + + lc, _ := NewLeaderController(Config{}, common.DefaultNamespace, shard, newMockRpcClient(), walFactory, kvFactory) + _, _ = lc.NewTerm(&proto.NewTermRequest{ShardId: shard, Term: 1}) + _, _ = lc.BecomeLeader(context.Background(), &proto.BecomeLeaderRequest{ + ShardId: shard, + Term: 1, + ReplicationFactor: 1, + FollowerMaps: nil, + }) + + _, err := lc.Write(context.Background(), &proto.WriteRequest{ + ShardId: &shard, + Puts: []*proto.PutRequest{ + {Key: "/a", Value: []byte{0}}, + {Key: "/b", Value: []byte{0}}, + {Key: "/c", Value: []byte{0}}, + {Key: "/d", Value: []byte{0}}, + }, + }) + assert.NoError(t, err) + + ch, _, err := lc.RangeScan(context.Background(), &proto.RangeScanRequest{ + ShardId: &shard, + StartInclusive: "/a", + EndExclusive: "/c", + }) + assert.NoError(t, err) + + gr, more := <-ch + assert.Equal(t, "/a", *gr.Key) + assert.True(t, more) + gr, more = <-ch + assert.Equal(t, "/b", *gr.Key) + assert.True(t, more) + gr, more = <-ch + assert.Nil(t, gr) + assert.False(t, more) + + ch, _, err = lc.RangeScan(context.Background(), &proto.RangeScanRequest{ + ShardId: &shard, + StartInclusive: "/y", + EndExclusive: "/z", + }) + assert.NoError(t, err) + + gr, more = <-ch + assert.Nil(t, gr) + assert.False(t, more) +} + func TestLeaderController_DeleteShard(t *testing.T) { var shard int64 = 1 diff --git a/server/public_rpc_server.go b/server/public_rpc_server.go index 7169a7cd..7c12f095 100644 --- a/server/public_rpc_server.go +++ b/server/public_rpc_server.go @@ -30,8 +30,9 @@ import ( ) const ( - maxTotalReadValueSize = 4 << (10 * 2) // 4Mi - maxTotalListKeySize = 4 << (10 * 2) // 4Mi + maxTotalScanBatchCount = 1000 + maxTotalReadValueSize = 4 << (10 * 2) // 4Mi + maxTotalListKeySize = 4 << (10 * 2) // 4Mi ) type publicRpcServer struct { @@ -203,6 +204,62 @@ func (s *publicRpcServer) List(request *proto.ListRequest, stream proto.OxiaClie } } +//nolint:revive +func (s *publicRpcServer) RangeScan(request *proto.RangeScanRequest, stream proto.OxiaClient_RangeScanServer) error { + s.log.Debug( + "RangeScan request", + slog.String("peer", common.GetPeer(stream.Context())), + slog.Any("req", request), + ) + + lc, err := s.getLeader(*request.ShardId) + if err != nil { + return err + } + + ch, errCh, err := lc.RangeScan(stream.Context(), request) + if err != nil { + s.log.Warn( + "Failed to perform range-scan operation", + slog.Any("error", err), + ) + } + + response := &proto.RangeScanResponse{} + var totalSize int + + for { + select { + case err := <-errCh: + return err + + case gr, more := <-ch: + if !more { + if len(response.Records) > 0 { + if err := stream.Send(response); err != nil { + return err + } + } + return nil + } + + size := len(gr.Value) + if len(response.Records) >= maxTotalScanBatchCount || totalSize+size > maxTotalReadValueSize { + if err := stream.Send(response); err != nil { + return err + } + response = &proto.RangeScanResponse{} + totalSize = 0 + } + response.Records = append(response.Records, gr) + totalSize += size + + case <-stream.Context().Done(): + return stream.Context().Err() + } + } +} + func (s *publicRpcServer) GetNotifications(req *proto.NotificationsRequest, stream proto.OxiaClient_GetNotificationsServer) error { s.log.Debug( "Get notifications",