From ffe9d3c9bb94423ede17ab40c1c008e04fd44d9b Mon Sep 17 00:00:00 2001 From: SimFG Date: Tue, 15 Oct 2024 19:28:57 +0800 Subject: [PATCH] support to replicate the database by the create request Signed-off-by: SimFG --- .github/workflows/ci.yaml | 138 ++++++++++++ core/api/replicate_manager.go | 8 +- core/api/replicate_manager_test.go | 4 +- core/mocks/channel_manager.go | 44 ++-- core/reader/collection_reader.go | 26 ++- core/reader/collection_reader_test.go | 11 +- core/reader/replicate_channel_manager.go | 23 +- core/reader/replicate_channel_manager_test.go | 10 +- core/util/msgpack.go | 14 +- server/cdc_impl.go | 20 +- server/cdc_impl_test.go | 70 +++--- server/model/common.go | 4 + server/model/meta/task.go | 1 + server/model/request/create.go | 1 + tests/testcases/test_cdc_database.py | 200 ++++++++++++++++++ 15 files changed, 486 insertions(+), 88 deletions(-) create mode 100644 tests/testcases/test_cdc_database.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 863153e8..c537ae4f 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -581,3 +581,141 @@ jobs: scripts/k8s_logs tests/deployment/upstream/logs server/server.log + milvus-cdc-function-test-with-db: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - uses: actions/setup-go@v4 + with: + go-version: '1.21' + cache-dependency-path: server/go.sum + cache: true + + - name: set up python + uses: actions/setup-python@v2 + with: + python-version: '3.8' + cache: 'pip' + + - name: Build CDC + timeout-minutes: 15 + working-directory: server + shell: bash + run: | + make build + ls -l + + - name: Creating kind cluster + uses: helm/kind-action@v1.2.0 + + - name: Print cluster information + run: | + kubectl config view + kubectl cluster-info + kubectl get nodes + kubectl get pods -o wide -n kube-system + helm version + kubectl version + + - name: Deploy Source Milvus + timeout-minutes: 15 + shell: bash + working-directory: tests/deployment/upstream/ + run: | + docker compose up -d + bash ../../../scripts/check_healthy.sh + docker compose ps -a + + - name: Deploy Downstream Milvus + timeout-minutes: 15 + shell: bash + working-directory: tests/deployment/downstream + run: | + helm repo add milvus https://zilliztech.github.io/milvus-helm + helm repo update + helm install --wait --timeout 720s cdc-downstream milvus/milvus -f standalone-values-auth.yaml + kubectl get pods + kubectl port-forward service/cdc-downstream-milvus 19500:19530 >/dev/null 2>&1 & + sleep 20s + nc -vz 127.0.0.1 19500 + + - name: Deploy Milvus CDC + timeout-minutes: 15 + working-directory: server + shell: bash + run: | + cp ../deployment/docker/cdc.yaml configs/cdc.yaml + ../bin/cdc > server.log 2>&1 & + sleep 20s + + - name: Create CDC task + timeout-minutes: 15 + run: | + curl --location '127.0.0.1:8444/cdc' \ + --header 'Content-Type: application/json' \ + --data '{ + "request_type": "create", + "request_data": { + "milvus_connect_param": { + "uri": "http://127.0.0.1:19500", + "token": "root:Milvus", + "connect_timeout": 120 + }, + "collection_infos": [ + { + "name": "*" + } + ], + "database_info": { + "name": "foo" + } + } + }' + + - name: Run test + timeout-minutes: 15 + shell: bash + working-directory: tests + run: | + pip install -r requirements.txt --trusted-host https://test.pypi.org + pytest testcases/test_cdc_database.py --upstream_host 127.0.0.1 --upstream_port 19530 --downstream_host 127.0.0.1 --downstream_port 19500 + + - name: List CDC task + if: ${{ always() }} + timeout-minutes: 15 + working-directory: server + shell: bash + run: | + cat server.log | tail -n 100 + curl --location '127.0.0.1:8444/cdc' \ + --header 'Content-Type: application/json' \ + --data '{ + "request_type": "list" + }' + + - name: Export upstream milvus logs + if: ${{ always() }} + timeout-minutes: 5 + working-directory: tests/deployment/upstream + run: | + docker compose ps -a + docker stats --no-stream + bash ../../../scripts/export_log_docker.sh + - name: Export downstream milvus logs + if: ${{ always() }} + timeout-minutes: 5 + working-directory: scripts + run: | + kubectl get pods || true + bash export_log_k8s.sh default cdc-downstream k8s_logs + + - name: Upload logs + if: ${{ always() }} + uses: actions/upload-artifact@v4 + with: + name: func-test-logs-with-db + path: | + scripts/k8s_logs + tests/deployment/upstream/logs + server/server.log diff --git a/core/api/replicate_manager.go b/core/api/replicate_manager.go index 4b61e019..c66233f6 100644 --- a/core/api/replicate_manager.go +++ b/core/api/replicate_manager.go @@ -35,9 +35,9 @@ type ChannelManager interface { AddDroppedCollection(ids []int64) AddDroppedPartition(ids []int64) - StartReadCollection(ctx context.Context, info *pb.CollectionInfo, seekPositions []*msgpb.MsgPosition) error + StartReadCollection(ctx context.Context, db *model.DatabaseInfo, info *pb.CollectionInfo, seekPositions []*msgpb.MsgPosition) error StopReadCollection(ctx context.Context, info *pb.CollectionInfo) error - AddPartition(ctx context.Context, collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo) error + AddPartition(ctx context.Context, dbInfo *model.DatabaseInfo, collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo) error GetChannelChan() <-chan string GetMsgChan(pChannel string) <-chan *ReplicateMsg @@ -103,7 +103,7 @@ func (d *DefaultChannelManager) AddDroppedPartition(ids []int64) { log.Warn("AddDroppedPartition is not implemented, please check it") } -func (d *DefaultChannelManager) StartReadCollection(ctx context.Context, info *pb.CollectionInfo, seekPositions []*msgpb.MsgPosition) error { +func (d *DefaultChannelManager) StartReadCollection(ctx context.Context, db *model.DatabaseInfo, info *pb.CollectionInfo, seekPositions []*msgpb.MsgPosition) error { log.Warn("StartReadCollection is not implemented, please check it") return nil } @@ -113,7 +113,7 @@ func (d *DefaultChannelManager) StopReadCollection(ctx context.Context, info *pb return nil } -func (d *DefaultChannelManager) AddPartition(ctx context.Context, collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo) error { +func (d *DefaultChannelManager) AddPartition(ctx context.Context, dbInfo *model.DatabaseInfo, collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo) error { log.Warn("AddPartition is not implemented, please check it") return nil } diff --git a/core/api/replicate_manager_test.go b/core/api/replicate_manager_test.go index 0024017a..4ee69317 100644 --- a/core/api/replicate_manager_test.go +++ b/core/api/replicate_manager_test.go @@ -49,7 +49,7 @@ func TestDefaultChannelManager_AddPartition(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { d := &DefaultChannelManager{} - if err := d.AddPartition(tt.args.ctx, tt.args.collectionInfo, tt.args.partitionInfo); (err != nil) != tt.wantErr { + if err := d.AddPartition(tt.args.ctx, nil, tt.args.collectionInfo, tt.args.partitionInfo); (err != nil) != tt.wantErr { t.Errorf("AddPartition() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -141,7 +141,7 @@ func TestDefaultChannelManager_StartReadCollection(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { d := &DefaultChannelManager{} - if err := d.StartReadCollection(tt.args.ctx, tt.args.info, tt.args.seekPositions); (err != nil) != tt.wantErr { + if err := d.StartReadCollection(tt.args.ctx, nil, tt.args.info, tt.args.seekPositions); (err != nil) != tt.wantErr { t.Errorf("StartReadCollection() error = %v, wantErr %v", err, tt.wantErr) } }) diff --git a/core/mocks/channel_manager.go b/core/mocks/channel_manager.go index 8bef1f55..591ae0bc 100644 --- a/core/mocks/channel_manager.go +++ b/core/mocks/channel_manager.go @@ -9,6 +9,8 @@ import ( mock "github.com/stretchr/testify/mock" + model "github.com/zilliztech/milvus-cdc/core/model" + msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" pb "github.com/zilliztech/milvus-cdc/core/pb" @@ -93,13 +95,13 @@ func (_c *ChannelManager_AddDroppedPartition_Call) RunAndReturn(run func([]int64 return _c } -// AddPartition provides a mock function with given fields: ctx, collectionInfo, partitionInfo -func (_m *ChannelManager) AddPartition(ctx context.Context, collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo) error { - ret := _m.Called(ctx, collectionInfo, partitionInfo) +// AddPartition provides a mock function with given fields: ctx, dbInfo, collectionInfo, partitionInfo +func (_m *ChannelManager) AddPartition(ctx context.Context, dbInfo *model.DatabaseInfo, collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo) error { + ret := _m.Called(ctx, dbInfo, collectionInfo, partitionInfo) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *pb.CollectionInfo, *pb.PartitionInfo) error); ok { - r0 = rf(ctx, collectionInfo, partitionInfo) + if rf, ok := ret.Get(0).(func(context.Context, *model.DatabaseInfo, *pb.CollectionInfo, *pb.PartitionInfo) error); ok { + r0 = rf(ctx, dbInfo, collectionInfo, partitionInfo) } else { r0 = ret.Error(0) } @@ -114,15 +116,16 @@ type ChannelManager_AddPartition_Call struct { // AddPartition is a helper method to define mock.On call // - ctx context.Context +// - dbInfo *model.DatabaseInfo // - collectionInfo *pb.CollectionInfo // - partitionInfo *pb.PartitionInfo -func (_e *ChannelManager_Expecter) AddPartition(ctx interface{}, collectionInfo interface{}, partitionInfo interface{}) *ChannelManager_AddPartition_Call { - return &ChannelManager_AddPartition_Call{Call: _e.mock.On("AddPartition", ctx, collectionInfo, partitionInfo)} +func (_e *ChannelManager_Expecter) AddPartition(ctx interface{}, dbInfo interface{}, collectionInfo interface{}, partitionInfo interface{}) *ChannelManager_AddPartition_Call { + return &ChannelManager_AddPartition_Call{Call: _e.mock.On("AddPartition", ctx, dbInfo, collectionInfo, partitionInfo)} } -func (_c *ChannelManager_AddPartition_Call) Run(run func(ctx context.Context, collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo)) *ChannelManager_AddPartition_Call { +func (_c *ChannelManager_AddPartition_Call) Run(run func(ctx context.Context, dbInfo *model.DatabaseInfo, collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo)) *ChannelManager_AddPartition_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*pb.CollectionInfo), args[2].(*pb.PartitionInfo)) + run(args[0].(context.Context), args[1].(*model.DatabaseInfo), args[2].(*pb.CollectionInfo), args[3].(*pb.PartitionInfo)) }) return _c } @@ -132,7 +135,7 @@ func (_c *ChannelManager_AddPartition_Call) Return(_a0 error) *ChannelManager_Ad return _c } -func (_c *ChannelManager_AddPartition_Call) RunAndReturn(run func(context.Context, *pb.CollectionInfo, *pb.PartitionInfo) error) *ChannelManager_AddPartition_Call { +func (_c *ChannelManager_AddPartition_Call) RunAndReturn(run func(context.Context, *model.DatabaseInfo, *pb.CollectionInfo, *pb.PartitionInfo) error) *ChannelManager_AddPartition_Call { _c.Call.Return(run) return _c } @@ -355,13 +358,13 @@ func (_c *ChannelManager_SetCtx_Call) RunAndReturn(run func(context.Context)) *C return _c } -// StartReadCollection provides a mock function with given fields: ctx, info, seekPositions -func (_m *ChannelManager) StartReadCollection(ctx context.Context, info *pb.CollectionInfo, seekPositions []*msgpb.MsgPosition) error { - ret := _m.Called(ctx, info, seekPositions) +// StartReadCollection provides a mock function with given fields: ctx, db, info, seekPositions +func (_m *ChannelManager) StartReadCollection(ctx context.Context, db *model.DatabaseInfo, info *pb.CollectionInfo, seekPositions []*msgpb.MsgPosition) error { + ret := _m.Called(ctx, db, info, seekPositions) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *pb.CollectionInfo, []*msgpb.MsgPosition) error); ok { - r0 = rf(ctx, info, seekPositions) + if rf, ok := ret.Get(0).(func(context.Context, *model.DatabaseInfo, *pb.CollectionInfo, []*msgpb.MsgPosition) error); ok { + r0 = rf(ctx, db, info, seekPositions) } else { r0 = ret.Error(0) } @@ -376,15 +379,16 @@ type ChannelManager_StartReadCollection_Call struct { // StartReadCollection is a helper method to define mock.On call // - ctx context.Context +// - db *model.DatabaseInfo // - info *pb.CollectionInfo // - seekPositions []*msgpb.MsgPosition -func (_e *ChannelManager_Expecter) StartReadCollection(ctx interface{}, info interface{}, seekPositions interface{}) *ChannelManager_StartReadCollection_Call { - return &ChannelManager_StartReadCollection_Call{Call: _e.mock.On("StartReadCollection", ctx, info, seekPositions)} +func (_e *ChannelManager_Expecter) StartReadCollection(ctx interface{}, db interface{}, info interface{}, seekPositions interface{}) *ChannelManager_StartReadCollection_Call { + return &ChannelManager_StartReadCollection_Call{Call: _e.mock.On("StartReadCollection", ctx, db, info, seekPositions)} } -func (_c *ChannelManager_StartReadCollection_Call) Run(run func(ctx context.Context, info *pb.CollectionInfo, seekPositions []*msgpb.MsgPosition)) *ChannelManager_StartReadCollection_Call { +func (_c *ChannelManager_StartReadCollection_Call) Run(run func(ctx context.Context, db *model.DatabaseInfo, info *pb.CollectionInfo, seekPositions []*msgpb.MsgPosition)) *ChannelManager_StartReadCollection_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*pb.CollectionInfo), args[2].([]*msgpb.MsgPosition)) + run(args[0].(context.Context), args[1].(*model.DatabaseInfo), args[2].(*pb.CollectionInfo), args[3].([]*msgpb.MsgPosition)) }) return _c } @@ -394,7 +398,7 @@ func (_c *ChannelManager_StartReadCollection_Call) Return(_a0 error) *ChannelMan return _c } -func (_c *ChannelManager_StartReadCollection_Call) RunAndReturn(run func(context.Context, *pb.CollectionInfo, []*msgpb.MsgPosition) error) *ChannelManager_StartReadCollection_Call { +func (_c *ChannelManager_StartReadCollection_Call) RunAndReturn(run func(context.Context, *model.DatabaseInfo, *pb.CollectionInfo, []*msgpb.MsgPosition) error) *ChannelManager_StartReadCollection_Call { _c.Call.Return(run) return _c } diff --git a/core/reader/collection_reader.go b/core/reader/collection_reader.go index 50b966b7..636c146a 100644 --- a/core/reader/collection_reader.go +++ b/core/reader/collection_reader.go @@ -35,6 +35,7 @@ import ( "github.com/zilliztech/milvus-cdc/core/api" "github.com/zilliztech/milvus-cdc/core/config" "github.com/zilliztech/milvus-cdc/core/log" + "github.com/zilliztech/milvus-cdc/core/model" "github.com/zilliztech/milvus-cdc/core/pb" "github.com/zilliztech/milvus-cdc/core/util" ) @@ -48,7 +49,7 @@ type CollectionInfo struct { positions map[string]*commonpb.KeyDataPair } -type ShouldReadFunc func(*pb.CollectionInfo) bool +type ShouldReadFunc func(*model.DatabaseInfo, *pb.CollectionInfo) bool var _ api.Reader = (*CollectionReader)(nil) @@ -112,7 +113,8 @@ func (reader *CollectionReader) StartRead(ctx context.Context) { } collectionLog.Info("has watched to read collection") - if !reader.shouldReadFunc(info) { + dbInfo := reader.metaOp.GetDatabaseInfoForCollection(ctx, info.ID) + if !reader.shouldReadFunc(&dbInfo, info) { collectionLog.Info("the collection should not be read") return false } @@ -124,7 +126,7 @@ func (reader *CollectionReader) StartRead(ctx context.Context) { Timestamp: info.CreateTime, }) } - if err := reader.channelManager.StartReadCollection(ctx, info, startPositions); err != nil { + if err := reader.channelManager.StartReadCollection(ctx, &dbInfo, info, startPositions); err != nil { collectionLog.Warn("fail to start to replicate the collection data in the watch process", zap.Any("info", info), zap.Error(err)) reader.sendError(err) } @@ -168,12 +170,13 @@ func (reader *CollectionReader) StartRead(ctx context.Context) { Name: collectionName, }, } - if !reader.shouldReadFunc(tmpCollectionInfo) { + dbInfo := reader.metaOp.GetDatabaseInfoForCollection(ctx, tmpCollectionInfo.ID) + if !reader.shouldReadFunc(&dbInfo, tmpCollectionInfo) { partitionLog.Info("the partition should not be read", zap.String("name", collectionName)) return true } - err := reader.channelManager.AddPartition(ctx, tmpCollectionInfo, info) + err := reader.channelManager.AddPartition(ctx, &dbInfo, tmpCollectionInfo, info) if err != nil { partitionLog.Warn("fail to add partition", zap.String("collection_name", collectionName), zap.Any("partition", info), zap.Error(err)) reader.sendError(err) @@ -187,7 +190,8 @@ func (reader *CollectionReader) StartRead(ctx context.Context) { readerLog := log.With(zap.String("task_id", reader.id)) existedCollectionInfos, err := reader.metaOp.GetAllCollection(ctx, func(info *pb.CollectionInfo) bool { - return !reader.shouldReadFunc(info) + // return !reader.shouldReadFunc(info) + return false }) if err != nil { readerLog.Warn("get all collection failed", zap.Error(err)) @@ -230,7 +234,8 @@ func (reader *CollectionReader) StartRead(ctx context.Context) { readerLog.Info("skip to start to read collection", zap.String("name", info.Schema.Name), zap.Int64("collection_id", info.ID)) continue } - if !reader.shouldReadFunc(info) { + dbInfo := reader.metaOp.GetDatabaseInfoForCollection(ctx, info.ID) + if !reader.shouldReadFunc(&dbInfo, info) { readerLog.Info("the collection is not in the watch list", zap.String("name", info.Schema.Name), zap.Int64("collection_id", info.ID)) continue } @@ -252,7 +257,7 @@ func (reader *CollectionReader) StartRead(ctx context.Context) { zap.String("name", info.Schema.Name), zap.Int64("collection_id", info.ID), zap.String("state", info.State.String())) - if err := reader.channelManager.StartReadCollection(ctx, info, seekPositions); err != nil { + if err := reader.channelManager.StartReadCollection(ctx, &dbInfo, info, seekPositions); err != nil { readerLog.Warn("fail to start to replicate the collection data", zap.Any("collection", info), zap.Error(err)) reader.sendError(err) } @@ -288,7 +293,8 @@ func (reader *CollectionReader) StartRead(ctx context.Context) { Name: collectionName, }, } - if !reader.shouldReadFunc(tmpCollectionInfo) { + dbInfo := reader.metaOp.GetDatabaseInfoForCollection(ctx, tmpCollectionInfo.ID) + if !reader.shouldReadFunc(&dbInfo, tmpCollectionInfo) { readerLog.Info("the collection is not in the watch list", zap.String("collection_name", collectionName), zap.String("partition_name", info.PartitionName)) return true } @@ -297,7 +303,7 @@ func (reader *CollectionReader) StartRead(ctx context.Context) { zap.Int64("partition_id", info.PartitionID), zap.String("collection_name", collectionName), zap.Int64("collection_id", info.CollectionId)) - err := reader.channelManager.AddPartition(ctx, tmpCollectionInfo, info) + err := reader.channelManager.AddPartition(ctx, &dbInfo, tmpCollectionInfo, info) if err != nil { readerLog.Warn("fail to add partition", zap.String("collection_name", collectionName), zap.String("partition_name", info.PartitionName), zap.Error(err)) reader.sendError(err) diff --git a/core/reader/collection_reader_test.go b/core/reader/collection_reader_test.go index fe2caac0..f29ed37e 100644 --- a/core/reader/collection_reader_test.go +++ b/core/reader/collection_reader_test.go @@ -37,6 +37,7 @@ import ( api2 "github.com/zilliztech/milvus-cdc/core/api" "github.com/zilliztech/milvus-cdc/core/config" "github.com/zilliztech/milvus-cdc/core/mocks" + "github.com/zilliztech/milvus-cdc/core/model" "github.com/zilliztech/milvus-cdc/core/pb" ) @@ -118,11 +119,11 @@ func TestCollectionReader(t *testing.T) { channelManager := mocks.NewChannelManager(t) // existed collection and partition - channelManager.EXPECT().StartReadCollection(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock err")).Once() - channelManager.EXPECT().AddPartition(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + channelManager.EXPECT().StartReadCollection(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock err")).Once() + channelManager.EXPECT().AddPartition(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() channelManager.EXPECT().AddDroppedCollection(mock.Anything).Return().Once() - reader, err := NewCollectionReader("reader-1", channelManager, etcdOp, nil, func(ci *pb.CollectionInfo) bool { + reader, err := NewCollectionReader("reader-1", channelManager, etcdOp, nil, func(_ *model.DatabaseInfo, ci *pb.CollectionInfo) bool { return !strings.Contains(ci.Schema.Name, "test") }, config.ReaderConfig{ Retry: config.RetrySettings{ @@ -142,8 +143,8 @@ func TestCollectionReader(t *testing.T) { }() reader.StartRead(context.Background()) // put collection and partition - channelManager.EXPECT().StartReadCollection(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() - channelManager.EXPECT().AddPartition(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + channelManager.EXPECT().StartReadCollection(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + channelManager.EXPECT().AddPartition(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() { // filter collection diff --git a/core/reader/replicate_channel_manager.go b/core/reader/replicate_channel_manager.go index db118a98..dff0bf4e 100644 --- a/core/reader/replicate_channel_manager.go +++ b/core/reader/replicate_channel_manager.go @@ -169,7 +169,7 @@ func (r *replicateChannelManager) AddDroppedPartition(ids []int64) { log.Info("has removed dropped partitions", zap.Int64s("ids", ids)) } -func (r *replicateChannelManager) startReadCollectionForKafka(ctx context.Context, info *pb.CollectionInfo, sourceDBInfo model.DatabaseInfo) (*model.CollectionInfo, error) { +func (r *replicateChannelManager) startReadCollectionForKafka(ctx context.Context, info *pb.CollectionInfo, sourceDBInfo *model.DatabaseInfo) (*model.CollectionInfo, error) { r.collectionLock.RLock() _, ok := r.replicateCollections[info.ID] r.collectionLock.RUnlock() @@ -212,7 +212,7 @@ func (r *replicateChannelManager) startReadCollectionForKafka(ctx context.Contex return targetInfo, nil } -func (r *replicateChannelManager) startReadCollectionForMilvus(ctx context.Context, info *pb.CollectionInfo, sourceDBInfo model.DatabaseInfo) (*model.CollectionInfo, error) { +func (r *replicateChannelManager) startReadCollectionForMilvus(ctx context.Context, info *pb.CollectionInfo, sourceDBInfo *model.DatabaseInfo) (*model.CollectionInfo, error) { var err error retryErr := retry.Do(ctx, func() error { _, err = r.targetClient.GetCollectionInfo(ctx, info.Schema.GetName(), sourceDBInfo.Name) @@ -270,7 +270,7 @@ func (r *replicateChannelManager) startReadCollectionForMilvus(ctx context.Conte return targetInfo, nil } -func (r *replicateChannelManager) sendCreateCollectionvent(ctx context.Context, info *pb.CollectionInfo, sourceDBInfo model.DatabaseInfo) error { +func (r *replicateChannelManager) sendCreateCollectionvent(ctx context.Context, info *pb.CollectionInfo, sourceDBInfo *model.DatabaseInfo) error { select { case <-ctx.Done(): log.Warn("context is done in the start read collection") @@ -290,7 +290,7 @@ func (r *replicateChannelManager) sendCreateCollectionvent(ctx context.Context, return nil } -func (r *replicateChannelManager) StartReadCollection(ctx context.Context, info *pb.CollectionInfo, seekPositions []*msgpb.MsgPosition) error { +func (r *replicateChannelManager) StartReadCollection(ctx context.Context, db *model.DatabaseInfo, info *pb.CollectionInfo, seekPositions []*msgpb.MsgPosition) error { r.addCollectionLock.Lock() *r.addCollectionCnt++ r.addCollectionLock.Unlock() @@ -308,11 +308,10 @@ func (r *replicateChannelManager) StartReadCollection(ctx context.Context, info var targetInfo *model.CollectionInfo var err error - sourceDBInfo := r.metaOp.GetDatabaseInfoForCollection(ctx, info.ID) if r.downstream == "milvus" { - targetInfo, err = r.startReadCollectionForMilvus(ctx, info, sourceDBInfo) + targetInfo, err = r.startReadCollectionForMilvus(ctx, info, db) } else if r.downstream == "kafka" { - targetInfo, err = r.startReadCollectionForKafka(ctx, info, sourceDBInfo) + targetInfo, err = r.startReadCollectionForKafka(ctx, info, db) } if err != nil { @@ -440,14 +439,12 @@ func ForeachChannel(sourcePChannels, targetPChannels []string, f func(sourcePCha return nil } -func (r *replicateChannelManager) AddPartition(ctx context.Context, collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo) error { +func (r *replicateChannelManager) AddPartition(ctx context.Context, dbInfo *model.DatabaseInfo, collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo) error { var handlers []*replicateChannelHandler collectionID := collectionInfo.ID partitionLog := log.With(zap.Int64("partition_id", partitionInfo.PartitionID), zap.Int64("collection_id", collectionID), zap.String("collection_name", collectionInfo.Schema.Name), zap.String("partition_name", partitionInfo.PartitionName)) - sourceDBInfo := r.metaOp.GetDatabaseInfoForCollection(ctx, collectionID) - - if sourceDBInfo.Dropped { + if dbInfo.Dropped { partitionLog.Info("the database has been dropped when add partition") return nil } @@ -514,7 +511,7 @@ func (r *replicateChannelManager) AddPartition(ctx context.Context, collectionIn IsReplicate: true, MsgTimestamp: partitionInfo.PartitionCreatedTimestamp, }, - ReplicateParam: api.ReplicateParam{Database: sourceDBInfo.Name}, + ReplicateParam: api.ReplicateParam{Database: dbInfo.Name}, }: case <-ctx.Done(): partitionLog.Warn("context is done when adding partition") @@ -534,7 +531,7 @@ func (r *replicateChannelManager) AddPartition(ctx context.Context, collectionIn IsReplicate: true, MsgTimestamp: msgTs, }, - ReplicateParam: api.ReplicateParam{Database: sourceDBInfo.Name}, + ReplicateParam: api.ReplicateParam{Database: dbInfo.Name}, }: r.droppedPartitions.Store(partitionInfo.PartitionID, struct{}{}) for _, handler := range handlers { diff --git a/core/reader/replicate_channel_manager_test.go b/core/reader/replicate_channel_manager_test.go index 782466f4..90a7ac6b 100644 --- a/core/reader/replicate_channel_manager_test.go +++ b/core/reader/replicate_channel_manager_test.go @@ -195,7 +195,7 @@ func TestStartReadCollection(t *testing.T) { t.Run("context cancel", func(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) cancelFunc() - err = manager.StartReadCollection(ctx, &pb.CollectionInfo{}, nil) + err = manager.StartReadCollection(ctx, &model.DatabaseInfo{}, &pb.CollectionInfo{}, nil) assert.Error(t, err) }) @@ -215,7 +215,7 @@ func TestStartReadCollection(t *testing.T) { realManager.startReadRetryOptions = []retry.Option{ retry.Attempts(1), } - err = manager.StartReadCollection(context.Background(), &pb.CollectionInfo{ + err = manager.StartReadCollection(context.Background(), &model.DatabaseInfo{}, &pb.CollectionInfo{ Schema: &schemapb.CollectionSchema{ Name: "test", }, @@ -291,7 +291,7 @@ func TestStartReadCollection(t *testing.T) { "_default": 31010, }, }, nil).Twice() - err := realManager.StartReadCollection(context.Background(), &pb.CollectionInfo{ + err := realManager.StartReadCollection(context.Background(), &model.DatabaseInfo{}, &pb.CollectionInfo{ ID: 31001, Schema: &schemapb.CollectionSchema{ Name: "test", @@ -312,7 +312,7 @@ func TestStartReadCollection(t *testing.T) { realManager.retryOptions = []retry.Option{ retry.Attempts(1), } - err := realManager.AddPartition(context.Background(), &pb.CollectionInfo{ + err := realManager.AddPartition(context.Background(), &model.DatabaseInfo{}, &pb.CollectionInfo{ ID: 41, Schema: &schemapb.CollectionSchema{ Name: "test", @@ -323,7 +323,7 @@ func TestStartReadCollection(t *testing.T) { // add partition { - err := realManager.AddPartition(context.Background(), &pb.CollectionInfo{ + err := realManager.AddPartition(context.Background(), &model.DatabaseInfo{}, &pb.CollectionInfo{ ID: 31001, Schema: &schemapb.CollectionSchema{ Name: "test", diff --git a/core/util/msgpack.go b/core/util/msgpack.go index 53306587..5aa1a33e 100644 --- a/core/util/msgpack.go +++ b/core/util/msgpack.go @@ -23,6 +23,7 @@ import ( "sync" "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/requestutil" ) var SuffixSnapshotTombstone = []byte{0xE2, 0x9B, 0xBC} // base64 value: "4pu8" @@ -53,8 +54,17 @@ func GetCollectionNameFromMsgPack(msgPack *msgstream.MsgPack) string { return "" } firstMsg := msgPack.Msgs[0] - collectionName, _ := GetCollectionNameFromRequest(firstMsg) - return collectionName + collectionName, _ := requestutil.GetCollectionNameFromRequest(firstMsg) + return collectionName.(string) +} + +func GetDatabaseNameFromMsgPack(msgPack *msgstream.MsgPack) string { + if len(msgPack.Msgs) == 0 { + return "" + } + firstMsg := msgPack.Msgs[0] + dbName, _ := requestutil.GetDbNameFromRequest(firstMsg) + return dbName.(string) } func GetCollectionIDFromMsgPack(msgPack *msgstream.MsgPack) int64 { diff --git a/server/cdc_impl.go b/server/cdc_impl.go index d1250dcf..151155d8 100644 --- a/server/cdc_impl.go +++ b/server/cdc_impl.go @@ -44,6 +44,7 @@ import ( "github.com/zilliztech/milvus-cdc/core/api" "github.com/zilliztech/milvus-cdc/core/config" "github.com/zilliztech/milvus-cdc/core/log" + coremodel "github.com/zilliztech/milvus-cdc/core/model" "github.com/zilliztech/milvus-cdc/core/pb" cdcreader "github.com/zilliztech/milvus-cdc/core/reader" "github.com/zilliztech/milvus-cdc/core/util" @@ -304,6 +305,7 @@ func (e *MetaCDC) Create(req *request.CreateRequest) (resp *request.CreateRespon TaskID: e.getUUID(), MilvusConnectParam: req.MilvusConnectParam, KafkaConnectParam: req.KafkaConnectParam, + DatabaseInfo: req.DatabaseInfo, CollectionInfos: req.CollectionInfos, RPCRequestChannelInfo: req.RPCChannelInfo, ExcludeCollections: excludeCollectionNames, @@ -952,7 +954,9 @@ func replicateMetric(info *meta.TaskInfo, channelName string, msgPack *msgstream func (e *MetaCDC) getChannelReader(info *meta.TaskInfo, replicateEntity *ReplicateEntity, channelName, channelPosition string) (api.Reader, error) { taskLog := log.With(zap.String("task_id", info.TaskID)) collectionName := info.CollectionNames()[0] + databaseName := info.DatabaseInfo.Name isAnyCollection := collectionName == cdcreader.AllCollection + isAnyDatabase := databaseName == "" // isTmpCollection := collectionName == model.TmpCollectionName dataHandleFunc := func(funcCtx context.Context, pack *msgstream.MsgPack) bool { @@ -967,8 +971,11 @@ func (e *MetaCDC) getChannelReader(info *meta.TaskInfo, replicateEntity *Replica Set(float64(msgTime)) msgCollectionName := util.GetCollectionNameFromMsgPack(pack) + msgDatabaseName := util.GetDatabaseNameFromMsgPack(pack) // TODO it should be changed if replicate the user and role info or multi collection - if !isAnyCollection && msgCollectionName != collectionName { + // TODO how to handle it when there are "*" and "foo" collection names in the task list + if (!isAnyCollection && msgCollectionName != collectionName) || + (!isAnyDatabase && msgDatabaseName != databaseName) { // skip the message if the collection name is not equal to the task collection name return true } @@ -1253,15 +1260,22 @@ func (e *MetaCDC) Maintenance(req *request.MaintenanceRequest) (*request.Mainten func GetShouldReadFunc(taskInfo *meta.TaskInfo) cdcreader.ShouldReadFunc { isAll := taskInfo.CollectionInfos[0].Name == cdcreader.AllCollection - return func(collectionInfo *pb.CollectionInfo) bool { + return func(databaseInfo *coremodel.DatabaseInfo, collectionInfo *pb.CollectionInfo) bool { currentCollectionName := collectionInfo.Schema.Name + if databaseInfo.Dropped { + log.Info("database is dropped", zap.String("database", databaseInfo.Name), zap.String("collection", currentCollectionName)) + return false + } + notStarContains := !isAll && lo.ContainsBy(taskInfo.CollectionInfos, func(taskCollectionInfo model.CollectionInfo) bool { return taskCollectionInfo.Name == currentCollectionName }) starContains := isAll && !lo.ContainsBy(taskInfo.ExcludeCollections, func(s string) bool { return s == currentCollectionName }) + dbMatch := taskInfo.DatabaseInfo.Name == "" || + taskInfo.DatabaseInfo.Name == databaseInfo.Name - return notStarContains || starContains + return (notStarContains || starContains) && dbMatch } } diff --git a/server/cdc_impl_test.go b/server/cdc_impl_test.go index df9a4a69..87854348 100644 --- a/server/cdc_impl_test.go +++ b/server/cdc_impl_test.go @@ -40,6 +40,7 @@ import ( "github.com/zilliztech/milvus-cdc/core/config" coremocks "github.com/zilliztech/milvus-cdc/core/mocks" + coremodel "github.com/zilliztech/milvus-cdc/core/model" "github.com/zilliztech/milvus-cdc/core/pb" cdcreader "github.com/zilliztech/milvus-cdc/core/reader" "github.com/zilliztech/milvus-cdc/core/util" @@ -760,16 +761,31 @@ func TestShouldReadCollection(t *testing.T) { }, ExcludeCollections: []string{"foo"}, }) - assert.True(t, f(&pb.CollectionInfo{ - Schema: &schemapb.CollectionSchema{ - Name: "hoo", - }, - })) - assert.False(t, f(&pb.CollectionInfo{ - Schema: &schemapb.CollectionSchema{ - Name: "foo", + assert.True(t, f( + &coremodel.DatabaseInfo{}, + &pb.CollectionInfo{ + Schema: &schemapb.CollectionSchema{ + Name: "hoo", + }, + })) + + assert.False(t, f( + &coremodel.DatabaseInfo{}, + &pb.CollectionInfo{ + Schema: &schemapb.CollectionSchema{ + Name: "foo", + }, + })) + + assert.False(t, f( + &coremodel.DatabaseInfo{ + Dropped: true, }, - })) + &pb.CollectionInfo{ + Schema: &schemapb.CollectionSchema{ + Name: "hoo", + }, + })) }) t.Run("some collection", func(t *testing.T) { @@ -784,21 +800,27 @@ func TestShouldReadCollection(t *testing.T) { }, ExcludeCollections: []string{"foo"}, }) - assert.True(t, f(&pb.CollectionInfo{ - Schema: &schemapb.CollectionSchema{ - Name: "a", - }, - })) - assert.False(t, f(&pb.CollectionInfo{ - Schema: &schemapb.CollectionSchema{ - Name: "c", - }, - })) - assert.False(t, f(&pb.CollectionInfo{ - Schema: &schemapb.CollectionSchema{ - Name: "foo", - }, - })) + assert.True(t, f( + &coremodel.DatabaseInfo{}, + &pb.CollectionInfo{ + Schema: &schemapb.CollectionSchema{ + Name: "a", + }, + })) + assert.False(t, f( + &coremodel.DatabaseInfo{}, + &pb.CollectionInfo{ + Schema: &schemapb.CollectionSchema{ + Name: "c", + }, + })) + assert.False(t, f( + &coremodel.DatabaseInfo{}, + &pb.CollectionInfo{ + Schema: &schemapb.CollectionSchema{ + Name: "foo", + }, + })) }) } diff --git a/server/model/common.go b/server/model/common.go index cba3689f..6a9129cf 100644 --- a/server/model/common.go +++ b/server/model/common.go @@ -56,6 +56,10 @@ type MilvusConnectParam struct { ConnectTimeout int `json:"connect_timeout" mapstructure:"connect_timeout"` } +type DatabaseInfo struct { + Name string `json:"name" mapstructure:"name"` +} + type CollectionInfo struct { Name string `json:"name" mapstructure:"name"` Positions map[string]string `json:"positions" mapstructure:"positions"` // the key is the vchannel diff --git a/server/model/meta/task.go b/server/model/meta/task.go index d9ff285b..ea5d84ab 100644 --- a/server/model/meta/task.go +++ b/server/model/meta/task.go @@ -64,6 +64,7 @@ type TaskInfo struct { KafkaConnectParam model.KafkaConnectParam WriterCacheConfig model.BufferConfig CollectionInfos []model.CollectionInfo + DatabaseInfo model.DatabaseInfo RPCRequestChannelInfo model.ChannelInfo ExcludeCollections []string // it's used for the `*` collection name State TaskState diff --git a/server/model/request/create.go b/server/model/request/create.go index 972254db..0683e75a 100644 --- a/server/model/request/create.go +++ b/server/model/request/create.go @@ -25,6 +25,7 @@ type CreateRequest struct { KafkaConnectParam model.KafkaConnectParam `json:"kafka_connect_param,omitempty" mapstructure:"kafka_connect_param,omitempty"` MilvusConnectParam model.MilvusConnectParam `json:"milvus_connect_param" mapstructure:"milvus_connect_param"` CollectionInfos []model.CollectionInfo `json:"collection_infos" mapstructure:"collection_infos"` + DatabaseInfo model.DatabaseInfo `json:"database_info" mapstructure:"database_info"` RPCChannelInfo model.ChannelInfo `json:"rpc_channel_info" mapstructure:"rpc_channel_info"` BufferConfig model.BufferConfig `json:"buffer_config" mapstructure:"buffer_config"` // Deprecated diff --git a/tests/testcases/test_cdc_database.py b/tests/testcases/test_cdc_database.py new file mode 100644 index 00000000..d38c492d --- /dev/null +++ b/tests/testcases/test_cdc_database.py @@ -0,0 +1,200 @@ +import random + +import pytest +import time +import numpy as np +from datetime import datetime +from utils.util_log import test_log as log +from api.milvus_cdc import MilvusCdcClient +from pymilvus import ( + connections, list_collections, + Collection, Partition, db, + utility, +) +from pymilvus.client.types import LoadState +from pymilvus.orm.role import Role +from base.checker import default_schema, list_partitions +from base.checker import ( + InsertEntitiesPartitionChecker, + InsertEntitiesCollectionChecker +) +from base.client_base import TestBase + +prefix = "cdc_create_task_" +# client = MilvusCdcClient('http://localhost:8444') + + +class TestCDCSyncRequest(TestBase): + """ Test Milvus CDC end to end """ + + def connect_downstream(self, host, port, token="root:Milvus"): + connections.connect(host=host, port=port, token=token) + + def test_cdc_sync_default_database_request(self, upstream_host, upstream_port, downstream_host, downstream_port): + """ + target: test cdc without not match database + method: create collection/insert/load/flush in upstream + expected: create successfully + """ + connections.connect(host=upstream_host, port=upstream_port) + col_list = [] + for i in range(10): + time.sleep(0.1) + collection_name = prefix + "not_match_database_" + datetime.now().strftime('%Y_%m_%d_%H_%M_%S_%f') + col_list.append(collection_name) + # create collections in upstream + for col_name in col_list: + c = Collection(name=col_name, schema=default_schema) + log.info(f"create collection {col_name} in upstream") + # insert data to upstream + nb = 300 + epoch = 10 + for e in range(epoch): + data = [ + [i for i in range(nb)], + [np.float32(i) for i in range(nb)], + [str(i) for i in range(nb)], + [[random.random() for _ in range(128)] for _ in range(nb)] + ] + c.insert(data) + c.flush() + index_params = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} + c.create_index("float_vector", index_params) + c.load() + # check collections in upstream + assert set(col_list).issubset(set(list_collections())) + # check collections in downstream + connections.disconnect("default") + self.connect_downstream(downstream_host, downstream_port) + idx = 0 + log.info(f"all collections in downstream {list_collections()}") + while idx < 10: + downstream_col_list = list_collections() + if len(downstream_col_list) != 0: + log.info(f"all collections in downstream {downstream_col_list}") + time.sleep(3) + idx += 1 + assert len(list_collections()) == 0, f"collections in downstream {list_collections()}" + + def test_cdc_sync_not_match_database_request(self, upstream_host, upstream_port, downstream_host, downstream_port): + """ + target: test cdc without not match database + method: create collection/insert/load/flush in upstream + expected: create successfully + """ + connections.connect(host=upstream_host, port=upstream_port) + db.create_database("hoo") + db.using_database(db_name="hoo") + col_list = [] + for i in range(10): + time.sleep(0.1) + collection_name = prefix + "not_match_database_" + datetime.now().strftime('%Y_%m_%d_%H_%M_%S_%f') + col_list.append(collection_name) + # create collections in upstream + for col_name in col_list: + c = Collection(name=col_name, schema=default_schema) + log.info(f"create collection {col_name} in upstream") + # insert data to upstream + nb = 300 + epoch = 10 + for e in range(epoch): + data = [ + [i for i in range(nb)], + [np.float32(i) for i in range(nb)], + [str(i) for i in range(nb)], + [[random.random() for _ in range(128)] for _ in range(nb)] + ] + c.insert(data) + c.flush() + index_params = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} + c.create_index("float_vector", index_params) + c.load() + # check collections in upstream + assert set(col_list).issubset(set(list_collections())) + # check collections in downstream + connections.disconnect("default") + self.connect_downstream(downstream_host, downstream_port) + idx = 0 + db_list = db.list_database() + log.info(f"all collections in downstream {db_list}") + while idx < 10: + db_list = db.list_database() + if len(db_list) != 1: + log.info(f"all collections in downstream {db_list}") + time.sleep(3) + idx += 1 + assert len(db_list) == 1, f"collections in downstream {db.list_database()}" + + def test_cdc_sync_match_database_request(self, upstream_host, upstream_port, downstream_host, downstream_port): + """ + target: test cdc with match database + method: create collection/insert/load/flush in upstream + expected: create successfully + """ + connections.connect(host=upstream_host, port=upstream_port) + db.create_database("foo") + db.using_database(db_name="foo") + col_list = [] + col_name = prefix + "match_database_" + datetime.now().strftime('%Y_%m_%d_%H_%M_%S_%f') + # create collections in upstream + c = Collection(name=col_name, schema=default_schema) + log.info(f"create collection {col_name} in upstream") + # insert data to upstream + nb = 300 + epoch = 10 + for e in range(epoch): + data = [ + [i for i in range(nb)], + [np.float32(i) for i in range(nb)], + [str(i) for i in range(nb)], + [[random.random() for _ in range(128)] for _ in range(nb)] + ] + c.insert(data) + c.flush() + index_params = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} + c.create_index("float_vector", index_params) + c.load() + # check collections in upstream + assert set(col_list).issubset(set(list_collections())) + # check collections in downstream + connections.disconnect("default") + self.connect_downstream(downstream_host, downstream_port) + + db_list = db.list_database() + log.info(f"all collections in downstream {db_list}") + timeout = 20 + t0 = time.time() + db_name = "foo" + + while time.time() - t0 < timeout: + if db_name in db.list_database(): + log.info(f"database synced in downstream successfully cost time: {time.time() - t0:.2f}s") + break + time.sleep(2) + if time.time() - t0 > timeout: + log.info(f"database synced in downstream failed with timeout: {time.time() - t0:.2f}s") + log.info(f"database in downstream {db.list_database()}") + assert db_name in db.list_database() + + db.using_database(db_name=db_name) + c_downstream = Collection(name=col_name) + timeout = 60 + t0 = time.time() + log.info(f"all collections in downstream {list_collections()}") + while time.time() - t0 < timeout: + if time.time() - t0 > timeout: + log.info(f"collection synced in downstream failed with timeout: {time.time() - t0:.2f}s") + break + # get the number of entities in downstream + if c_downstream.num_entities != nb: + log.info(f"sync progress:{c_downstream.num_entities / (nb*epoch) * 100:.2f}%") + # collections in subset of downstream + if c_downstream.num_entities == nb*epoch: + log.info(f"collection synced in downstream successfully cost time: {time.time() - t0:.2f}s") + break + time.sleep(1) + try: + c_downstream.flush(timeout=5) + except Exception as e: + log.info(f"flush err: {str(e)}") + assert c_downstream.num_entities == nb*epoch