diff --git a/helper.go b/helper.go index ac79668..08da786 100644 --- a/helper.go +++ b/helper.go @@ -111,6 +111,27 @@ func stringToObjectID(object map[string]interface{}) error { return nil } +// sliceToObjectID converts _id key from slice of strings to slice of bson.ObjectId +func sliceToObjectID(object map[string]interface{}) error { + if id, ok := object["id"]; ok { + delete(object, "id") + ids := strings.Split(id.(string), ",") + bsonIds := []bson.ObjectId{} + for _, id := range ids { + if !bson.IsObjectIdHex(id) { + return ErrInvalidInput("id is a invalid hex representation of an ObjectId") + } + + if reflect.TypeOf(id).String() != "bson.ObjectId" { + bsonIds = append(bsonIds, bson.ObjectIdHex(id)) + } + } + object["_id"] = bsonIds + } + + return nil +} + // IsConditionalCheckErr check if err is dynamoDB condition error func IsConditionalCheckErr(err error) bool { if ae, ok := err.(awserr.RequestFailure); ok { diff --git a/mongodb.go b/mongodb.go index d937367..b9e4194 100644 --- a/mongodb.go +++ b/mongodb.go @@ -218,8 +218,17 @@ func (s *MongoSession) GetAll(filter Filter, resultsTypeHint interface{}, order slicePointer.Elem().Set(results) if !s.repoDef.IsCustomID() { - if err := stringToObjectID(filter); err != nil { - return nil, ErrInvalidInput(err) + if id, ok := filter["id"]; ok { + // check if id field contains values separated by comma + if ok := strings.Contains(id.(string), ","); ok { + if err := sliceToObjectID(filter); err != nil { + return nil, ErrInvalidInput(err) + } + } else { + if err := stringToObjectID(filter); err != nil { + return nil, ErrInvalidInput(err) + } + } } } @@ -417,6 +426,24 @@ func toMongoFilter(filter Filter) (map[string]interface{}, error) { } return nil, fmt.Errorf("unknown filter specification - supported type is $pattern") } + // if filter key contains multiple values to search by + if val, ok := value.(string); ok { + if values := strings.Split(val, ","); len(values) > 1 { + mgf[key] = bson.M{ + "$in": values, + } + continue + } + } + // if filter _id contains multiple id to search by + if val, ok := value.([]bson.ObjectId); ok { + if len(val) > 1 { + mgf[key] = bson.M{ + "$in": val, + } + continue + } + } mgf[key] = value // copy over the key=>value pairs to do exact matching } return mgf, nil