diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c537ae4..798d574 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -93,6 +93,12 @@ jobs: "enable_tls": false, "connect_timeout": 120 }, + "extra_info": { + "enable_user_role": true + }, + "database_info": { + "name": "*" + }, "collection_infos": [ { "name": "*" @@ -253,6 +259,12 @@ jobs: "token": "root:Milvus", "connect_timeout": 120 }, + "extra_info": { + "enable_user_role": true + }, + "database_info": { + "name": "*" + }, "collection_infos": [ { "name": "*" @@ -390,6 +402,12 @@ jobs: "connect_timeout": 120, "channel_num": 16 }, + "extra_info": { + "enable_user_role": true + }, + "database_info": { + "name": "*" + }, "collection_infos": [ { "name": "*" @@ -527,6 +545,12 @@ jobs: "connect_timeout": 120, "channel_num": 8 }, + "extra_info": { + "enable_user_role": true + }, + "database_info": { + "name": "*" + }, "collection_infos": [ { "name": "*" @@ -662,6 +686,9 @@ jobs: "token": "root:Milvus", "connect_timeout": 120 }, + "extra_info": { + "enable_user_role": true + }, "collection_infos": [ { "name": "*" diff --git a/core/reader/collection_reader.go b/core/reader/collection_reader.go index 636c146..425be3f 100644 --- a/core/reader/collection_reader.go +++ b/core/reader/collection_reader.go @@ -41,7 +41,9 @@ import ( ) const ( - AllCollection = "*" + AllCollection = "*" + AllDatabase = "*" + DefaultDatabase = "default" ) type CollectionInfo struct { diff --git a/core/util/msgpack.go b/core/util/msgpack.go index 5aa1a33..dd55b97 100644 --- a/core/util/msgpack.go +++ b/core/util/msgpack.go @@ -22,6 +22,7 @@ import ( "bytes" "sync" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/requestutil" ) @@ -75,3 +76,17 @@ func GetCollectionIDFromMsgPack(msgPack *msgstream.MsgPack) int64 { collectionID, _ := GetCollectionIDFromRequest(firstMsg) return collectionID } + +func IsUserRoleMessage(msgPack *msgstream.MsgPack) bool { + if len(msgPack.Msgs) == 0 { + return false + } + msgType := msgPack.Msgs[0].Type() + return msgType == commonpb.MsgType_CreateCredential || + msgType == commonpb.MsgType_DeleteCredential || + msgType == commonpb.MsgType_UpdateCredential || + msgType == commonpb.MsgType_CreateRole || + msgType == commonpb.MsgType_DropRole || + msgType == commonpb.MsgType_OperateUserRole || + msgType == commonpb.MsgType_OperatePrivilege +} diff --git a/server/cdc_impl.go b/server/cdc_impl.go index 151155d..f0a84f5 100644 --- a/server/cdc_impl.go +++ b/server/cdc_impl.go @@ -84,6 +84,7 @@ type MetaCDC struct { sync.RWMutex data map[string][]string excludeData map[string][]string + extraInfos map[string]model.ExtraInfo } cdcTasks struct { sync.RWMutex @@ -142,6 +143,7 @@ func NewMetaCDC(serverConfig *CDCServerConfig) *MetaCDC { cdc.collectionNames.data = make(map[string][]string) cdc.collectionNames.excludeData = make(map[string][]string) + cdc.collectionNames.extraInfos = make(map[string]model.ExtraInfo) cdc.cdcTasks.data = make(map[string]*meta.TaskInfo) cdc.replicateEntityMap.data = make(map[string]*ReplicateEntity) return cdc @@ -176,6 +178,7 @@ func (e *MetaCDC) ReloadTask() { }) e.collectionNames.data[uKey] = append(e.collectionNames.data[uKey], newCollectionNames...) e.collectionNames.excludeData[uKey] = append(e.collectionNames.excludeData[uKey], taskInfo.ExcludeCollections...) + e.collectionNames.excludeData[uKey] = lo.Uniq(e.collectionNames.excludeData[uKey]) e.cdcTasks.Lock() e.cdcTasks.data[taskInfo.TaskID] = taskInfo e.cdcTasks.Unlock() @@ -231,6 +234,99 @@ func getTaskUniqueIDFromReq(req *request.CreateRequest) string { panic("fail to get the task unique id") } +func getDatabaseName(i any) string { + switch r := i.(type) { + case *meta.TaskInfo: + if r.DatabaseInfo.Name != "" { + return r.DatabaseInfo.Name + } + return cdcreader.DefaultDatabase + case *request.CreateRequest: + if r.DatabaseInfo.Name != "" { + return r.DatabaseInfo.Name + } + return cdcreader.DefaultDatabase + default: + panic("invalid type") + } +} + +func getFullCollectionName(collectionName string, databaseName string) string { + return fmt.Sprintf("%s.%s", databaseName, collectionName) +} + +func getCollectionNameFromFull(fullName string) (string, string) { + names := strings.Split(fullName, ".") + if len(names) != 2 { + panic("invalid full collection name") + } + return names[0], names[1] +} + +func matchCollectionName(sampleCollection, targetCollection string) (bool, bool) { + db1, collection1 := getCollectionNameFromFull(sampleCollection) + db2, collection2 := getCollectionNameFromFull(targetCollection) + return (db1 == db2 || db1 == cdcreader.AllDatabase) && + (collection1 == collection2 || collection1 == cdcreader.AllCollection), + db1 == cdcreader.AllDatabase || collection1 == cdcreader.AllCollection +} + +func (e *MetaCDC) checkDuplicateCollection(uKey string, newCollectionNames []string, extraInfo model.ExtraInfo) ([]string, error) { + e.collectionNames.Lock() + defer e.collectionNames.Unlock() + existExtraInfo := e.collectionNames.extraInfos[uKey] + if existExtraInfo.EnableUserRole && extraInfo.EnableUserRole { + return nil, servererror.NewClientError("the enable user role param is duplicate") + } + if names, ok := e.collectionNames.data[uKey]; ok { + var duplicateCollections []string + containsAny := false + for _, name := range names { + d, c := getCollectionNameFromFull(name) + if d == cdcreader.AllDatabase || c == cdcreader.AllCollection { + containsAny = true + } + } + for _, newCollectionName := range newCollectionNames { + if lo.Contains(names, newCollectionName) { + duplicateCollections = append(duplicateCollections, newCollectionName) + continue + } + nd, nc := getCollectionNameFromFull(newCollectionName) + if nd == cdcreader.AllDatabase && nc == cdcreader.AllCollection { + continue + } + if containsAny && !lo.Contains(e.collectionNames.excludeData[uKey], newCollectionName) { + duplicateCollections = append(duplicateCollections, newCollectionName) + continue + } + } + if len(duplicateCollections) > 0 { + log.Info("duplicate collections", + zap.Strings("request_collections", newCollectionNames), + zap.Strings("exist_collections", names), + zap.Strings("exclude_collections", e.collectionNames.excludeData[uKey]), + zap.Strings("duplicate_collections", duplicateCollections)) + return nil, servererror.NewClientError(fmt.Sprintf("the collection name is duplicate with existing task, %v", duplicateCollections)) + } + } + // release lock early to accept other requests + var excludeCollectionNames []string + for _, newCollectionName := range newCollectionNames { + for _, existCollectionName := range e.collectionNames.data[uKey] { + if match, _ := matchCollectionName(newCollectionName, existCollectionName); match { + excludeCollectionNames = append(excludeCollectionNames, existCollectionName) + } + } + } + e.collectionNames.excludeData[uKey] = append(e.collectionNames.excludeData[uKey], excludeCollectionNames...) + e.collectionNames.data[uKey] = append(e.collectionNames.data[uKey], newCollectionNames...) + e.collectionNames.extraInfos[uKey] = model.ExtraInfo{ + EnableUserRole: existExtraInfo.EnableUserRole || extraInfo.EnableUserRole, + } + return excludeCollectionNames, nil +} + func (e *MetaCDC) Create(req *request.CreateRequest) (resp *request.CreateResponse, err error) { defer func() { log.Info("create request done") @@ -242,53 +338,19 @@ func (e *MetaCDC) Create(req *request.CreateRequest) (resp *request.CreateRespon return nil, err } uKey := getTaskUniqueIDFromReq(req) + databaseName := getDatabaseName(req) newCollectionNames := lo.Map(req.CollectionInfos, func(t model.CollectionInfo, _ int) string { - return t.Name + return getFullCollectionName(t.Name, databaseName) }) - e.collectionNames.Lock() - if names, ok := e.collectionNames.data[uKey]; ok { - existAll := lo.Contains(names, cdcreader.AllCollection) - duplicateCollections := lo.Filter(req.CollectionInfos, func(info model.CollectionInfo, _ int) bool { - return (!existAll && lo.Contains(names, info.Name)) || (existAll && info.Name == cdcreader.AllCollection) - }) - if len(duplicateCollections) > 0 { - e.collectionNames.Unlock() - return nil, servererror.NewClientError(fmt.Sprintf("some collections are duplicate with existing tasks, %v", lo.Map(duplicateCollections, func(t model.CollectionInfo, i int) string { - return t.Name - }))) - } - if existAll { - excludeCollectionNames := lo.Filter(e.collectionNames.excludeData[uKey], func(s string, _ int) bool { - return !lo.Contains(names, s) - }) - duplicateCollections = lo.Filter(req.CollectionInfos, func(info model.CollectionInfo, _ int) bool { - return !lo.Contains(excludeCollectionNames, info.Name) - }) - if len(duplicateCollections) > 0 { - e.collectionNames.Unlock() - return nil, servererror.NewClientError(fmt.Sprintf("some collections are duplicate with existing tasks, check the `*` collection task and other tasks, %v", lo.Map(duplicateCollections, func(t model.CollectionInfo, i int) string { - return t.Name - }))) - } - } - } - // release lock early to accept other requests - var excludeCollectionNames []string - if newCollectionNames[0] == cdcreader.AllCollection { - existCollectionNames := e.collectionNames.data[uKey] - excludeCollectionNames = make([]string, len(existCollectionNames)) - copy(excludeCollectionNames, existCollectionNames) - e.collectionNames.excludeData[uKey] = excludeCollectionNames + excludeCollectionNames, err := e.checkDuplicateCollection(uKey, newCollectionNames, req.ExtraInfo) + if err != nil { + return nil, err } - e.collectionNames.data[uKey] = append(e.collectionNames.data[uKey], newCollectionNames...) - e.collectionNames.Unlock() revertCollectionNames := func() { e.collectionNames.Lock() defer e.collectionNames.Unlock() - if newCollectionNames[0] == cdcreader.AllCollection { - e.collectionNames.excludeData[uKey] = []string{} - } + e.collectionNames.excludeData[uKey] = lo.Without(e.collectionNames.excludeData[uKey], excludeCollectionNames...) e.collectionNames.data[uKey] = lo.Without(e.collectionNames.data[uKey], newCollectionNames...) } @@ -308,6 +370,7 @@ func (e *MetaCDC) Create(req *request.CreateRequest) (resp *request.CreateRespon DatabaseInfo: req.DatabaseInfo, CollectionInfos: req.CollectionInfos, RPCRequestChannelInfo: req.RPCChannelInfo, + ExtraInfo: req.ExtraInfo, ExcludeCollections: excludeCollectionNames, WriterCacheConfig: req.BufferConfig, State: meta.TaskStateInitial, @@ -954,9 +1017,9 @@ func replicateMetric(info *meta.TaskInfo, channelName string, msgPack *msgstream func (e *MetaCDC) getChannelReader(info *meta.TaskInfo, replicateEntity *ReplicateEntity, channelName, channelPosition string) (api.Reader, error) { taskLog := log.With(zap.String("task_id", info.TaskID)) collectionName := info.CollectionNames()[0] - databaseName := info.DatabaseInfo.Name + databaseName := getDatabaseName(info) isAnyCollection := collectionName == cdcreader.AllCollection - isAnyDatabase := databaseName == "" + isAnyDatabase := databaseName == cdcreader.AllDatabase // isTmpCollection := collectionName == model.TmpCollectionName dataHandleFunc := func(funcCtx context.Context, pack *msgstream.MsgPack) bool { @@ -974,7 +1037,15 @@ func (e *MetaCDC) getChannelReader(info *meta.TaskInfo, replicateEntity *Replica msgDatabaseName := util.GetDatabaseNameFromMsgPack(pack) // TODO it should be changed if replicate the user and role info or multi collection // TODO how to handle it when there are "*" and "foo" collection names in the task list - if (!isAnyCollection && msgCollectionName != collectionName) || + if msgCollectionName == "" && msgDatabaseName == "" { + extraSkip := true + if info.ExtraInfo.EnableUserRole && util.IsUserRoleMessage(pack) { + extraSkip = false + } + if extraSkip { + return true + } + } else if (!isAnyCollection && msgCollectionName != collectionName) || (!isAnyDatabase && msgDatabaseName != databaseName) { // skip the message if the collection name is not equal to the task collection name return true @@ -1101,9 +1172,7 @@ func (e *MetaCDC) delete(taskID string) error { uKey = milvusURI + kafkaAddress collectionNames := info.CollectionNames() e.collectionNames.Lock() - if collectionNames[0] == cdcreader.AllCollection { - e.collectionNames.excludeData[uKey] = []string{} - } + e.collectionNames.excludeData[uKey] = lo.Without(e.collectionNames.excludeData[uKey], info.ExcludeCollections...) e.collectionNames.data[uKey] = lo.Without(e.collectionNames.data[uKey], collectionNames...) e.collectionNames.Unlock() @@ -1259,7 +1328,9 @@ func (e *MetaCDC) Maintenance(req *request.MaintenanceRequest) (*request.Mainten } func GetShouldReadFunc(taskInfo *meta.TaskInfo) cdcreader.ShouldReadFunc { - isAll := taskInfo.CollectionInfos[0].Name == cdcreader.AllCollection + isAllCollection := taskInfo.CollectionInfos[0].Name == cdcreader.AllCollection + databaseName := getDatabaseName(taskInfo) + isAllDataBase := databaseName == cdcreader.AllDatabase return func(databaseInfo *coremodel.DatabaseInfo, collectionInfo *pb.CollectionInfo) bool { currentCollectionName := collectionInfo.Schema.Name if databaseInfo.Dropped { @@ -1267,13 +1338,13 @@ func GetShouldReadFunc(taskInfo *meta.TaskInfo) cdcreader.ShouldReadFunc { return false } - notStarContains := !isAll && lo.ContainsBy(taskInfo.CollectionInfos, func(taskCollectionInfo model.CollectionInfo) bool { + notStarContains := !isAllCollection && lo.ContainsBy(taskInfo.CollectionInfos, func(taskCollectionInfo model.CollectionInfo) bool { return taskCollectionInfo.Name == currentCollectionName }) - starContains := isAll && !lo.ContainsBy(taskInfo.ExcludeCollections, func(s string) bool { + starContains := isAllCollection && !lo.ContainsBy(taskInfo.ExcludeCollections, func(s string) bool { return s == currentCollectionName }) - dbMatch := taskInfo.DatabaseInfo.Name == "" || + dbMatch := isAllDataBase || taskInfo.DatabaseInfo.Name == databaseInfo.Name return (notStarContains || starContains) && dbMatch diff --git a/server/cdc_impl_test.go b/server/cdc_impl_test.go index 8785434..29b3124 100644 --- a/server/cdc_impl_test.go +++ b/server/cdc_impl_test.go @@ -26,6 +26,7 @@ import ( "time" "github.com/cockroachdb/errors" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" clientv3 "go.etcd.io/etcd/client/v3" @@ -699,6 +700,7 @@ func initMetaCDCMap(cdc *MetaCDC) { cdc.collectionNames.Lock() cdc.collectionNames.data = map[string][]string{} cdc.collectionNames.excludeData = map[string][]string{} + cdc.collectionNames.extraInfos = map[string]model.ExtraInfo{} cdc.collectionNames.Unlock() cdc.cdcTasks.Lock() @@ -1386,3 +1388,154 @@ func TestPauseTask(t *testing.T) { assert.True(t, isQuit.Load()) }) } + +func TestCheckDuplicateCollection(t *testing.T) { + t.Run("check enable user role", func(t *testing.T) { + metaCDC := &MetaCDC{} + initMetaCDCMap(metaCDC) + _, err := metaCDC.checkDuplicateCollection("foo", []string{}, model.ExtraInfo{ + EnableUserRole: true, + }) + assert.NoError(t, err) + + _, err = metaCDC.checkDuplicateCollection("foo", []string{}, model.ExtraInfo{ + EnableUserRole: true, + }) + assert.Error(t, err) + + _, err = metaCDC.checkDuplicateCollection("hoo", []string{}, model.ExtraInfo{ + EnableUserRole: true, + }) + assert.NoError(t, err) + }) + + t.Run("default db", func(t *testing.T) { + metaCDC := &MetaCDC{} + initMetaCDCMap(metaCDC) + + excludeCollections, err := metaCDC.checkDuplicateCollection("foo", []string{ + getFullCollectionName("foo", "default"), + getFullCollectionName("hoo", "default"), + }, model.ExtraInfo{ + EnableUserRole: true, + }) + assert.NoError(t, err) + assert.Len(t, excludeCollections, 0) + + excludeCollections, err = metaCDC.checkDuplicateCollection("foo", []string{ + getFullCollectionName("*", "default"), + }, model.ExtraInfo{}) + assert.NoError(t, err) + assert.Len(t, excludeCollections, 2) + + _, err = metaCDC.checkDuplicateCollection("foo", []string{ + getFullCollectionName("doo", "default"), + }, model.ExtraInfo{}) + assert.Error(t, err) + + metaCDC.collectionNames.Lock() + metaCDC.collectionNames.data["foo"] = lo.Without(metaCDC.collectionNames.data["foo"], getFullCollectionName("foo", "default")) + metaCDC.collectionNames.Unlock() + + _, err = metaCDC.checkDuplicateCollection("foo", []string{ + getFullCollectionName("hoo", "default"), + }, model.ExtraInfo{}) + assert.Error(t, err) + + excludeCollections, err = metaCDC.checkDuplicateCollection("foo", []string{ + getFullCollectionName("foo", "default"), + }, model.ExtraInfo{}) + assert.NoError(t, err) + assert.Len(t, excludeCollections, 0) + }) + + t.Run("more dbs", func(t *testing.T) { + metaCDC := &MetaCDC{} + initMetaCDCMap(metaCDC) + + excludeCollections, err := metaCDC.checkDuplicateCollection("foo", []string{ + getFullCollectionName("foo", "db1"), + getFullCollectionName("foo", "db2"), + }, model.ExtraInfo{ + EnableUserRole: true, + }) + assert.NoError(t, err) + assert.Len(t, excludeCollections, 0) + + excludeCollections, err = metaCDC.checkDuplicateCollection("foo", []string{ + getFullCollectionName("foo", "*"), + }, model.ExtraInfo{}) + assert.NoError(t, err) + assert.Len(t, excludeCollections, 2) + + _, err = metaCDC.checkDuplicateCollection("foo", []string{ + getFullCollectionName("foo", "db3"), + }, model.ExtraInfo{}) + assert.Error(t, err) + + metaCDC.collectionNames.Lock() + metaCDC.collectionNames.data["foo"] = lo.Without(metaCDC.collectionNames.data["foo"], getFullCollectionName("foo", "db1")) + metaCDC.collectionNames.Unlock() + + _, err = metaCDC.checkDuplicateCollection("foo", []string{ + getFullCollectionName("foo", "db2"), + }, model.ExtraInfo{}) + assert.Error(t, err) + + excludeCollections, err = metaCDC.checkDuplicateCollection("foo", []string{ + getFullCollectionName("foo", "db1"), + }, model.ExtraInfo{}) + assert.NoError(t, err) + assert.Len(t, excludeCollections, 0) + }) + + t.Run("collection and db mix", func(t *testing.T) { + metaCDC := &MetaCDC{} + initMetaCDCMap(metaCDC) + + excludeCollections, err := metaCDC.checkDuplicateCollection("foo", []string{ + getFullCollectionName("foo", "*"), + getFullCollectionName("hoo", "db2"), + }, model.ExtraInfo{ + EnableUserRole: true, + }) + assert.NoError(t, err) + assert.Len(t, excludeCollections, 0) + + excludeCollections, err = metaCDC.checkDuplicateCollection("foo", []string{ + getFullCollectionName("*", "*"), + }, model.ExtraInfo{}) + assert.NoError(t, err) + assert.Len(t, excludeCollections, 2) + + _, err = metaCDC.checkDuplicateCollection("foo", []string{ + getFullCollectionName("foo", "db3"), + }, model.ExtraInfo{}) + assert.Error(t, err) + + metaCDC.collectionNames.Lock() + metaCDC.collectionNames.data["foo"] = lo.Without(metaCDC.collectionNames.data["foo"], getFullCollectionName("foo", "*")) + metaCDC.collectionNames.Unlock() + + _, err = metaCDC.checkDuplicateCollection("foo", []string{ + getFullCollectionName("foo", "db2"), + }, model.ExtraInfo{}) + assert.Error(t, err) + + _, err = metaCDC.checkDuplicateCollection("foo", []string{ + getFullCollectionName("hoo", "*"), + }, model.ExtraInfo{}) + assert.Error(t, err) + + _, err = metaCDC.checkDuplicateCollection("foo", []string{ + getFullCollectionName("hoo", "db2"), + }, model.ExtraInfo{}) + assert.Error(t, err) + + excludeCollections, err = metaCDC.checkDuplicateCollection("foo", []string{ + getFullCollectionName("foo", "*"), + }, model.ExtraInfo{}) + assert.NoError(t, err) + assert.Len(t, excludeCollections, 0) + }) +} diff --git a/server/model/common.go b/server/model/common.go index 6a9129c..686d7e5 100644 --- a/server/model/common.go +++ b/server/model/common.go @@ -70,6 +70,10 @@ type ChannelInfo struct { Position string `json:"position" mapstructure:"position"` } +type ExtraInfo struct { + EnableUserRole bool `json:"enable_user_role" mapstructure:"enable_user_role"` +} + type BufferConfig struct { Period int `json:"period" mapstructure:"period"` Size int `json:"size" mapstructure:"size"` diff --git a/server/model/meta/task.go b/server/model/meta/task.go index ea5d84a..21e76c3 100644 --- a/server/model/meta/task.go +++ b/server/model/meta/task.go @@ -66,6 +66,7 @@ type TaskInfo struct { CollectionInfos []model.CollectionInfo DatabaseInfo model.DatabaseInfo RPCRequestChannelInfo model.ChannelInfo + ExtraInfo model.ExtraInfo ExcludeCollections []string // it's used for the `*` collection name State TaskState Reason string diff --git a/server/model/request/create.go b/server/model/request/create.go index 0683e75..467439c 100644 --- a/server/model/request/create.go +++ b/server/model/request/create.go @@ -27,6 +27,7 @@ type CreateRequest struct { CollectionInfos []model.CollectionInfo `json:"collection_infos" mapstructure:"collection_infos"` DatabaseInfo model.DatabaseInfo `json:"database_info" mapstructure:"database_info"` RPCChannelInfo model.ChannelInfo `json:"rpc_channel_info" mapstructure:"rpc_channel_info"` + ExtraInfo model.ExtraInfo `json:"extra_info" mapstructure:"extra_info"` BufferConfig model.BufferConfig `json:"buffer_config" mapstructure:"buffer_config"` // Deprecated Positions map[string]string `json:"positions" mapstructure:"positions"`