diff --git a/e2e/mongo/nested/gen_mongo_config.go b/e2e/mongo/nested/gen_mongo_config.go index a8f3daaf..c685b954 100644 --- a/e2e/mongo/nested/gen_mongo_config.go +++ b/e2e/mongo/nested/gen_mongo_config.go @@ -3,6 +3,7 @@ package nested import ( "context" "fmt" + "sync" "github.com/ezbuy/ezorm/v2/pkg/db" "github.com/ezbuy/wrapper/database" @@ -36,6 +37,7 @@ func WithPostHooks(fn ...func()) SetupOptionFn { } var mongoDriver *db.MongoDriver +var mongoDriverOnce sync.Once func MgoSetup(config *db.MongoConfig, opts ...SetupOptionFn) { sopt := &SetupOption{} @@ -46,23 +48,30 @@ func MgoSetup(config *db.MongoConfig, opts ...SetupOptionFn) { sopt.postHooks = append(sopt.postHooks, UserIndexesFunc, ) - var dopt []db.MongoDriverOption + var dopt []db.MongoDriverConnOptionFn if sopt.monitor != nil { - dopt = append(dopt, db.WithPoolMonitor(database.NewMongoDriverMonitor(sopt.monitor))) + clientOpt := db.WithClientOption(db.WithPoolMonitor(database.NewMongoDriverMonitor(sopt.monitor))) + dopt = append(dopt, clientOpt) } - db.Setup(config) - - var err error - mongoDriver, err = db.NewMongoDriver( - context.Background(), - dopt..., - ) - if err != nil { - panic(fmt.Errorf("failed to create mongodb driver: %s", err)) - } - for _, hook := range sopt.postHooks { - hook() + if config.DBName == "" { + panic("db name is required") } + db.SetupMany(config) + dopt = append(dopt, db.WithDBName(config.DBName)) + + mongoDriverOnce.Do(func() { + var err error + mongoDriver, err = db.NewMongoDriverBy( + context.Background(), + dopt..., + ) + if err != nil { + panic(fmt.Errorf("failed to create mongodb driver: %s", err)) + } + for _, hook := range sopt.postHooks { + hook() + } + }) } func Col(col string) *mongo.Collection { diff --git a/e2e/mongo/user/gen_mongo_config.go b/e2e/mongo/user/gen_mongo_config.go index 7fb5a5c1..884af8eb 100644 --- a/e2e/mongo/user/gen_mongo_config.go +++ b/e2e/mongo/user/gen_mongo_config.go @@ -3,6 +3,7 @@ package user import ( "context" "fmt" + "sync" "github.com/ezbuy/ezorm/v2/pkg/db" "github.com/ezbuy/wrapper/database" @@ -36,6 +37,7 @@ func WithPostHooks(fn ...func()) SetupOptionFn { } var mongoDriver *db.MongoDriver +var mongoDriverOnce sync.Once func MgoSetup(config *db.MongoConfig, opts ...SetupOptionFn) { sopt := &SetupOption{} @@ -47,23 +49,30 @@ func MgoSetup(config *db.MongoConfig, opts ...SetupOptionFn) { UserIndexesFunc, UserBlogIndexesFunc, ) - var dopt []db.MongoDriverOption + var dopt []db.MongoDriverConnOptionFn if sopt.monitor != nil { - dopt = append(dopt, db.WithPoolMonitor(database.NewMongoDriverMonitor(sopt.monitor))) + clientOpt := db.WithClientOption(db.WithPoolMonitor(database.NewMongoDriverMonitor(sopt.monitor))) + dopt = append(dopt, clientOpt) } - db.Setup(config) - - var err error - mongoDriver, err = db.NewMongoDriver( - context.Background(), - dopt..., - ) - if err != nil { - panic(fmt.Errorf("failed to create mongodb driver: %s", err)) - } - for _, hook := range sopt.postHooks { - hook() + if config.DBName == "" { + panic("db name is required") } + db.SetupMany(config) + dopt = append(dopt, db.WithDBName(config.DBName)) + + mongoDriverOnce.Do(func() { + var err error + mongoDriver, err = db.NewMongoDriverBy( + context.Background(), + dopt..., + ) + if err != nil { + panic(fmt.Errorf("failed to create mongodb driver: %s", err)) + } + for _, hook := range sopt.postHooks { + hook() + } + }) } func Col(col string) *mongo.Collection { diff --git a/e2e/mongo/user/user_test.go b/e2e/mongo/user/user_test.go index b94d9262..868cc822 100644 --- a/e2e/mongo/user/user_test.go +++ b/e2e/mongo/user/user_test.go @@ -7,19 +7,21 @@ import ( "testing" "time" + "github.com/ezbuy/ezorm/v2/e2e/mongo/nested" "github.com/ezbuy/ezorm/v2/e2e/mongo/user" "github.com/ezbuy/ezorm/v2/pkg/db" "github.com/ezbuy/ezorm/v2/pkg/orm" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) -func getConfigFromEnv() *db.MongoConfig { +func getConfigFromEnv(dbName string) *db.MongoConfig { return &db.MongoConfig{ - DBName: "ezorm", + DBName: dbName, MongoDB: fmt.Sprintf( "mongodb://%s:%s@%s:%s", os.Getenv("MONGO_USER"), @@ -31,7 +33,7 @@ func getConfigFromEnv() *db.MongoConfig { } func TestMain(m *testing.M) { - user.MgoSetup(getConfigFromEnv()) + user.MgoSetup(getConfigFromEnv("ezorm")) expr := 3600 exprInt32 := int32(expr) var exist, created int @@ -79,6 +81,43 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } +func TestOperateMultipleDB(t *testing.T) { + // fetch user from ezorm + user.MgoSetup(getConfigFromEnv("ezorm")) + // fetch user from ezorm_nested + nested.MgoSetup(getConfigFromEnv("ezorm_nested")) + + u1 := user.Get_UserMgr().NewUser() + u1.Username = "username_1" + if _, err := u1.Save(context.TODO()); err != nil { + t.Fatalf("failed to save user: %s", err) + } + + c := nested.Get_UserMgr().Count(context.TODO(), bson.M{}) + require.Equalf(t, 0, c, "unexpected count of users, got: %d, expect: %d", c, 0) + + u2 := nested.Get_UserMgr().NewUser() + u2.Username = "username_2" + if _, err := u2.Save(context.TODO()); err != nil { + t.Fatalf("failed to save user: %s", err) + } + + c2 := nested.Get_UserMgr().Count(context.TODO(), bson.M{}) + require.Equalf(t, 1, c2, "unexpected count of users, got: %d, expect: %d", c2, 1) + + c3 := user.Get_UserMgr().Count(context.TODO(), bson.M{}) + require.Equalf(t, 1, c3, "unexpected count of users, got: %d, expect: %d", c3, 1) + + t.Cleanup(func() { + if _, err := user.Get_UserMgr().RemoveAll(context.TODO(), nil); err != nil { + t.Fatalf("failed to remove all users: %s", err) + } + if _, err := nested.Get_UserMgr().RemoveAll(context.TODO(), nil); err != nil { + t.Fatalf("failed to remove all users: %s", err) + } + }) +} + func TestSave(t *testing.T) { ctx := context.TODO() u1 := user.Get_UserMgr().NewUser() diff --git a/go.mod b/go.mod index 81d8dd18..4e1a1ad9 100644 --- a/go.mod +++ b/go.mod @@ -94,3 +94,5 @@ require ( gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0 // indirect ) + +replace github.com/ezbuy/ezorm/v2/pkg => ./v2/pkg diff --git a/go.sum b/go.sum index 5f2a5a9c..57f7576a 100644 --- a/go.sum +++ b/go.sum @@ -173,8 +173,6 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.m github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= -github.com/ezbuy/ezorm/v2/pkg v0.0.10 h1:F5NCRrW+1pn67OBZgSp5XzUJPIcDS47xLzkgt8quhrw= -github.com/ezbuy/ezorm/v2/pkg v0.0.10/go.mod h1:OJv0Xk6t+tGCClK0OaRMubrdjfDVb/uEgqVLs6btC7E= github.com/ezbuy/statsd v0.0.0-20190521031639-ae237584062d h1:SeHyPo+ykKm5DKcBUDEqTW5FNDS2kIBx6iFie3OttPw= github.com/ezbuy/statsd v0.0.0-20190521031639-ae237584062d/go.mod h1:RUr3GtpMTjto7ygeUCLHDs6uBzvV5aHSzGUbr8YZkJU= github.com/ezbuy/utils v0.0.0-20170609090716-8ac4beef008f h1:f6mjy3cXO0yeEROA2hs94Gj3iOpQExCZt0bP9O4dW98= diff --git a/internal/parser/shared/tpl/mongo_config.gogo b/internal/parser/shared/tpl/mongo_config.gogo index b80c5814..a16b406b 100644 --- a/internal/parser/shared/tpl/mongo_config.gogo +++ b/internal/parser/shared/tpl/mongo_config.gogo @@ -6,6 +6,7 @@ package {{ $first.GoPackage }} import ( "context" "fmt" + "sync" "github.com/ezbuy/ezorm/v2/pkg/db" "github.com/ezbuy/wrapper/database" @@ -40,6 +41,7 @@ func WithPostHooks(fn ...func()) SetupOptionFn { } var mongoDriver *db.MongoDriver +var mongoDriverOnce sync.Once func MgoSetup(config *db.MongoConfig, opts ...SetupOptionFn) { sopt := &SetupOption{} @@ -54,23 +56,30 @@ func MgoSetup(config *db.MongoConfig, opts ...SetupOptionFn) { {{- end}} {{- end}} ) - var dopt []db.MongoDriverOption + var dopt []db.MongoDriverConnOptionFn if sopt.monitor != nil { - dopt = append(dopt, db.WithPoolMonitor(database.NewMongoDriverMonitor(sopt.monitor))) + clientOpt := db.WithClientOption(db.WithPoolMonitor(database.NewMongoDriverMonitor(sopt.monitor))) + dopt = append(dopt, clientOpt) } - db.Setup(config) - - var err error - mongoDriver, err = db.NewMongoDriver( - context.Background(), - dopt..., - ) - if err != nil { - panic(fmt.Errorf("failed to create mongodb driver: %s", err)) - } - for _, hook := range sopt.postHooks { - hook() + if config.DBName == "" { + panic("db name is required") } + db.SetupMany(config) + dopt = append(dopt, db.WithDBName(config.DBName)) + + mongoDriverOnce.Do(func() { + var err error + mongoDriver, err = db.NewMongoDriverBy( + context.Background(), + dopt..., + ) + if err != nil { + panic(fmt.Errorf("failed to create mongodb driver: %s", err)) + } + for _, hook := range sopt.postHooks { + hook() + } + }) } func Col(col string) *mongo.Collection { diff --git a/v2/pkg/db/mongo_config.go b/v2/pkg/db/mongo_config.go index d8587c2b..b3d33977 100644 --- a/v2/pkg/db/mongo_config.go +++ b/v2/pkg/db/mongo_config.go @@ -1,5 +1,7 @@ package db +import "sync" + var config *MongoConfig type MongoConfig struct { @@ -9,6 +11,39 @@ type MongoConfig struct { MaxSession int } +// Setup setup one mongo config +// For multiple configs, use `SetupMany` instead func Setup(c *MongoConfig) { config = c } + +var multiConfigs sync.Map + +// SetupMany setup many mongo configs +// For singleton , use `Setup` instead +func SetupMany( + cs ...*MongoConfig, +) { + if config == nil && len(cs) == 1 { + Setup(cs[0]) + multiConfigs.Store(cs[0].DBName, cs[0]) + return + } + for _, c := range cs { + multiConfigs.Store(c.DBName, c) + } +} + +func GetConfigByName(name string) *MongoConfig { + if v, ok := multiConfigs.Load(name); ok { + return v.(*MongoConfig) + } + return nil +} + +func GetConfig(name string) *MongoConfig { + if config != nil { + return config + } + return GetConfigByName(name) +} diff --git a/v2/pkg/db/mongodriver.go b/v2/pkg/db/mongodriver.go index 160101c4..29822a6f 100644 --- a/v2/pkg/db/mongodriver.go +++ b/v2/pkg/db/mongodriver.go @@ -11,6 +11,66 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) +type MongoDriverConnOption struct { + clientOption *MongoDriverOption + dbName string +} + +type MongoDriverConnOptionFn func(*MongoDriverConnOption) + +func WithDBName(dbName string) MongoDriverConnOptionFn { + return func(opt *MongoDriverConnOption) { + opt.dbName = dbName + } +} + +func WithClientOption(opt MongoDriverOption) MongoDriverConnOptionFn { + return func(connOpt *MongoDriverConnOption) { + connOpt.clientOption = &opt + } +} + +func NewMongoDriverBy( + ctx context.Context, + opts ...MongoDriverConnOptionFn, +) (*MongoDriver, error) { + if config == nil { + return nil, errors.New("db: initialize config before new mongo driver") + } + connOpt := &MongoDriverConnOption{} + for _, opt := range opts { + opt(connOpt) + } + if connOpt.dbName == "" { + return NewMongoDriver(ctx, *connOpt.clientOption) + } + config := GetConfigByName(connOpt.dbName) + uri := config.MongoDB + if !strings.HasPrefix(uri, "mongodb") { + uri = "mongodb://" + config.MongoDB + } + + cliOpts := options.Client().ApplyURI(uri).SetMaxPoolSize(uint64(config.PoolLimit)) + + cli, err := mongo.NewClient(cliOpts) + if err != nil { + return nil, fmt.Errorf("failed to create mongodb client: %w", err) + } + + if err = cli.Connect(ctx); err != nil { + return nil, fmt.Errorf("failed to connect to mongodb server: %w", err) + } + + if err = cli.Ping(ctx, nil); err != nil { + return nil, fmt.Errorf("failed to ping remote mongodb server: %w", err) + } + + return &MongoDriver{ + cli: cli, + dbName: connOpt.dbName, + }, nil +} + func NewMongoDriver(ctx context.Context, opts ...MongoDriverOption) (*MongoDriver, error) { if config == nil { return nil, errors.New("db: initialize config before new mongo driver") @@ -51,10 +111,14 @@ func WithPoolMonitor(m *event.PoolMonitor) MongoDriverOption { } type MongoDriver struct { - cli *mongo.Client + cli *mongo.Client + dbName string } func (md *MongoDriver) GetCol(cname string) *mongo.Collection { + if md.dbName != "" { + return md.cli.Database(md.dbName).Collection(cname) + } return md.cli.Database(config.DBName).Collection(cname) } diff --git a/v2/pkg/db/mongodriver_test.go b/v2/pkg/db/mongodriver_test.go index ba77cd25..441b6986 100644 --- a/v2/pkg/db/mongodriver_test.go +++ b/v2/pkg/db/mongodriver_test.go @@ -45,6 +45,63 @@ func Test_MongoDriverConnection(t *testing.T) { t.Logf("got insert id: %s", ret.InsertedID.(primitive.ObjectID).String()) } +func Test_MongoDriverConnectionBy(t *testing.T) { + ctx, cancelFn := context.WithTimeout(context.Background(), 3*time.Second) + defer cancelFn() + + db.SetupMany(&db.MongoConfig{ + MongoDB: "mongodb://127.0.0.1:27017", + DBName: "test", + PoolLimit: 30, + }, &db.MongoConfig{ + MongoDB: "mongodb://127.0.0.1:27017", + DBName: "test_2", + PoolLimit: 30, + }) + + md, err := db.NewMongoDriverBy(ctx, db.WithDBName("test")) + if err != nil { + t.Fatalf("failed to new mongo driver: %s", err) + } + defer md.Close() + + const ( + collectionName = "Test" + ) + + col := md.GetCol(collectionName) + ret, err := col.InsertOne(ctx, bson.M{ + "tid": 1, + "ezorm": "mongo_driver_support", + "create_date": time.Now().Unix(), + }) + if err != nil { + t.Fatalf("failed to insert to collection: %s", err) + } + t.Logf("got insert id: %s", ret.InsertedID.(primitive.ObjectID).String()) + + md2, err := db.NewMongoDriverBy(ctx, db.WithDBName("test_2")) + if err != nil { + t.Fatalf("failed to new mongo driver: %s", err) + } + defer md2.Close() + + const ( + collectionName2 = "Test" + ) + + col2 := md.GetCol(collectionName2) + ret2, err := col2.InsertOne(ctx, bson.M{ + "tid": 1, + "ezorm": "mongo_driver_support", + "create_date": time.Now().Unix(), + }) + if err != nil { + t.Fatalf("failed to insert to collection: %s", err) + } + t.Logf("got insert id: %s", ret2.InsertedID.(primitive.ObjectID).String()) +} + func Test_MongoDriverConnPool(t *testing.T) { connIds := make(map[uint64]int) connIdsLock := new(sync.Mutex)