From 974dd52edeb27f2f220a5301f6da130a8146aa22 Mon Sep 17 00:00:00 2001 From: Bowen Date: Tue, 1 Oct 2024 11:00:35 +0800 Subject: [PATCH 01/11] chore: remove wire --- contracts/database/config.go | 32 +++ contracts/database/gorm/wire_interface.go | 19 -- contracts/database/orm/constants.go | 1 + contracts/database/orm/orm.go | 2 + contracts/foundation/application.go | 2 + database/console/migrate.go | 18 +- database/db/config.go | 76 ------- database/db/config_test.go | 122 ----------- database/db/configs.go | 88 ++++++++ database/db/configs_test.go | 232 ++++++++++++++++++++ database/db/dsn.go | 89 +++----- database/db/dsn_test.go | 110 +++++----- database/gorm/cursor.go | 2 +- database/gorm/dialector.go | 109 ++------- database/gorm/dialector_test.go | 169 +++++++------- database/gorm/event.go | 6 +- database/gorm/event_test.go | 10 +- database/gorm/gorm.go | 71 +++--- database/gorm/query.go | 256 +++++++++++----------- database/gorm/query_test.go | 2 +- database/gorm/test_utils.go | 12 +- database/gorm/to_sql.go | 4 +- database/gorm/to_sql_test.go | 68 +++--- database/gorm/wire.go | 29 --- database/gorm/wire_gen.go | 35 --- database/orm.go | 63 +++--- database/orm_test.go | 4 +- database/service_provider.go | 19 +- database/wire.go | 23 -- database/wire_gen.go | 32 --- database/wire_interface.go | 24 -- foundation/container.go | 4 + go.mod | 1 - go.sum | 18 -- mocks/database/Configs.go | 129 +++++++++++ mocks/database/gorm/Gorm.go | 92 -------- mocks/database/gorm/Initialize.go | 151 ------------- mocks/database/orm/Orm.go | 32 +++ mocks/foundation/Application.go | 33 +++ testing/docker/database.go | 35 +-- 40 files changed, 1021 insertions(+), 1203 deletions(-) delete mode 100644 contracts/database/gorm/wire_interface.go delete mode 100644 database/db/config.go delete mode 100644 database/db/config_test.go create mode 100644 database/db/configs.go create mode 100644 database/db/configs_test.go delete mode 100644 database/gorm/wire.go delete mode 100644 database/gorm/wire_gen.go delete mode 100644 database/wire.go delete mode 100644 database/wire_gen.go delete mode 100644 database/wire_interface.go create mode 100644 mocks/database/Configs.go delete mode 100644 mocks/database/gorm/Gorm.go delete mode 100644 mocks/database/gorm/Initialize.go diff --git a/contracts/database/config.go b/contracts/database/config.go index 3e03fbe57..641d3c8b5 100644 --- a/contracts/database/config.go +++ b/contracts/database/config.go @@ -1,5 +1,19 @@ package database +const ( + DriverMysql Driver = "mysql" + DriverPostgres Driver = "postgres" + DriverSqlite Driver = "sqlite" + DriverSqlserver Driver = "sqlserver" +) + +type Driver string + +func (d Driver) String() string { + return string(d) +} + +// Config Used in config/database.go type Config struct { Host string Port int @@ -7,3 +21,21 @@ type Config struct { Username string Password string } + +// FullConfig Fill the default value for Config +type FullConfig struct { + Config + Driver Driver + Connection string + Prefix string + Singular bool + Charset string // Mysql, Sqlserver + Loc string // Mysql + Sslmode string // Postgres + Timezone string // Postgres +} + +type Configs interface { + Reads() []FullConfig + Writes() []FullConfig +} diff --git a/contracts/database/gorm/wire_interface.go b/contracts/database/gorm/wire_interface.go deleted file mode 100644 index 841eeb3ce..000000000 --- a/contracts/database/gorm/wire_interface.go +++ /dev/null @@ -1,19 +0,0 @@ -package gorm - -import ( - "context" - - gormio "gorm.io/gorm" - - "github.com/goravel/framework/contracts/config" - "github.com/goravel/framework/contracts/database/orm" -) - -type Gorm interface { - Make() (*gormio.DB, error) -} - -type Initialize interface { - InitializeGorm(config config.Config, connection string) Gorm - InitializeQuery(ctx context.Context, config config.Config, connection string) (orm.Query, error) -} diff --git a/contracts/database/orm/constants.go b/contracts/database/orm/constants.go index c66bb6600..1217b0c32 100644 --- a/contracts/database/orm/constants.go +++ b/contracts/database/orm/constants.go @@ -1,5 +1,6 @@ package orm +// DEPRECATED Move to contracts/database/config.go const ( DriverMysql Driver = "mysql" DriverPostgres Driver = "postgres" diff --git a/contracts/database/orm/orm.go b/contracts/database/orm/orm.go index 108d8d27e..e15846ba7 100644 --- a/contracts/database/orm/orm.go +++ b/contracts/database/orm/orm.go @@ -16,6 +16,8 @@ type Orm interface { Factory() Factory // Observe registers an observer with the Orm. Observe(model any, observer Observer) + // Refresh resets the Orm instance. + Refresh() // Transaction runs a callback wrapped in a database transaction. Transaction(txFunc func(tx Query) error) error // WithContext sets the context to be used by the Orm. diff --git a/contracts/foundation/application.go b/contracts/foundation/application.go index 5694815c6..c194f9898 100644 --- a/contracts/foundation/application.go +++ b/contracts/foundation/application.go @@ -123,6 +123,8 @@ type Application interface { MakeSeeder() seeder.Facade // MakeWith resolves the given type with the given parameters from the container. MakeWith(key any, parameters map[string]any) (any, error) + // Refresh an instance on the given target. + Refresh(key any) // Singleton registers a shared binding in the container. Singleton(key any, callback func(app Application) (any, error)) } diff --git a/database/console/migrate.go b/database/console/migrate.go index 077006783..2b449e396 100644 --- a/database/console/migrate.go +++ b/database/console/migrate.go @@ -13,7 +13,7 @@ import ( "github.com/goravel/framework/contracts/config" "github.com/goravel/framework/contracts/database/orm" "github.com/goravel/framework/database/console/driver" - "github.com/goravel/framework/database/db" + databasedb "github.com/goravel/framework/database/db" "github.com/goravel/framework/support" ) @@ -25,16 +25,15 @@ func getMigrate(config config.Config) (*migrate.Migrate, error) { dir = fmt.Sprintf("file://%s/database/migrations", support.RelativePath) } - gormConfig := db.NewConfigImpl(config, connection) - writeConfigs := gormConfig.Writes() + configs := databasedb.NewConfigs(config, connection) + writeConfigs := configs.Writes() if len(writeConfigs) == 0 { return nil, errors.New("not found database configuration") } switch orm.Driver(driver) { case orm.DriverMysql: - dsn := db.NewDsnImpl(config, connection) - mysqlDsn := dsn.Mysql(writeConfigs[0]) + mysqlDsn := databasedb.Dsn(writeConfigs[0]) if mysqlDsn == "" { return nil, nil } @@ -53,8 +52,7 @@ func getMigrate(config config.Config) (*migrate.Migrate, error) { return migrate.NewWithDatabaseInstance(dir, "mysql", instance) case orm.DriverPostgres: - dsn := db.NewDsnImpl(config, connection) - postgresDsn := dsn.Postgres(writeConfigs[0]) + postgresDsn := databasedb.Dsn(writeConfigs[0]) if postgresDsn == "" { return nil, nil } @@ -73,8 +71,7 @@ func getMigrate(config config.Config) (*migrate.Migrate, error) { return migrate.NewWithDatabaseInstance(dir, "postgres", instance) case orm.DriverSqlite: - dsn := db.NewDsnImpl(config, "") - sqliteDsn := dsn.Sqlite(writeConfigs[0]) + sqliteDsn := databasedb.Dsn(writeConfigs[0]) if sqliteDsn == "" { return nil, nil } @@ -93,8 +90,7 @@ func getMigrate(config config.Config) (*migrate.Migrate, error) { return migrate.NewWithDatabaseInstance(dir, "sqlite3", instance) case orm.DriverSqlserver: - dsn := db.NewDsnImpl(config, connection) - sqlserverDsn := dsn.Sqlserver(writeConfigs[0]) + sqlserverDsn := databasedb.Dsn(writeConfigs[0]) if sqlserverDsn == "" { return nil, nil } diff --git a/database/db/config.go b/database/db/config.go deleted file mode 100644 index 9ef415022..000000000 --- a/database/db/config.go +++ /dev/null @@ -1,76 +0,0 @@ -package db - -import ( - "fmt" - - "github.com/google/wire" - - "github.com/goravel/framework/contracts/config" - databasecontract "github.com/goravel/framework/contracts/database" - "github.com/goravel/framework/contracts/database/orm" -) - -var ConfigSet = wire.NewSet(NewConfigImpl, wire.Bind(new(Config), new(*ConfigImpl))) -var _ Config = &ConfigImpl{} - -type Config interface { - Reads() []databasecontract.Config - Writes() []databasecontract.Config -} - -type ConfigImpl struct { - config config.Config - connection string -} - -func NewConfigImpl(config config.Config, connection string) *ConfigImpl { - return &ConfigImpl{ - config: config, - connection: connection, - } -} - -func (c *ConfigImpl) Reads() []databasecontract.Config { - configs := c.config.Get(fmt.Sprintf("database.connections.%s.read", c.connection)) - if configs, ok := configs.([]databasecontract.Config); ok { - return c.fillDefault(configs) - } - - return []databasecontract.Config{} -} - -func (c *ConfigImpl) Writes() []databasecontract.Config { - configs := c.config.Get(fmt.Sprintf("database.connections.%s.write", c.connection)) - if configs, ok := configs.([]databasecontract.Config); ok { - return c.fillDefault(configs) - } - - return c.fillDefault([]databasecontract.Config{{}}) -} - -func (c *ConfigImpl) fillDefault(configs []databasecontract.Config) []databasecontract.Config { - var newConfigs []databasecontract.Config - driver := c.config.GetString(fmt.Sprintf("database.connections.%s.driver", c.connection)) - for _, item := range configs { - if driver != orm.DriverSqlite.String() { - if item.Host == "" { - item.Host = c.config.GetString(fmt.Sprintf("database.connections.%s.host", c.connection)) - } - if item.Port == 0 { - item.Port = c.config.GetInt(fmt.Sprintf("database.connections.%s.port", c.connection)) - } - if item.Username == "" { - item.Username = c.config.GetString(fmt.Sprintf("database.connections.%s.username", c.connection)) - } - if item.Password == "" { - item.Password = c.config.GetString(fmt.Sprintf("database.connections.%s.password", c.connection)) - } - } - if item.Database == "" { - item.Database = c.config.GetString(fmt.Sprintf("database.connections.%s.database", c.connection)) - } - newConfigs = append(newConfigs, item) - } - - return newConfigs -} diff --git a/database/db/config_test.go b/database/db/config_test.go deleted file mode 100644 index fd3d986fd..000000000 --- a/database/db/config_test.go +++ /dev/null @@ -1,122 +0,0 @@ -package db - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/suite" - - databasecontract "github.com/goravel/framework/contracts/database" - configmock "github.com/goravel/framework/mocks/config" -) - -type ConfigTestSuite struct { - suite.Suite - config *ConfigImpl - connection string - mockConfig *configmock.Config -} - -func TestConfigTestSuite(t *testing.T) { - suite.Run(t, &ConfigTestSuite{ - connection: "mysql", - }) -} - -func (s *ConfigTestSuite) SetupTest() { - s.mockConfig = &configmock.Config{} - s.config = NewConfigImpl(s.mockConfig, s.connection) -} - -func (s *ConfigTestSuite) TestFillDefaultForConfigs() { - host := "localhost" - port := 3306 - database := "forge" - username := "root" - password := "123123" - - tests := []struct { - name string - configs []databasecontract.Config - setup func() - expectConfigs []databasecontract.Config - }{ - { - name: "success when configs is empty", - configs: []databasecontract.Config{}, - setup: func() { - s.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.driver", s.connection)).Return("mysql").Once() - }, - }, - { - name: "success when configs have item but key is empty", - configs: []databasecontract.Config{{}}, - setup: func() { - s.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.driver", s.connection)).Return("mysql").Once() - s.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.host", s.connection)).Return(host).Once() - s.mockConfig.On("GetInt", fmt.Sprintf("database.connections.%s.port", s.connection)).Return(port).Once() - s.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.database", s.connection)).Return(database).Once() - s.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.username", s.connection)).Return(username).Once() - s.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.password", s.connection)).Return(password).Once() - }, - expectConfigs: []databasecontract.Config{ - { - Host: host, - Port: port, - Database: database, - Username: username, - Password: password, - }, - }, - }, - { - name: "success when configs have item", - configs: []databasecontract.Config{ - { - Host: "localhost", - Port: 3306, - Database: "forge", - Username: "root", - Password: "123123", - }, - }, - setup: func() { - s.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.driver", s.connection)).Return("mysql").Once() - }, - expectConfigs: []databasecontract.Config{ - { - Host: "localhost", - Port: 3306, - Database: "forge", - Username: "root", - Password: "123123", - }, - }, - }, - { - name: "success when sqlite", - configs: []databasecontract.Config{ - { - Database: "forge", - }, - }, - setup: func() { - s.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.driver", s.connection)).Return("sqlite").Once() - }, - expectConfigs: []databasecontract.Config{ - { - Database: "forge", - }, - }, - }, - } - - for _, test := range tests { - s.Run(test.name, func() { - test.setup() - configs := s.config.fillDefault(test.configs) - s.Equal(test.expectConfigs, configs) - s.mockConfig.AssertExpectations(s.T()) - }) - } -} diff --git a/database/db/configs.go b/database/db/configs.go new file mode 100644 index 000000000..1e8315206 --- /dev/null +++ b/database/db/configs.go @@ -0,0 +1,88 @@ +package db + +import ( + "fmt" + + contractsconfig "github.com/goravel/framework/contracts/config" + "github.com/goravel/framework/contracts/database" +) + +type Configs struct { + config contractsconfig.Config + connection string +} + +func NewConfigs(config contractsconfig.Config, connection string) *Configs { + return &Configs{ + config: config, + connection: connection, + } +} + +func (c *Configs) Reads() []database.FullConfig { + configs := c.config.Get(fmt.Sprintf("database.connections.%s.read", c.connection)) + if configs, ok := configs.([]database.Config); ok { + return c.fillDefault(configs) + } + + return nil +} + +func (c *Configs) Writes() []database.FullConfig { + configs := c.config.Get(fmt.Sprintf("database.connections.%s.write", c.connection)) + if configs, ok := configs.([]database.Config); ok { + return c.fillDefault(configs) + } + + // Use default db configuration when write is empty + return c.fillDefault([]database.Config{{}}) +} + +func (c *Configs) fillDefault(configs []database.Config) []database.FullConfig { + if len(configs) == 0 { + return nil + } + + var fullConfigs []database.FullConfig + driver := database.Driver(c.config.GetString(fmt.Sprintf("database.connections.%s.driver", c.connection))) + + for _, config := range configs { + fullConfig := database.FullConfig{ + Config: config, + Connection: c.connection, + Driver: driver, + Prefix: c.config.GetString(fmt.Sprintf("database.connections.%s.prefix", c.connection)), + Singular: c.config.GetBool(fmt.Sprintf("database.connections.%s.singular", c.connection)), + } + if driver != database.DriverSqlite { + if fullConfig.Host == "" { + fullConfig.Host = c.config.GetString(fmt.Sprintf("database.connections.%s.host", c.connection)) + } + if fullConfig.Port == 0 { + fullConfig.Port = c.config.GetInt(fmt.Sprintf("database.connections.%s.port", c.connection)) + } + if fullConfig.Username == "" { + fullConfig.Username = c.config.GetString(fmt.Sprintf("database.connections.%s.username", c.connection)) + } + if fullConfig.Password == "" { + fullConfig.Password = c.config.GetString(fmt.Sprintf("database.connections.%s.password", c.connection)) + } + if driver == database.DriverMysql { + fullConfig.Charset = c.config.GetString(fmt.Sprintf("database.connections.%s.charset", c.connection)) + } + if driver == database.DriverMysql || driver == database.DriverSqlserver { + fullConfig.Loc = c.config.GetString(fmt.Sprintf("database.connections.%s.loc", c.connection)) + } + if driver == database.DriverPostgres { + fullConfig.Sslmode = c.config.GetString(fmt.Sprintf("database.connections.%s.sslmode", c.connection)) + fullConfig.Timezone = c.config.GetString(fmt.Sprintf("database.connections.%s.timezone", c.connection)) + } + } + if config.Database == "" { + fullConfig.Database = c.config.GetString(fmt.Sprintf("database.connections.%s.database", c.connection)) + } + fullConfigs = append(fullConfigs, fullConfig) + } + + return fullConfigs +} diff --git a/database/db/configs_test.go b/database/db/configs_test.go new file mode 100644 index 000000000..08db5c588 --- /dev/null +++ b/database/db/configs_test.go @@ -0,0 +1,232 @@ +package db + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/suite" + + contractsdatabase "github.com/goravel/framework/contracts/database" + mocksconfig "github.com/goravel/framework/mocks/config" +) + +type ConfigTestSuite struct { + suite.Suite + configs *Configs + connection string + mockConfig *mocksconfig.Config +} + +func TestConfigTestSuite(t *testing.T) { + suite.Run(t, &ConfigTestSuite{ + connection: "mysql", + }) +} + +func (s *ConfigTestSuite) SetupTest() { + s.mockConfig = mocksconfig.NewConfig(s.T()) + s.configs = NewConfigs(s.mockConfig, s.connection) +} + +func (s *ConfigTestSuite) TestReads() { + database := "forge" + prefix := "goravel_" + singular := false + + // Test when configs is empty + s.mockConfig.EXPECT().Get("database.connections.mysql.read").Return(nil).Once() + s.Nil(s.configs.Reads()) + + // Test when configs is not empty + s.mockConfig.EXPECT().Get("database.connections.mysql.read").Return([]contractsdatabase.Config{ + { + Database: database, + }, + }).Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.prefix", s.connection)).Return(prefix).Once() + s.mockConfig.EXPECT().GetBool(fmt.Sprintf("database.connections.%s.singular", s.connection)).Return(singular).Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.driver", s.connection)).Return(contractsdatabase.DriverSqlite.String()).Once() + + s.Equal([]contractsdatabase.FullConfig{ + { + Connection: s.connection, + Driver: contractsdatabase.DriverSqlite, + Prefix: prefix, + Config: contractsdatabase.Config{ + Database: database, + }, + }, + }, s.configs.Reads()) +} + +func (s *ConfigTestSuite) TestWrites() { + database := "forge" + prefix := "goravel_" + singular := false + + // Test when configs is empty + s.mockConfig.EXPECT().Get("database.connections.mysql.write").Return(nil).Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.driver", s.connection)).Return(contractsdatabase.DriverSqlite.String()).Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.database", s.connection)).Return(database).Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.prefix", s.connection)).Return(prefix).Once() + s.mockConfig.EXPECT().GetBool(fmt.Sprintf("database.connections.%s.singular", s.connection)).Return(singular).Once() + + s.Equal([]contractsdatabase.FullConfig{ + { + Connection: s.connection, + Driver: contractsdatabase.DriverSqlite, + Prefix: prefix, + Config: contractsdatabase.Config{ + Database: database, + }, + }, + }, s.configs.Writes()) + + // Test when configs is not empty + s.mockConfig.EXPECT().Get("database.connections.mysql.write").Return([]contractsdatabase.Config{ + { + Database: database, + }, + }).Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.driver", s.connection)).Return(contractsdatabase.DriverSqlite.String()).Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.prefix", s.connection)).Return(prefix).Once() + s.mockConfig.EXPECT().GetBool(fmt.Sprintf("database.connections.%s.singular", s.connection)).Return(singular).Once() + + s.Equal([]contractsdatabase.FullConfig{ + { + Connection: s.connection, + Driver: contractsdatabase.DriverSqlite, + Prefix: prefix, + Config: contractsdatabase.Config{ + Database: database, + }, + }, + }, s.configs.Writes()) +} + +func (s *ConfigTestSuite) TestFillDefault() { + host := "localhost" + port := 3306 + database := "forge" + username := "root" + password := "123123" + prefix := "goravel_" + singular := false + charset := "utf8mb4" + loc := "Local" + + tests := []struct { + name string + configs []contractsdatabase.Config + setup func() + expectConfigs []contractsdatabase.FullConfig + }{ + { + name: "success when configs is empty", + setup: func() {}, + configs: []contractsdatabase.Config{}, + }, + { + name: "success when configs have item but key is empty", + configs: []contractsdatabase.Config{{}}, + setup: func() { + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.prefix", s.connection)).Return(prefix).Once() + s.mockConfig.EXPECT().GetBool(fmt.Sprintf("database.connections.%s.singular", s.connection)).Return(singular).Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.driver", s.connection)).Return("mysql").Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.host", s.connection)).Return(host).Once() + s.mockConfig.EXPECT().GetInt(fmt.Sprintf("database.connections.%s.port", s.connection)).Return(port).Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.database", s.connection)).Return(database).Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.username", s.connection)).Return(username).Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.password", s.connection)).Return(password).Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.charset", s.connection)).Return(charset).Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.loc", s.connection)).Return(loc).Once() + }, + expectConfigs: []contractsdatabase.FullConfig{ + { + Connection: s.connection, + Driver: contractsdatabase.DriverMysql, + Prefix: prefix, + Singular: singular, + Charset: charset, + Loc: loc, + Config: contractsdatabase.Config{ + Host: host, + Port: port, + Database: database, + Username: username, + Password: password, + }, + }, + }, + }, + { + name: "success when configs have item", + configs: []contractsdatabase.Config{ + { + Host: host, + Port: port, + Database: database, + Username: username, + Password: password, + }, + }, + setup: func() { + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.driver", s.connection)).Return("mysql").Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.prefix", s.connection)).Return(prefix).Once() + s.mockConfig.EXPECT().GetBool(fmt.Sprintf("database.connections.%s.singular", s.connection)).Return(singular).Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.charset", s.connection)).Return(charset).Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.loc", s.connection)).Return(loc).Once() + }, + expectConfigs: []contractsdatabase.FullConfig{ + { + Connection: s.connection, + Driver: contractsdatabase.DriverMysql, + Prefix: prefix, + Singular: singular, + Charset: charset, + Loc: loc, + Config: contractsdatabase.Config{ + Database: database, + Host: host, + Port: port, + Username: username, + Password: password, + }, + }, + }, + }, + { + name: "success when sqlite", + configs: []contractsdatabase.Config{ + { + Database: database, + }, + }, + setup: func() { + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.prefix", s.connection)).Return(prefix).Once() + s.mockConfig.EXPECT().GetBool(fmt.Sprintf("database.connections.%s.singular", s.connection)).Return(singular).Once() + s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.driver", s.connection)).Return("sqlite").Once() + }, + expectConfigs: []contractsdatabase.FullConfig{ + { + Connection: s.connection, + Driver: contractsdatabase.DriverSqlite, + Prefix: prefix, + Singular: singular, + Config: contractsdatabase.Config{ + Database: database, + }, + }, + }, + }, + } + + for _, test := range tests { + s.Run(test.name, func() { + test.setup() + configs := s.configs.fillDefault(test.configs) + + s.Equal(test.expectConfigs, configs) + }) + } +} diff --git a/database/db/dsn.go b/database/db/dsn.go index 47c0e1764..492a64b45 100644 --- a/database/db/dsn.go +++ b/database/db/dsn.go @@ -3,67 +3,38 @@ package db import ( "fmt" - "github.com/goravel/framework/contracts/config" - databasecontract "github.com/goravel/framework/contracts/database" + "github.com/goravel/framework/contracts/database" ) -type Dsn interface { - Mysql(config databasecontract.Config) string - Postgres(config databasecontract.Config) string - Sqlite(config databasecontract.Config) string - Sqlserver(config databasecontract.Config) string -} - -type DsnImpl struct { - config config.Config - connection string -} - -func NewDsnImpl(config config.Config, connection string) *DsnImpl { - return &DsnImpl{ - config: config, - connection: connection, - } -} - -func (d *DsnImpl) Mysql(config databasecontract.Config) string { - host := config.Host - if host == "" { +func Dsn(config database.FullConfig) string { + switch config.Driver { + case database.DriverMysql: + host := config.Host + if host == "" { + return "" + } + + return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=%t&loc=%s&multiStatements=true", + config.Username, config.Password, host, config.Port, config.Database, config.Charset, true, config.Loc) + case database.DriverPostgres: + host := config.Host + if host == "" { + return "" + } + + return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s&timezone=%s", + config.Username, config.Password, host, config.Port, config.Database, config.Sslmode, config.Timezone) + case database.DriverSqlite: + return fmt.Sprintf("%s?multi_stmts=true", config.Database) + case database.DriverSqlserver: + host := config.Host + if host == "" { + return "" + } + + return fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s&charset=%s&MultipleActiveResultSets=true", + config.Username, config.Password, host, config.Port, config.Database, config.Charset) + default: return "" } - - charset := d.config.GetString("database.connections." + d.connection + ".charset") - loc := d.config.GetString("database.connections." + d.connection + ".loc") - - return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=%t&loc=%s&multiStatements=true", - config.Username, config.Password, host, config.Port, config.Database, charset, true, loc) -} - -func (d *DsnImpl) Postgres(config databasecontract.Config) string { - host := config.Host - if host == "" { - return "" - } - - sslmode := d.config.GetString("database.connections." + d.connection + ".sslmode") - timezone := d.config.GetString("database.connections." + d.connection + ".timezone") - - return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s&timezone=%s", - config.Username, config.Password, host, config.Port, config.Database, sslmode, timezone) -} - -func (d *DsnImpl) Sqlite(config databasecontract.Config) string { - return fmt.Sprintf("%s?multi_stmts=true", config.Database) -} - -func (d *DsnImpl) Sqlserver(config databasecontract.Config) string { - host := config.Host - if host == "" { - return "" - } - - charset := d.config.GetString("database.connections." + d.connection + ".charset") - - return fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s&charset=%s&MultipleActiveResultSets=true", - config.Username, config.Password, host, config.Port, config.Database, charset) } diff --git a/database/db/dsn_test.go b/database/db/dsn_test.go index 83dc5ddb5..7a5ca4245 100644 --- a/database/db/dsn_test.go +++ b/database/db/dsn_test.go @@ -4,11 +4,9 @@ import ( "fmt" "testing" - "github.com/stretchr/testify/suite" + "github.com/stretchr/testify/assert" - databasecontract "github.com/goravel/framework/contracts/database" - "github.com/goravel/framework/contracts/database/orm" - configmock "github.com/goravel/framework/mocks/config" + "github.com/goravel/framework/contracts/database" ) const ( @@ -19,7 +17,7 @@ const ( testPassword = "123123" ) -var testConfig = databasecontract.Config{ +var testConfig = database.Config{ Host: testHost, Port: testPort, Database: testDatabase, @@ -27,54 +25,58 @@ var testConfig = databasecontract.Config{ Password: testPassword, } -type DsnTestSuite struct { - suite.Suite - mockConfig *configmock.Config -} - -func TestDsnTestSuite(t *testing.T) { - suite.Run(t, new(DsnTestSuite)) -} - -func (s *DsnTestSuite) SetupTest() { - s.mockConfig = &configmock.Config{} -} - -func (s *DsnTestSuite) TestMysql() { - connection := orm.DriverMysql.String() - dsn := NewDsnImpl(s.mockConfig, connection) - charset := "utf8mb4" - loc := "Local" - s.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.charset", connection)).Return(charset).Once() - s.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.loc", connection)).Return(loc).Once() - - s.Equal(fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=%t&loc=%s&multiStatements=true", - testUsername, testPassword, testHost, testPort, testDatabase, charset, true, loc), dsn.Mysql(testConfig)) -} - -func (s *DsnTestSuite) TestPostgres() { - connection := orm.DriverPostgres.String() - dsn := NewDsnImpl(s.mockConfig, connection) - sslmode := "disable" - timezone := "UTC" - s.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.sslmode", connection)).Return(sslmode).Once() - s.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.timezone", connection)).Return(timezone).Once() - - s.Equal(fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s&timezone=%s", - testUsername, testPassword, testHost, testPort, testDatabase, sslmode, timezone), dsn.Postgres(testConfig)) -} - -func (s *DsnTestSuite) TestSqlite() { - dsn := NewDsnImpl(s.mockConfig, "") - s.Equal(fmt.Sprintf("%s?multi_stmts=true", testDatabase), dsn.Sqlite(testConfig)) -} - -func (s *DsnTestSuite) TestSqlserver() { - connection := orm.DriverSqlserver.String() - dsn := NewDsnImpl(s.mockConfig, connection) - charset := "utf8mb4" - s.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.charset", connection)).Return(charset).Once() +func TestDsn(t *testing.T) { + tests := []struct { + name string + config database.FullConfig + expectDsn string + }{ + { + name: "mysql", + config: database.FullConfig{ + Config: testConfig, + Driver: database.DriverMysql, + Charset: "utf8mb4", + Loc: "Local", + }, + expectDsn: fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=%t&loc=%s&multiStatements=true", + testUsername, testPassword, testHost, testPort, testDatabase, "utf8mb4", true, "Local"), + }, + { + name: "postgres", + config: database.FullConfig{ + Config: testConfig, + Driver: database.DriverPostgres, + Sslmode: "disable", + Timezone: "UTC", + }, + expectDsn: fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s&timezone=%s", + testUsername, testPassword, testHost, testPort, testDatabase, "disable", "UTC"), + }, + { + name: "sqlite", + config: database.FullConfig{ + Config: testConfig, + Driver: database.DriverSqlite, + }, + expectDsn: fmt.Sprintf("%s?multi_stmts=true", testDatabase), + }, + { + name: "sqlserver", + config: database.FullConfig{ + Config: testConfig, + Driver: database.DriverSqlserver, + Charset: "utf8mb4", + }, + expectDsn: fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s&charset=%s&MultipleActiveResultSets=true", + testUsername, testPassword, testHost, testPort, testDatabase, "utf8mb4"), + }, + } - s.Equal(fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s&charset=%s&MultipleActiveResultSets=true", - testUsername, testPassword, testHost, testPort, testDatabase, charset), dsn.Sqlserver(testConfig)) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dsn := Dsn(test.config) + assert.Equal(t, test.expectDsn, dsn) + }) + } } diff --git a/database/gorm/cursor.go b/database/gorm/cursor.go index b03f14e34..b208397df 100644 --- a/database/gorm/cursor.go +++ b/database/gorm/cursor.go @@ -13,7 +13,7 @@ import ( ) type CursorImpl struct { - query *QueryImpl + query *Query row map[string]any } diff --git a/database/gorm/dialector.go b/database/gorm/dialector.go index d8d34f3e6..5d44f3ad9 100644 --- a/database/gorm/dialector.go +++ b/database/gorm/dialector.go @@ -4,61 +4,42 @@ import ( "fmt" "github.com/glebarez/sqlite" - "github.com/google/wire" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlserver" "gorm.io/gorm" - "github.com/goravel/framework/contracts/config" - databasecontract "github.com/goravel/framework/contracts/database" - "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/database/db" ) -var DialectorSet = wire.NewSet(NewDialectorImpl, wire.Bind(new(Dialector), new(*DialectorImpl))) -var _ Dialector = &DialectorImpl{} - -type Dialector interface { - Make(configs []databasecontract.Config) ([]gorm.Dialector, error) -} - -type DialectorImpl struct { - config config.Config - connection string - dsn db.Dsn -} - -func NewDialectorImpl(config config.Config, connection string) *DialectorImpl { - return &DialectorImpl{ - config: config, - connection: connection, - dsn: db.NewDsnImpl(config, connection), - } -} - -func (d *DialectorImpl) Make(configs []databasecontract.Config) ([]gorm.Dialector, error) { - driver := d.config.GetString(fmt.Sprintf("database.connections.%s.driver", d.connection)) - +func GetDialectors(configs []database.FullConfig) ([]gorm.Dialector, error) { var dialectors []gorm.Dialector - for _, item := range configs { + + for _, config := range configs { var dialector gorm.Dialector - var err error - switch orm.Driver(driver) { - case orm.DriverMysql: - dialector = d.mysql(item) - case orm.DriverPostgres: - dialector = d.postgres(item) - case orm.DriverSqlite: - dialector = d.sqlite(item) - case orm.DriverSqlserver: - dialector = d.sqlserver(item) - default: - err = fmt.Errorf("err database driver: %s, only support mysql, postgres, sqlite and sqlserver", driver) + dsn := db.Dsn(config) + if dsn == "" { + return nil, fmt.Errorf("failed to get dsn for %s", config.Connection) } - if err != nil { - return nil, err + switch config.Driver { + case database.DriverMysql: + dialector = mysql.New(mysql.Config{ + DSN: dsn, + }) + case database.DriverPostgres: + dialector = postgres.New(postgres.Config{ + DSN: dsn, + }) + case database.DriverSqlite: + dialector = sqlite.Open(dsn) + case database.DriverSqlserver: + dialector = sqlserver.New(sqlserver.Config{ + DSN: dsn, + }) + default: + return nil, fmt.Errorf("err database driver: %s, only support mysql, postgres, sqlite and sqlserver", config.Driver) } dialectors = append(dialectors, dialector) @@ -66,45 +47,3 @@ func (d *DialectorImpl) Make(configs []databasecontract.Config) ([]gorm.Dialecto return dialectors, nil } - -func (d *DialectorImpl) mysql(config databasecontract.Config) gorm.Dialector { - dsn := d.dsn.Mysql(config) - if dsn == "" { - return nil - } - - return mysql.New(mysql.Config{ - DSN: dsn, - }) -} - -func (d *DialectorImpl) postgres(config databasecontract.Config) gorm.Dialector { - dsn := d.dsn.Postgres(config) - if dsn == "" { - return nil - } - - return postgres.New(postgres.Config{ - DSN: dsn, - }) -} - -func (d *DialectorImpl) sqlite(config databasecontract.Config) gorm.Dialector { - dsn := d.dsn.Sqlite(config) - if dsn == "" { - return nil - } - - return sqlite.Open(dsn) -} - -func (d *DialectorImpl) sqlserver(config databasecontract.Config) gorm.Dialector { - dsn := d.dsn.Sqlserver(config) - if dsn == "" { - return nil - } - - return sqlserver.New(sqlserver.Config{ - DSN: dsn, - }) -} diff --git a/database/gorm/dialector_test.go b/database/gorm/dialector_test.go index f39ec21db..b27c5e9e4 100644 --- a/database/gorm/dialector_test.go +++ b/database/gorm/dialector_test.go @@ -1,97 +1,90 @@ package gorm import ( - "fmt" "testing" - - "github.com/glebarez/sqlite" - "github.com/stretchr/testify/suite" - "gorm.io/driver/mysql" - "gorm.io/driver/postgres" - "gorm.io/driver/sqlserver" - - databasecontract "github.com/goravel/framework/contracts/database" - "github.com/goravel/framework/contracts/database/orm" - configmock "github.com/goravel/framework/mocks/config" ) -type DialectorTestSuite struct { - suite.Suite - mockConfig *configmock.Config - config databasecontract.Config -} - -func TestDialectorTestSuite(t *testing.T) { - suite.Run(t, &DialectorTestSuite{ - config: databasecontract.Config{ - Host: "localhost", - Port: 3306, - Database: "forge", - Username: "root", - Password: "123123", - }, - }) -} - -func (s *DialectorTestSuite) SetupTest() { - s.mockConfig = &configmock.Config{} -} - -func (s *DialectorTestSuite) TestMysql() { - dialector := NewDialectorImpl(s.mockConfig, orm.DriverMysql.String()) - s.mockConfig.On("GetString", "database.connections.mysql.driver"). - Return(orm.DriverMysql.String()).Once() - s.mockConfig.On("GetString", "database.connections.mysql.charset"). - Return("utf8mb4").Once() - s.mockConfig.On("GetString", "database.connections.mysql.loc"). - Return("Local").Once() - dialectors, err := dialector.Make([]databasecontract.Config{s.config}) - s.Nil(err) - s.NotEmpty(dialectors) - s.Equal(mysql.New(mysql.Config{ - DSN: fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=%t&loc=%s&multiStatements=true", - s.config.Username, s.config.Password, s.config.Host, s.config.Port, s.config.Database, "utf8mb4", true, "Local"), - }), dialectors[0]) -} +//type DialectorTestSuite struct { +// suite.Suite +// mockConfig *configmock.Config +// config databasecontract.Config +//} +// +//func TestDialectorTestSuite(t *testing.T) { +// suite.Run(t, &DialectorTestSuite{ +// config: databasecontract.Config{ +// Host: "localhost", +// Port: 3306, +// Database: "forge", +// Username: "root", +// Password: "123123", +// }, +// }) +//} +// +//func (s *DialectorTestSuite) SetupTest() { +// s.mockConfig = &configmock.Config{} +//} +// +//func (s *DialectorTestSuite) TestMysql() { +// dialector := NewDialector(s.mockConfig, orm.DriverMysql.String()) +// s.mockConfig.On("GetString", "database.connections.mysql.driver"). +// Return(orm.DriverMysql.String()).Once() +// s.mockConfig.On("GetString", "database.connections.mysql.charset"). +// Return("utf8mb4").Once() +// s.mockConfig.On("GetString", "database.connections.mysql.loc"). +// Return("Local").Once() +// dialectors, err := dialector.Make([]databasecontract.Config{s.config}) +// s.Nil(err) +// s.NotEmpty(dialectors) +// s.Equal(mysql.New(mysql.Config{ +// DSN: fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=%t&loc=%s&multiStatements=true", +// s.config.Username, s.config.Password, s.config.Host, s.config.Port, s.config.Database, "utf8mb4", true, "Local"), +// }), dialectors[0]) +//} +// +//func (s *DialectorTestSuite) TestPostgres() { +// dialector := NewDialector(s.mockConfig, orm.DriverPostgres.String()) +// s.mockConfig.On("GetString", "database.connections.postgres.driver"). +// Return(orm.DriverPostgres.String()).Once() +// s.mockConfig.On("GetString", "database.connections.postgres.sslmode"). +// Return("disable").Once() +// s.mockConfig.On("GetString", "database.connections.postgres.timezone"). +// Return("UTC").Once() +// dialectors, err := dialector.Make([]databasecontract.Config{s.config}) +// s.Nil(err) +// s.NotEmpty(dialectors) +// s.Equal(postgres.New(postgres.Config{ +// DSN: fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s&timezone=%s", +// s.config.Username, s.config.Password, s.config.Host, s.config.Port, s.config.Database, "disable", "UTC"), +// }), dialectors[0]) +//} +// +//func (s *DialectorTestSuite) TestSqlite() { +// dialector := NewDialector(s.mockConfig, orm.DriverSqlite.String()) +// s.mockConfig.On("GetString", "database.connections.sqlite.driver"). +// Return(orm.DriverSqlite.String()).Once() +// dialectors, err := dialector.Make([]databasecontract.Config{s.config}) +// s.Nil(err) +// s.NotEmpty(dialectors) +// s.Equal(sqlite.Open(fmt.Sprintf("%s?multi_stmts=true", s.config.Database)), dialectors[0]) +//} +// +//func (s *DialectorTestSuite) TestSqlserver() { +// dialector := NewDialector(s.mockConfig, orm.DriverSqlserver.String()) +// s.mockConfig.On("GetString", "database.connections.sqlserver.driver"). +// Return(orm.DriverSqlserver.String()).Once() +// s.mockConfig.On("GetString", "database.connections.sqlserver.charset"). +// Return("utf8mb4").Once() +// dialectors, err := dialector.Make([]databasecontract.Config{s.config}) +// s.Nil(err) +// s.NotEmpty(dialectors) +// s.Equal(sqlserver.New(sqlserver.Config{ +// DSN: fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s&charset=%s&MultipleActiveResultSets=true", +// s.config.Username, s.config.Password, s.config.Host, s.config.Port, s.config.Database, "utf8mb4"), +// }), dialectors[0]) +//} -func (s *DialectorTestSuite) TestPostgres() { - dialector := NewDialectorImpl(s.mockConfig, orm.DriverPostgres.String()) - s.mockConfig.On("GetString", "database.connections.postgres.driver"). - Return(orm.DriverPostgres.String()).Once() - s.mockConfig.On("GetString", "database.connections.postgres.sslmode"). - Return("disable").Once() - s.mockConfig.On("GetString", "database.connections.postgres.timezone"). - Return("UTC").Once() - dialectors, err := dialector.Make([]databasecontract.Config{s.config}) - s.Nil(err) - s.NotEmpty(dialectors) - s.Equal(postgres.New(postgres.Config{ - DSN: fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s&timezone=%s", - s.config.Username, s.config.Password, s.config.Host, s.config.Port, s.config.Database, "disable", "UTC"), - }), dialectors[0]) -} - -func (s *DialectorTestSuite) TestSqlite() { - dialector := NewDialectorImpl(s.mockConfig, orm.DriverSqlite.String()) - s.mockConfig.On("GetString", "database.connections.sqlite.driver"). - Return(orm.DriverSqlite.String()).Once() - dialectors, err := dialector.Make([]databasecontract.Config{s.config}) - s.Nil(err) - s.NotEmpty(dialectors) - s.Equal(sqlite.Open(fmt.Sprintf("%s?multi_stmts=true", s.config.Database)), dialectors[0]) -} +func TestGetDialectors(t *testing.T) { -func (s *DialectorTestSuite) TestSqlserver() { - dialector := NewDialectorImpl(s.mockConfig, orm.DriverSqlserver.String()) - s.mockConfig.On("GetString", "database.connections.sqlserver.driver"). - Return(orm.DriverSqlserver.String()).Once() - s.mockConfig.On("GetString", "database.connections.sqlserver.charset"). - Return("utf8mb4").Once() - dialectors, err := dialector.Make([]databasecontract.Config{s.config}) - s.Nil(err) - s.NotEmpty(dialectors) - s.Equal(sqlserver.New(sqlserver.Config{ - DSN: fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s&charset=%s&MultipleActiveResultSets=true", - s.config.Username, s.config.Password, s.config.Host, s.config.Port, s.config.Database, "utf8mb4"), - }), dialectors[0]) } diff --git a/database/gorm/event.go b/database/gorm/event.go index b79a387ff..a2a20121f 100644 --- a/database/gorm/event.go +++ b/database/gorm/event.go @@ -17,10 +17,10 @@ type Event struct { destOfMap map[string]any model any modelOfMap map[string]any - query *QueryImpl + query *Query } -func NewEvent(query *QueryImpl, model, dest any) *Event { +func NewEvent(query *Query, model, dest any) *Event { return &Event{ dest: dest, model: model, @@ -92,7 +92,7 @@ func (e *Event) IsDirty(columns ...string) bool { } func (e *Event) Query() orm.Query { - return NewQueryImpl(e.query.ctx, e.query.config, e.query.connection, e.query.instance.Session(&gorm.Session{NewDB: true}), nil) + return NewQuery(e.query.ctx, e.query.config, e.query.connection, e.query.instance.Session(&gorm.Session{NewDB: true}), nil) } func (e *Event) SetAttribute(key string, value any) { diff --git a/database/gorm/event_test.go b/database/gorm/event_test.go index d541c9ae0..5ff0d6616 100644 --- a/database/gorm/event_test.go +++ b/database/gorm/event_test.go @@ -39,7 +39,7 @@ var testEventModel = TestEventModel{ ManageAt: testNow, high: 1, } -var testQuery = &QueryImpl{ +var testQuery = &Query{ instance: &gorm.DB{ Statement: &gorm.Statement{ Selects: []string{}, @@ -85,7 +85,7 @@ func (s *EventTestSuite) SetupTest() { func (s *EventTestSuite) TestSetAttribute() { // dest is map dest := map[string]any{"avatar": "avatar1"} - query := &QueryImpl{ + query := &Query{ instance: &gorm.DB{ Statement: &gorm.Statement{ Selects: []string{}, @@ -113,7 +113,7 @@ func (s *EventTestSuite) TestSetAttribute() { dest1 := &TestEventModel{ Avatar: "avatar1", } - query1 := &QueryImpl{ + query1 := &Query{ instance: &gorm.DB{ Statement: &gorm.Statement{ Selects: []string{}, @@ -242,7 +242,7 @@ func (s *EventTestSuite) TestValidColumn() { s.True(event.validColumn("manage")) s.False(event.validColumn("age")) - event.query = &QueryImpl{ + event.query = &Query{ instance: &gorm.DB{ Statement: &gorm.Statement{ Selects: []string{"name"}, @@ -255,7 +255,7 @@ func (s *EventTestSuite) TestValidColumn() { s.False(event.validColumn("avatar")) s.False(event.validColumn("Avatar")) - event.query = &QueryImpl{ + event.query = &Query{ instance: &gorm.DB{ Statement: &gorm.Statement{ Selects: []string{}, diff --git a/database/gorm/gorm.go b/database/gorm/gorm.go index d8a923851..8ba01fc9c 100644 --- a/database/gorm/gorm.go +++ b/database/gorm/gorm.go @@ -7,56 +7,39 @@ import ( "os" "time" - "github.com/google/wire" gormio "gorm.io/gorm" gormlogger "gorm.io/gorm/logger" "gorm.io/gorm/schema" "gorm.io/plugin/dbresolver" "github.com/goravel/framework/contracts/config" - databasecontract "github.com/goravel/framework/contracts/database" - "github.com/goravel/framework/contracts/database/gorm" - "github.com/goravel/framework/database/db" + "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/support/carbon" ) -var GormSet = wire.NewSet(NewGormImpl, wire.Bind(new(gorm.Gorm), new(*GormImpl))) -var _ gorm.Gorm = &GormImpl{} - -type GormImpl struct { - config config.Config - connection string - dbConfig db.Config - dialector Dialector - instance *gormio.DB +type Builder struct { + config config.Config + configs database.Configs + instance *gormio.DB } -func NewGormImpl(config config.Config, connection string, dbConfig db.Config, dialector Dialector) *GormImpl { - return &GormImpl{ - config: config, - connection: connection, - dbConfig: dbConfig, - dialector: dialector, +func NewGorm(config config.Config, configs database.Configs) (*gormio.DB, error) { + builder := &Builder{ + config: config, + configs: configs, } + + return builder.Build() } -func (r *GormImpl) Make() (*gormio.DB, error) { - readConfigs := r.dbConfig.Reads() - writeConfigs := r.dbConfig.Writes() +func (r *Builder) Build() (*gormio.DB, error) { + readConfigs := r.configs.Reads() + writeConfigs := r.configs.Writes() if len(writeConfigs) == 0 { return nil, errors.New("not found database configuration") } - writeDialectors, err := r.dialector.Make([]databasecontract.Config{writeConfigs[0]}) - if err != nil { - return nil, fmt.Errorf("init gorm dialector error: %v", err) - } - - if len(writeDialectors) == 0 { - return nil, errors.New("no write dialectors found") - } - - if err := r.init(writeDialectors[0]); err != nil { + if err := r.init(writeConfigs[0]); err != nil { return nil, err } @@ -71,7 +54,7 @@ func (r *GormImpl) Make() (*gormio.DB, error) { return r.instance, nil } -func (r *GormImpl) configurePool() error { +func (r *Builder) configurePool() error { sqlDB, err := r.instance.DB() if err != nil { return err @@ -85,17 +68,17 @@ func (r *GormImpl) configurePool() error { return nil } -func (r *GormImpl) configureReadWriteSeparate(readConfigs, writeConfigs []databasecontract.Config) error { +func (r *Builder) configureReadWriteSeparate(readConfigs, writeConfigs []database.FullConfig) error { if len(readConfigs) == 0 || len(writeConfigs) == 0 { return nil } - readDialectors, err := r.dialector.Make(readConfigs) + readDialectors, err := GetDialectors(readConfigs) if err != nil { return err } - writeDialectors, err := r.dialector.Make(writeConfigs) + writeDialectors, err := GetDialectors(writeConfigs) if err != nil { return err } @@ -108,7 +91,15 @@ func (r *GormImpl) configureReadWriteSeparate(readConfigs, writeConfigs []databa })) } -func (r *GormImpl) init(dialector gormio.Dialector) error { +func (r *Builder) init(fullConfig database.FullConfig) error { + dialectors, err := GetDialectors([]database.FullConfig{fullConfig}) + if err != nil { + return fmt.Errorf("init gorm dialector error: %v", err) + } + if len(dialectors) == 0 { + return errors.New("no dialectors found") + } + var logLevel gormlogger.LogLevel if r.config.GetBool("app.debug") { logLevel = gormlogger.Info @@ -122,7 +113,7 @@ func (r *GormImpl) init(dialector gormio.Dialector) error { IgnoreRecordNotFoundError: true, Colorful: true, }) - instance, err := gormio.Open(dialector, &gormio.Config{ + instance, err := gormio.Open(dialectors[0], &gormio.Config{ DisableForeignKeyConstraintWhenMigrating: true, SkipDefaultTransaction: true, Logger: logger.LogMode(logLevel), @@ -130,8 +121,8 @@ func (r *GormImpl) init(dialector gormio.Dialector) error { return carbon.Now().StdTime() }, NamingStrategy: schema.NamingStrategy{ - TablePrefix: r.config.GetString(fmt.Sprintf("database.connections.%s.prefix", r.connection)), - SingularTable: r.config.GetBool(fmt.Sprintf("database.connections.%s.singular", r.connection)), + TablePrefix: fullConfig.Prefix, + SingularTable: fullConfig.Singular, }, }) if err != nil { diff --git a/database/gorm/query.go b/database/gorm/query.go index 762b8c080..cd0e2c269 100644 --- a/database/gorm/query.go +++ b/database/gorm/query.go @@ -7,7 +7,6 @@ import ( "fmt" "reflect" - "github.com/google/wire" "github.com/spf13/cast" "gorm.io/driver/mysql" "gorm.io/driver/postgres" @@ -16,32 +15,29 @@ import ( "gorm.io/gorm/clause" "github.com/goravel/framework/contracts/config" - "github.com/goravel/framework/contracts/database/gorm" ormcontract "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/database/db" "github.com/goravel/framework/database/gorm/hints" "github.com/goravel/framework/database/orm" "github.com/goravel/framework/support/database" ) -var QuerySet = wire.NewSet(BuildQueryImpl, wire.Bind(new(ormcontract.Query), new(*QueryImpl))) -var _ ormcontract.Query = &QueryImpl{} - -type QueryImpl struct { +type Query struct { conditions Conditions config config.Config connection string ctx context.Context instance *gormio.DB - queries map[string]*QueryImpl + queries map[string]*Query } -func NewQueryImpl(ctx context.Context, config config.Config, connection string, db *gormio.DB, conditions *Conditions) *QueryImpl { - queryImpl := &QueryImpl{ +func NewQuery(ctx context.Context, config config.Config, connection string, db *gormio.DB, conditions *Conditions) *Query { + queryImpl := &Query{ config: config, connection: connection, ctx: ctx, instance: db, - queries: make(map[string]*QueryImpl), + queries: make(map[string]*Query), } if conditions != nil { @@ -51,25 +47,23 @@ func NewQueryImpl(ctx context.Context, config config.Config, connection string, return queryImpl } -func BuildQueryImpl(ctx context.Context, config config.Config, connection string, gorm gorm.Gorm) (*QueryImpl, error) { - db, err := gorm.Make() +func BuildQuery(ctx context.Context, config config.Config, connection string) (*Query, error) { + configs := db.NewConfigs(config, connection) + gorm, err := NewGorm(config, configs) if err != nil { return nil, err } - if ctx != nil { - db = db.WithContext(ctx) - } - return NewQueryImpl(ctx, config, connection, db, nil), nil + return NewQuery(ctx, config, connection, gorm, nil), nil } -func (r *QueryImpl) Association(association string) ormcontract.Association { +func (r *Query) Association(association string) ormcontract.Association { query := r.buildConditions() return query.instance.Association(association) } -func (r *QueryImpl) Begin() (ormcontract.Query, error) { +func (r *Query) Begin() (ormcontract.Query, error) { tx := r.instance.Begin() if tx.Error != nil { return nil, tx.Error @@ -78,17 +72,17 @@ func (r *QueryImpl) Begin() (ormcontract.Query, error) { return r.new(tx), nil } -func (r *QueryImpl) Commit() error { +func (r *Query) Commit() error { return r.instance.Commit().Error } -func (r *QueryImpl) Count(count *int64) error { +func (r *Query) Count(count *int64) error { query := r.buildConditions() return query.instance.Count(count).Error } -func (r *QueryImpl) Create(value any) error { +func (r *Query) Create(value any) error { query, err := r.refreshConnection(value) if err != nil { return err @@ -110,7 +104,7 @@ func (r *QueryImpl) Create(value any) error { return query.create(value) } -func (r *QueryImpl) Cursor() (chan ormcontract.Cursor, error) { +func (r *Query) Cursor() (chan ormcontract.Cursor, error) { with := r.conditions.with query := r.buildConditions() r.conditions.with = with @@ -138,7 +132,7 @@ func (r *QueryImpl) Cursor() (chan ormcontract.Cursor, error) { return cursorChan, err } -func (r *QueryImpl) Delete(dest ...any) (*ormcontract.Result, error) { +func (r *Query) Delete(dest ...any) (*ormcontract.Result, error) { var ( realDest any err error @@ -172,18 +166,18 @@ func (r *QueryImpl) Delete(dest ...any) (*ormcontract.Result, error) { }, nil } -func (r *QueryImpl) Distinct(args ...any) ormcontract.Query { +func (r *Query) Distinct(args ...any) ormcontract.Query { conditions := r.conditions conditions.distinct = append(conditions.distinct, args...) return r.setConditions(conditions) } -func (r *QueryImpl) Driver() ormcontract.Driver { +func (r *Query) Driver() ormcontract.Driver { return ormcontract.Driver(r.instance.Dialector.Name()) } -func (r *QueryImpl) Exec(sql string, values ...any) (*ormcontract.Result, error) { +func (r *Query) Exec(sql string, values ...any) (*ormcontract.Result, error) { query := r.buildConditions() result := query.instance.Exec(sql, values...) @@ -192,13 +186,13 @@ func (r *QueryImpl) Exec(sql string, values ...any) (*ormcontract.Result, error) }, result.Error } -func (r *QueryImpl) Exists(exists *bool) error { +func (r *Query) Exists(exists *bool) error { query := r.buildConditions() return query.instance.Select("1").Limit(1).Find(exists).Error } -func (r *QueryImpl) Find(dest any, conds ...any) error { +func (r *Query) Find(dest any, conds ...any) error { query, err := r.refreshConnection(dest) if err != nil { return err @@ -216,7 +210,7 @@ func (r *QueryImpl) Find(dest any, conds ...any) error { return query.retrieved(dest) } -func (r *QueryImpl) FindOrFail(dest any, conds ...any) error { +func (r *Query) FindOrFail(dest any, conds ...any) error { query, err := r.refreshConnection(dest) if err != nil { return err @@ -240,7 +234,7 @@ func (r *QueryImpl) FindOrFail(dest any, conds ...any) error { return query.retrieved(dest) } -func (r *QueryImpl) First(dest any) error { +func (r *Query) First(dest any) error { query, err := r.refreshConnection(dest) if err != nil { return err @@ -260,7 +254,7 @@ func (r *QueryImpl) First(dest any) error { return query.retrieved(dest) } -func (r *QueryImpl) FirstOr(dest any, callback func() error) error { +func (r *Query) FirstOr(dest any, callback func() error) error { query, err := r.refreshConnection(dest) if err != nil { return err @@ -279,7 +273,7 @@ func (r *QueryImpl) FirstOr(dest any, callback func() error) error { return query.retrieved(dest) } -func (r *QueryImpl) FirstOrCreate(dest any, conds ...any) error { +func (r *Query) FirstOrCreate(dest any, conds ...any) error { query, err := r.refreshConnection(dest) if err != nil { return err @@ -308,7 +302,7 @@ func (r *QueryImpl) FirstOrCreate(dest any, conds ...any) error { return query.Create(dest) } -func (r *QueryImpl) FirstOrFail(dest any) error { +func (r *Query) FirstOrFail(dest any) error { query, err := r.refreshConnection(dest) if err != nil { return err @@ -327,7 +321,7 @@ func (r *QueryImpl) FirstOrFail(dest any) error { return query.retrieved(dest) } -func (r *QueryImpl) FirstOrNew(dest any, attributes any, values ...any) error { +func (r *Query) FirstOrNew(dest any, attributes any, values ...any) error { query, err := r.refreshConnection(dest) if err != nil { return err @@ -352,7 +346,7 @@ func (r *QueryImpl) FirstOrNew(dest any, attributes any, values ...any) error { return nil } -func (r *QueryImpl) ForceDelete(dest ...any) (*ormcontract.Result, error) { +func (r *Query) ForceDelete(dest ...any) (*ormcontract.Result, error) { var ( realDest any err error @@ -388,18 +382,18 @@ func (r *QueryImpl) ForceDelete(dest ...any) (*ormcontract.Result, error) { }, res.Error } -func (r *QueryImpl) Get(dest any) error { +func (r *Query) Get(dest any) error { return r.Find(dest) } -func (r *QueryImpl) Group(name string) ormcontract.Query { +func (r *Query) Group(name string) ormcontract.Query { conditions := r.conditions conditions.group = name return r.setConditions(conditions) } -func (r *QueryImpl) Having(query any, args ...any) ormcontract.Query { +func (r *Query) Having(query any, args ...any) ormcontract.Query { conditions := r.conditions conditions.having = &Having{ query: query, @@ -409,11 +403,11 @@ func (r *QueryImpl) Having(query any, args ...any) ormcontract.Query { return r.setConditions(conditions) } -func (r *QueryImpl) Instance() *gormio.DB { +func (r *Query) Instance() *gormio.DB { return r.instance } -func (r *QueryImpl) Join(query string, args ...any) ormcontract.Query { +func (r *Query) Join(query string, args ...any) ormcontract.Query { conditions := r.conditions conditions.join = append(conditions.join, Join{ query: query, @@ -423,14 +417,14 @@ func (r *QueryImpl) Join(query string, args ...any) ormcontract.Query { return r.setConditions(conditions) } -func (r *QueryImpl) Limit(limit int) ormcontract.Query { +func (r *Query) Limit(limit int) ormcontract.Query { conditions := r.conditions conditions.limit = &limit return r.setConditions(conditions) } -func (r *QueryImpl) Load(model any, relation string, args ...any) error { +func (r *Query) Load(model any, relation string, args ...any) error { if relation == "" { return errors.New("relation cannot be empty") } @@ -461,7 +455,7 @@ func (r *QueryImpl) Load(model any, relation string, args ...any) error { return err } -func (r *QueryImpl) LoadMissing(model any, relation string, args ...any) error { +func (r *Query) LoadMissing(model any, relation string, args ...any) error { destType := reflect.TypeOf(model) if destType.Kind() != reflect.Pointer { return errors.New("model must be pointer") @@ -495,42 +489,42 @@ func (r *QueryImpl) LoadMissing(model any, relation string, args ...any) error { return r.Load(model, relation, args...) } -func (r *QueryImpl) LockForUpdate() ormcontract.Query { +func (r *Query) LockForUpdate() ormcontract.Query { conditions := r.conditions conditions.lockForUpdate = true return r.setConditions(conditions) } -func (r *QueryImpl) Model(value any) ormcontract.Query { +func (r *Query) Model(value any) ormcontract.Query { conditions := r.conditions conditions.model = value return r.setConditions(conditions) } -func (r *QueryImpl) Offset(offset int) ormcontract.Query { +func (r *Query) Offset(offset int) ormcontract.Query { conditions := r.conditions conditions.offset = &offset return r.setConditions(conditions) } -func (r *QueryImpl) Omit(columns ...string) ormcontract.Query { +func (r *Query) Omit(columns ...string) ormcontract.Query { conditions := r.conditions conditions.omit = columns return r.setConditions(conditions) } -func (r *QueryImpl) Order(value any) ormcontract.Query { +func (r *Query) Order(value any) ormcontract.Query { conditions := r.conditions conditions.order = append(r.conditions.order, value) return r.setConditions(conditions) } -func (r *QueryImpl) OrderBy(column string, direction ...string) ormcontract.Query { +func (r *Query) OrderBy(column string, direction ...string) ormcontract.Query { var orderDirection string if len(direction) > 0 { orderDirection = direction[0] @@ -540,11 +534,11 @@ func (r *QueryImpl) OrderBy(column string, direction ...string) ormcontract.Quer return r.Order(fmt.Sprintf("%s %s", column, orderDirection)) } -func (r *QueryImpl) OrderByDesc(column string) ormcontract.Query { +func (r *Query) OrderByDesc(column string) ormcontract.Query { return r.Order(fmt.Sprintf("%s DESC", column)) } -func (r *QueryImpl) InRandomOrder() ormcontract.Query { +func (r *Query) InRandomOrder() ormcontract.Query { order := "" switch r.Driver() { case ormcontract.DriverMysql: @@ -559,7 +553,7 @@ func (r *QueryImpl) InRandomOrder() ormcontract.Query { return r.Order(order) } -func (r *QueryImpl) OrWhere(query any, args ...any) ormcontract.Query { +func (r *Query) OrWhere(query any, args ...any) ormcontract.Query { conditions := r.conditions conditions.where = append(r.conditions.where, Where{ query: query, @@ -570,7 +564,7 @@ func (r *QueryImpl) OrWhere(query any, args ...any) ormcontract.Query { return r.setConditions(conditions) } -func (r *QueryImpl) Paginate(page, limit int, dest any, total *int64) error { +func (r *Query) Paginate(page, limit int, dest any, total *int64) error { query, err := r.refreshConnection(dest) if err != nil { return err @@ -594,21 +588,21 @@ func (r *QueryImpl) Paginate(page, limit int, dest any, total *int64) error { return query.Offset(offset).Limit(limit).Find(dest) } -func (r *QueryImpl) Pluck(column string, dest any) error { +func (r *Query) Pluck(column string, dest any) error { query := r.buildConditions() return query.instance.Pluck(column, dest).Error } -func (r *QueryImpl) Raw(sql string, values ...any) ormcontract.Query { +func (r *Query) Raw(sql string, values ...any) ormcontract.Query { return r.new(r.instance.Raw(sql, values...)) } -func (r *QueryImpl) Rollback() error { +func (r *Query) Rollback() error { return r.instance.Rollback().Error } -func (r *QueryImpl) Save(value any) error { +func (r *Query) Save(value any) error { query, err := r.refreshConnection(value) if err != nil { return err @@ -666,11 +660,11 @@ func (r *QueryImpl) Save(value any) error { return nil } -func (r *QueryImpl) SaveQuietly(value any) error { +func (r *Query) SaveQuietly(value any) error { return r.WithoutEvents().Save(value) } -func (r *QueryImpl) Scan(dest any) error { +func (r *Query) Scan(dest any) error { query, err := r.refreshConnection(dest) if err != nil { return err @@ -681,14 +675,14 @@ func (r *QueryImpl) Scan(dest any) error { return query.instance.Scan(dest).Error } -func (r *QueryImpl) Scopes(funcs ...func(ormcontract.Query) ormcontract.Query) ormcontract.Query { +func (r *Query) Scopes(funcs ...func(ormcontract.Query) ormcontract.Query) ormcontract.Query { conditions := r.conditions conditions.scopes = append(r.conditions.scopes, funcs...) return r.setConditions(conditions) } -func (r *QueryImpl) Select(query any, args ...any) ormcontract.Query { +func (r *Query) Select(query any, args ...any) ormcontract.Query { conditions := r.conditions conditions.selectColumns = &Select{ query: query, @@ -698,25 +692,25 @@ func (r *QueryImpl) Select(query any, args ...any) ormcontract.Query { return r.setConditions(conditions) } -func (r *QueryImpl) SetContext(ctx context.Context) { +func (r *Query) SetContext(ctx context.Context) { r.ctx = ctx r.instance.Statement.Context = ctx } -func (r *QueryImpl) SharedLock() ormcontract.Query { +func (r *Query) SharedLock() ormcontract.Query { conditions := r.conditions conditions.sharedLock = true return r.setConditions(conditions) } -func (r *QueryImpl) Sum(column string, dest any) error { +func (r *Query) Sum(column string, dest any) error { query := r.buildConditions() return query.instance.Select("SUM(" + column + ")").Row().Scan(dest) } -func (r *QueryImpl) Table(name string, args ...any) ormcontract.Query { +func (r *Query) Table(name string, args ...any) ormcontract.Query { conditions := r.conditions conditions.table = &Table{ name: name, @@ -726,15 +720,15 @@ func (r *QueryImpl) Table(name string, args ...any) ormcontract.Query { return r.setConditions(conditions) } -func (r *QueryImpl) ToSql() ormcontract.ToSql { +func (r *Query) ToSql() ormcontract.ToSql { return NewToSql(r.setConditions(r.conditions), false) } -func (r *QueryImpl) ToRawSql() ormcontract.ToSql { +func (r *Query) ToRawSql() ormcontract.ToSql { return NewToSql(r.setConditions(r.conditions), true) } -func (r *QueryImpl) Update(column any, value ...any) (*ormcontract.Result, error) { +func (r *Query) Update(column any, value ...any) (*ormcontract.Result, error) { query := r.buildConditions() if _, ok := column.(string); !ok && len(value) > 0 { @@ -778,7 +772,7 @@ func (r *QueryImpl) Update(column any, value ...any) (*ormcontract.Result, error return res, err } -func (r *QueryImpl) UpdateOrCreate(dest any, attributes any, values any) error { +func (r *Query) UpdateOrCreate(dest any, attributes any, values any) error { query, err := r.refreshConnection(dest) if err != nil { return err @@ -797,7 +791,7 @@ func (r *QueryImpl) UpdateOrCreate(dest any, attributes any, values any) error { return query.Create(dest) } -func (r *QueryImpl) Where(query any, args ...any) ormcontract.Query { +func (r *Query) Where(query any, args ...any) ormcontract.Query { conditions := r.conditions conditions.where = append(r.conditions.where, Where{ query: query, @@ -807,51 +801,51 @@ func (r *QueryImpl) Where(query any, args ...any) ormcontract.Query { return r.setConditions(conditions) } -func (r *QueryImpl) WhereIn(column string, values []any) ormcontract.Query { +func (r *Query) WhereIn(column string, values []any) ormcontract.Query { return r.Where(fmt.Sprintf("%s IN ?", column), values) } -func (r *QueryImpl) OrWhereIn(column string, values []any) ormcontract.Query { +func (r *Query) OrWhereIn(column string, values []any) ormcontract.Query { return r.OrWhere(fmt.Sprintf("%s IN ?", column), values) } -func (r *QueryImpl) WhereNotIn(column string, values []any) ormcontract.Query { +func (r *Query) WhereNotIn(column string, values []any) ormcontract.Query { return r.Where(fmt.Sprintf("%s NOT IN ?", column), values) } -func (r *QueryImpl) OrWhereNotIn(column string, values []any) ormcontract.Query { +func (r *Query) OrWhereNotIn(column string, values []any) ormcontract.Query { return r.OrWhere(fmt.Sprintf("%s NOT IN ?", column), values) } -func (r *QueryImpl) WhereBetween(column string, x, y any) ormcontract.Query { +func (r *Query) WhereBetween(column string, x, y any) ormcontract.Query { return r.Where(fmt.Sprintf("%s BETWEEN %v AND %v", column, x, y)) } -func (r *QueryImpl) WhereNotBetween(column string, x, y any) ormcontract.Query { +func (r *Query) WhereNotBetween(column string, x, y any) ormcontract.Query { return r.Where(fmt.Sprintf("%s NOT BETWEEN %v AND %v", column, x, y)) } -func (r *QueryImpl) OrWhereBetween(column string, x, y any) ormcontract.Query { +func (r *Query) OrWhereBetween(column string, x, y any) ormcontract.Query { return r.OrWhere(fmt.Sprintf("%s BETWEEN %v AND %v", column, x, y)) } -func (r *QueryImpl) OrWhereNotBetween(column string, x, y any) ormcontract.Query { +func (r *Query) OrWhereNotBetween(column string, x, y any) ormcontract.Query { return r.OrWhere(fmt.Sprintf("%s NOT BETWEEN %v AND %v", column, x, y)) } -func (r *QueryImpl) OrWhereNull(column string) ormcontract.Query { +func (r *Query) OrWhereNull(column string) ormcontract.Query { return r.OrWhere(fmt.Sprintf("%s IS NULL", column)) } -func (r *QueryImpl) WhereNull(column string) ormcontract.Query { +func (r *Query) WhereNull(column string) ormcontract.Query { return r.Where(fmt.Sprintf("%s IS NULL", column)) } -func (r *QueryImpl) WhereNotNull(column string) ormcontract.Query { +func (r *Query) WhereNotNull(column string) ormcontract.Query { return r.Where(fmt.Sprintf("%s IS NOT NULL", column)) } -func (r *QueryImpl) With(query string, args ...any) ormcontract.Query { +func (r *Query) With(query string, args ...any) ormcontract.Query { conditions := r.conditions conditions.with = append(r.conditions.with, With{ query: query, @@ -861,21 +855,21 @@ func (r *QueryImpl) With(query string, args ...any) ormcontract.Query { return r.setConditions(conditions) } -func (r *QueryImpl) WithoutEvents() ormcontract.Query { +func (r *Query) WithoutEvents() ormcontract.Query { conditions := r.conditions conditions.withoutEvents = true return r.setConditions(conditions) } -func (r *QueryImpl) WithTrashed() ormcontract.Query { +func (r *Query) WithTrashed() ormcontract.Query { conditions := r.conditions conditions.withTrashed = true return r.setConditions(conditions) } -func (r *QueryImpl) buildConditions() *QueryImpl { +func (r *Query) buildConditions() *Query { query := r.buildModel() db := query.instance db = query.buildDistinct(db) @@ -898,7 +892,7 @@ func (r *QueryImpl) buildConditions() *QueryImpl { return query.new(db) } -func (r *QueryImpl) buildDistinct(db *gormio.DB) *gormio.DB { +func (r *Query) buildDistinct(db *gormio.DB) *gormio.DB { if len(r.conditions.distinct) == 0 { return db } @@ -909,7 +903,7 @@ func (r *QueryImpl) buildDistinct(db *gormio.DB) *gormio.DB { return db } -func (r *QueryImpl) buildGroup(db *gormio.DB) *gormio.DB { +func (r *Query) buildGroup(db *gormio.DB) *gormio.DB { if r.conditions.group == "" { return db } @@ -920,7 +914,7 @@ func (r *QueryImpl) buildGroup(db *gormio.DB) *gormio.DB { return db } -func (r *QueryImpl) buildHaving(db *gormio.DB) *gormio.DB { +func (r *Query) buildHaving(db *gormio.DB) *gormio.DB { if r.conditions.having == nil { return db } @@ -931,7 +925,7 @@ func (r *QueryImpl) buildHaving(db *gormio.DB) *gormio.DB { return db } -func (r *QueryImpl) buildJoin(db *gormio.DB) *gormio.DB { +func (r *Query) buildJoin(db *gormio.DB) *gormio.DB { if r.conditions.join == nil { return db } @@ -945,7 +939,7 @@ func (r *QueryImpl) buildJoin(db *gormio.DB) *gormio.DB { return db } -func (r *QueryImpl) buildLimit(db *gormio.DB) *gormio.DB { +func (r *Query) buildLimit(db *gormio.DB) *gormio.DB { if r.conditions.limit == nil { return db } @@ -956,7 +950,7 @@ func (r *QueryImpl) buildLimit(db *gormio.DB) *gormio.DB { return db } -func (r *QueryImpl) buildLockForUpdate(db *gormio.DB) *gormio.DB { +func (r *Query) buildLockForUpdate(db *gormio.DB) *gormio.DB { if !r.conditions.lockForUpdate { return db } @@ -977,7 +971,7 @@ func (r *QueryImpl) buildLockForUpdate(db *gormio.DB) *gormio.DB { return db } -func (r *QueryImpl) buildModel() *QueryImpl { +func (r *Query) buildModel() *Query { if r.conditions.model == nil { return r } @@ -990,7 +984,7 @@ func (r *QueryImpl) buildModel() *QueryImpl { return query.new(query.instance.Model(r.conditions.model)) } -func (r *QueryImpl) buildOffset(db *gormio.DB) *gormio.DB { +func (r *Query) buildOffset(db *gormio.DB) *gormio.DB { if r.conditions.offset == nil { return db } @@ -1001,7 +995,7 @@ func (r *QueryImpl) buildOffset(db *gormio.DB) *gormio.DB { return db } -func (r *QueryImpl) buildOmit(db *gormio.DB) *gormio.DB { +func (r *Query) buildOmit(db *gormio.DB) *gormio.DB { if len(r.conditions.omit) == 0 { return db } @@ -1012,7 +1006,7 @@ func (r *QueryImpl) buildOmit(db *gormio.DB) *gormio.DB { return db } -func (r *QueryImpl) buildOrder(db *gormio.DB) *gormio.DB { +func (r *Query) buildOrder(db *gormio.DB) *gormio.DB { if len(r.conditions.order) == 0 { return db } @@ -1026,7 +1020,7 @@ func (r *QueryImpl) buildOrder(db *gormio.DB) *gormio.DB { return db } -func (r *QueryImpl) buildSelectColumns(db *gormio.DB) *gormio.DB { +func (r *Query) buildSelectColumns(db *gormio.DB) *gormio.DB { if r.conditions.selectColumns == nil { return db } @@ -1037,7 +1031,7 @@ func (r *QueryImpl) buildSelectColumns(db *gormio.DB) *gormio.DB { return db } -func (r *QueryImpl) buildScopes(db *gormio.DB) *gormio.DB { +func (r *Query) buildScopes(db *gormio.DB) *gormio.DB { if len(r.conditions.scopes) == 0 { return db } @@ -1048,7 +1042,7 @@ func (r *QueryImpl) buildScopes(db *gormio.DB) *gormio.DB { gormFuncs = append(gormFuncs, func(tx *gormio.DB) *gormio.DB { queryImpl := r.new(tx) query := currentScope(queryImpl) - queryImpl = query.(*QueryImpl) + queryImpl = query.(*Query) queryImpl = queryImpl.buildConditions() return queryImpl.instance @@ -1061,7 +1055,7 @@ func (r *QueryImpl) buildScopes(db *gormio.DB) *gormio.DB { return db } -func (r *QueryImpl) buildSharedLock(db *gormio.DB) *gormio.DB { +func (r *Query) buildSharedLock(db *gormio.DB) *gormio.DB { if !r.conditions.sharedLock { return db } @@ -1082,7 +1076,7 @@ func (r *QueryImpl) buildSharedLock(db *gormio.DB) *gormio.DB { return db } -func (r *QueryImpl) buildTable(db *gormio.DB) *gormio.DB { +func (r *Query) buildTable(db *gormio.DB) *gormio.DB { if r.conditions.table == nil { return db } @@ -1093,7 +1087,7 @@ func (r *QueryImpl) buildTable(db *gormio.DB) *gormio.DB { return db } -func (r *QueryImpl) buildWhere(db *gormio.DB) *gormio.DB { +func (r *Query) buildWhere(db *gormio.DB) *gormio.DB { if len(r.conditions.where) == 0 { return db } @@ -1111,7 +1105,7 @@ func (r *QueryImpl) buildWhere(db *gormio.DB) *gormio.DB { return db } -func (r *QueryImpl) buildWith(db *gormio.DB) *gormio.DB { +func (r *Query) buildWith(db *gormio.DB) *gormio.DB { if len(r.conditions.with) == 0 { return db } @@ -1122,9 +1116,9 @@ func (r *QueryImpl) buildWith(db *gormio.DB) *gormio.DB { if arg, ok := item.args[0].(func(ormcontract.Query) ormcontract.Query); ok { newArgs := []any{ func(tx *gormio.DB) *gormio.DB { - queryImpl := NewQueryImpl(r.ctx, r.config, r.connection, tx, nil) + queryImpl := NewQuery(r.ctx, r.config, r.connection, tx, nil) query := arg(queryImpl) - queryImpl = query.(*QueryImpl) + queryImpl = query.(*Query) queryImpl = queryImpl.buildConditions() return queryImpl.instance @@ -1146,7 +1140,7 @@ func (r *QueryImpl) buildWith(db *gormio.DB) *gormio.DB { return db } -func (r *QueryImpl) buildWithTrashed(db *gormio.DB) *gormio.DB { +func (r *Query) buildWithTrashed(db *gormio.DB) *gormio.DB { if !r.conditions.withTrashed { return db } @@ -1157,11 +1151,11 @@ func (r *QueryImpl) buildWithTrashed(db *gormio.DB) *gormio.DB { return db } -func (r *QueryImpl) clearConditions() { +func (r *Query) clearConditions() { r.conditions = Conditions{} } -func (r *QueryImpl) create(dest any) error { +func (r *Query) create(dest any) error { if err := r.saving(dest); err != nil { return err } @@ -1183,31 +1177,31 @@ func (r *QueryImpl) create(dest any) error { return nil } -func (r *QueryImpl) created(dest any) error { +func (r *Query) created(dest any) error { return r.event(ormcontract.EventCreated, r.instance.Statement.Model, dest) } -func (r *QueryImpl) creating(dest any) error { +func (r *Query) creating(dest any) error { return r.event(ormcontract.EventCreating, r.instance.Statement.Model, dest) } -func (r *QueryImpl) deleting(dest any) error { +func (r *Query) deleting(dest any) error { return r.event(ormcontract.EventDeleting, r.instance.Statement.Model, dest) } -func (r *QueryImpl) deleted(dest any) error { +func (r *Query) deleted(dest any) error { return r.event(ormcontract.EventDeleted, r.instance.Statement.Model, dest) } -func (r *QueryImpl) forceDeleting(dest any) error { +func (r *Query) forceDeleting(dest any) error { return r.event(ormcontract.EventForceDeleting, r.instance.Statement.Model, dest) } -func (r *QueryImpl) forceDeleted(dest any) error { +func (r *Query) forceDeleted(dest any) error { return r.event(ormcontract.EventForceDeleted, r.instance.Statement.Model, dest) } -func (r *QueryImpl) event(event ormcontract.EventType, model, dest any) error { +func (r *Query) event(event ormcontract.EventType, model, dest any) error { if r.conditions.withoutEvents { return nil } @@ -1255,13 +1249,13 @@ func (r *QueryImpl) event(event ormcontract.EventType, model, dest any) error { return nil } -func (r *QueryImpl) new(db *gormio.DB) *QueryImpl { - query := NewQueryImpl(r.ctx, r.config, r.connection, db, &r.conditions) +func (r *Query) new(db *gormio.DB) *Query { + query := NewQuery(r.ctx, r.config, r.connection, db, &r.conditions) return query } -func (r *QueryImpl) omitCreate(value any) error { +func (r *Query) omitCreate(value any) error { if len(r.instance.Statement.Omits) > 1 { for _, val := range r.instance.Statement.Omits { if val == orm.Associations { @@ -1301,7 +1295,7 @@ func (r *QueryImpl) omitCreate(value any) error { return nil } -func (r *QueryImpl) omitSave(value any) error { +func (r *Query) omitSave(value any) error { for _, val := range r.instance.Statement.Omits { if val == orm.Associations { return r.instance.Omit(orm.Associations).Save(value).Error @@ -1311,7 +1305,7 @@ func (r *QueryImpl) omitSave(value any) error { return r.instance.Save(value).Error } -func (r *QueryImpl) refreshConnection(model any) (*QueryImpl, error) { +func (r *Query) refreshConnection(model any) (*Query, error) { connection, err := getModelConnection(model) if err != nil { return nil, err @@ -1323,13 +1317,13 @@ func (r *QueryImpl) refreshConnection(model any) (*QueryImpl, error) { query, ok := r.queries[connection] if !ok { var err error - query, err = InitializeQuery(r.ctx, r.config, connection) + query, err = BuildQuery(r.ctx, r.config, connection) if err != nil { return nil, err } if r.queries == nil { - r.queries = make(map[string]*QueryImpl) + r.queries = make(map[string]*Query) } r.queries[connection] = query } @@ -1339,23 +1333,23 @@ func (r *QueryImpl) refreshConnection(model any) (*QueryImpl, error) { return query, nil } -func (r *QueryImpl) retrieved(dest any) error { +func (r *Query) retrieved(dest any) error { return r.event(ormcontract.EventRetrieved, nil, dest) } -func (r *QueryImpl) save(value any) error { +func (r *Query) save(value any) error { return r.instance.Omit(orm.Associations).Save(value).Error } -func (r *QueryImpl) saved(dest any) error { +func (r *Query) saved(dest any) error { return r.event(ormcontract.EventSaved, r.instance.Statement.Model, dest) } -func (r *QueryImpl) saving(dest any) error { +func (r *Query) saving(dest any) error { return r.event(ormcontract.EventSaving, r.instance.Statement.Model, dest) } -func (r *QueryImpl) selectCreate(value any) error { +func (r *Query) selectCreate(value any) error { if len(r.instance.Statement.Selects) > 1 { for _, val := range r.instance.Statement.Selects { if val == orm.Associations { @@ -1389,7 +1383,7 @@ func (r *QueryImpl) selectCreate(value any) error { return nil } -func (r *QueryImpl) selectSave(value any) error { +func (r *Query) selectSave(value any) error { for _, val := range r.instance.Statement.Selects { if val == orm.Associations { return r.instance.Session(&gormio.Session{FullSaveAssociations: true}).Save(value).Error @@ -1403,22 +1397,22 @@ func (r *QueryImpl) selectSave(value any) error { return nil } -func (r *QueryImpl) setConditions(conditions Conditions) *QueryImpl { +func (r *Query) setConditions(conditions Conditions) *Query { query := r.new(r.instance) query.conditions = conditions return query } -func (r *QueryImpl) updating(dest any) error { +func (r *Query) updating(dest any) error { return r.event(ormcontract.EventUpdating, r.instance.Statement.Model, dest) } -func (r *QueryImpl) updated(dest any) error { +func (r *Query) updated(dest any) error { return r.event(ormcontract.EventUpdated, r.instance.Statement.Model, dest) } -func (r *QueryImpl) updates(values any) (*ormcontract.Result, error) { +func (r *Query) updates(values any) (*ormcontract.Result, error) { if len(r.instance.Statement.Selects) > 0 && len(r.instance.Statement.Omits) > 0 { return nil, errors.New("cannot set Select and Omits at the same time") } diff --git a/database/gorm/query_test.go b/database/gorm/query_test.go index 108691128..7fb5b9e85 100644 --- a/database/gorm/query_test.go +++ b/database/gorm/query_test.go @@ -2646,7 +2646,7 @@ func (s *QueryTestSuite) TestRefreshConnection() { s.Run(test.name, func() { test.setup() testQuery := s.queries[contractsorm.DriverPostgres] - query, err := testQuery.Query().(*QueryImpl).refreshConnection(test.model) + query, err := testQuery.Query().(*Query).refreshConnection(test.model) if test.expectErr != "" { s.EqualError(err, test.expectErr) } else { diff --git a/database/gorm/test_utils.go b/database/gorm/test_utils.go index 2685f8a5b..7eb816908 100644 --- a/database/gorm/test_utils.go +++ b/database/gorm/test_utils.go @@ -20,7 +20,7 @@ const ( TestModelNormal // Switch this value to control the test model. - TestModel = TestModelNormal + TestModel = TestModelMinimum ) type TestTable int @@ -229,19 +229,19 @@ func NewTestQuery(docker testing.DatabaseDriver, withPrefixAndSingular ...bool) } var ( - query *QueryImpl + query *Query err error ) if len(withPrefixAndSingular) > 0 && withPrefixAndSingular[0] { mockDriver.WithPrefixAndSingular() - query, err = InitializeQuery(testContext, mockConfig, docker.Driver().String()) + query, err = BuildQuery(testContext, mockConfig, docker.Driver().String()) } else { mockDriver.Common() - query, err = InitializeQuery(testContext, mockConfig, docker.Driver().String()) + query, err = BuildQuery(testContext, mockConfig, docker.Driver().String()) } if err != nil { - panic(fmt.Sprintf("connect to %s failed", docker.Driver().String())) + panic(fmt.Sprintf("connect to %s failed: %v", docker.Driver().String(), err)) } testQuery.query = query @@ -276,7 +276,7 @@ func (r *TestQuery) QueryOfReadWrite(config TestReadWriteConfig) (orm.Query, err mockDriver := GetMockDriver(r.Docker(), mockConfig, r.Docker().Driver().String()) mockDriver.ReadWrite(config) - return InitializeQuery(testContext, mockConfig, r.docker.Driver().String()) + return BuildQuery(testContext, mockConfig, r.docker.Driver().String()) } func GetMockDriver(docker testing.DatabaseDriver, mockConfig *mocksconfig.Config, connection string) testMockDriver { diff --git a/database/gorm/to_sql.go b/database/gorm/to_sql.go index 85a5b30dc..5591daaec 100644 --- a/database/gorm/to_sql.go +++ b/database/gorm/to_sql.go @@ -5,11 +5,11 @@ import ( ) type ToSql struct { - query *QueryImpl + query *Query raw bool } -func NewToSql(query *QueryImpl, raw bool) *ToSql { +func NewToSql(query *Query, raw bool) *ToSql { return &ToSql{ query: query, raw: raw, diff --git a/database/gorm/to_sql_test.go b/database/gorm/to_sql_test.go index 258ea7e69..23fe76abb 100644 --- a/database/gorm/to_sql_test.go +++ b/database/gorm/to_sql_test.go @@ -33,19 +33,19 @@ func (s *ToSqlTestSuite) SetupSuite() { func (s *ToSqlTestSuite) SetupTest() {} func (s *ToSqlTestSuite) TestCount() { - toSql := NewToSql(s.query.Model(&User{}).Where("id", 1).(*QueryImpl), false) + toSql := NewToSql(s.query.Model(&User{}).Where("id", 1).(*Query), false) s.Equal("SELECT count(*) FROM \"users\" WHERE \"id\" = $1 AND \"users\".\"deleted_at\" IS NULL", toSql.Count()) - toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*QueryImpl), true) + toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*Query), true) s.Equal("SELECT count(*) FROM \"users\" WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL", toSql.Count()) } func (s *ToSqlTestSuite) TestCreate() { user := User{Name: "to_sql_create"} - toSql := NewToSql(s.query.(*QueryImpl), false) + toSql := NewToSql(s.query.(*Query), false) s.Equal("INSERT INTO \"users\" (\"created_at\",\"updated_at\",\"deleted_at\",\"name\",\"bio\",\"avatar\") VALUES ($1,$2,$3,$4,$5,$6) RETURNING \"id\"", toSql.Create(&user)) - toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*QueryImpl), true) + toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*Query), true) s.Contains(toSql.Create(&user), "INSERT INTO \"users\" (\"created_at\",\"updated_at\",\"deleted_at\",\"name\",\"bio\",\"avatar\") VALUES (") s.Contains(toSql.Create(&user), ",NULL,'to_sql_create',NULL,'')") @@ -55,124 +55,124 @@ func (s *ToSqlTestSuite) TestCreate() { } func (s *ToSqlTestSuite) TestDelete() { - toSql := NewToSql(s.query.Where("id", 1).(*QueryImpl), false) + toSql := NewToSql(s.query.Where("id", 1).(*Query), false) s.Equal("UPDATE \"users\" SET \"deleted_at\"=$1 WHERE \"id\" = $2 AND \"users\".\"deleted_at\" IS NULL", toSql.Delete(&User{})) - toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), false) + toSql = NewToSql(s.query.Where("id", 1).(*Query), false) s.Equal("DELETE FROM \"roles\" WHERE \"id\" = $1", toSql.Delete(&Role{})) - toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), true) + toSql = NewToSql(s.query.Where("id", 1).(*Query), true) sql := toSql.Delete(&User{}) s.Contains(sql, "UPDATE \"users\" SET \"deleted_at\"=") s.Contains(sql, "WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL") - toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), true) + toSql = NewToSql(s.query.Where("id", 1).(*Query), true) s.Equal("DELETE FROM \"roles\" WHERE \"id\" = 1", toSql.Delete(&Role{})) } func (s *ToSqlTestSuite) TestFind() { - toSql := NewToSql(s.query.Where("id", 1).(*QueryImpl), false) + toSql := NewToSql(s.query.Where("id", 1).(*Query), false) s.Equal("SELECT * FROM \"users\" WHERE \"id\" = $1 AND \"users\".\"deleted_at\" IS NULL", toSql.Find(&User{})) - toSql = NewToSql(s.query.(*QueryImpl), false) + toSql = NewToSql(s.query.(*Query), false) s.Equal("SELECT * FROM \"users\" WHERE \"users\".\"id\" = $1 AND \"users\".\"deleted_at\" IS NULL", toSql.Find(&User{}, 1)) - toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), true) + toSql = NewToSql(s.query.Where("id", 1).(*Query), true) s.Equal("SELECT * FROM \"users\" WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL", toSql.Find(&User{})) - toSql = NewToSql(s.query.(*QueryImpl), true) + toSql = NewToSql(s.query.(*Query), true) s.Equal("SELECT * FROM \"users\" WHERE \"users\".\"id\" = 1 AND \"users\".\"deleted_at\" IS NULL", toSql.Find(&User{}, 1)) } func (s *ToSqlTestSuite) TestFirst() { - toSql := NewToSql(s.query.Where("id", 1).(*QueryImpl), false) + toSql := NewToSql(s.query.Where("id", 1).(*Query), false) s.Equal("SELECT * FROM \"users\" WHERE \"id\" = $1 AND \"users\".\"deleted_at\" IS NULL ORDER BY \"users\".\"id\" LIMIT $2", toSql.First(&User{})) - toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), true) + toSql = NewToSql(s.query.Where("id", 1).(*Query), true) s.Equal("SELECT * FROM \"users\" WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL ORDER BY \"users\".\"id\" LIMIT 1", toSql.First(&User{})) } func (s *ToSqlTestSuite) TestForceDelete() { - toSql := NewToSql(s.query.Where("id", 1).(*QueryImpl), false) + toSql := NewToSql(s.query.Where("id", 1).(*Query), false) s.Equal("DELETE FROM \"users\" WHERE \"id\" = $1", toSql.ForceDelete(&User{})) - toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), false) + toSql = NewToSql(s.query.Where("id", 1).(*Query), false) s.Equal("DELETE FROM \"roles\" WHERE \"id\" = $1", toSql.ForceDelete(&Role{})) - toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), true) + toSql = NewToSql(s.query.Where("id", 1).(*Query), true) s.Equal("DELETE FROM \"users\" WHERE \"id\" = 1", toSql.ForceDelete(&User{})) - toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), true) + toSql = NewToSql(s.query.Where("id", 1).(*Query), true) s.Equal("DELETE FROM \"roles\" WHERE \"id\" = 1", toSql.ForceDelete(&Role{})) } func (s *ToSqlTestSuite) TestGet() { - toSql := NewToSql(s.query.Where("id", 1).(*QueryImpl), false) + toSql := NewToSql(s.query.Where("id", 1).(*Query), false) s.Equal("SELECT * FROM \"users\" WHERE \"id\" = $1 AND \"users\".\"deleted_at\" IS NULL", toSql.Get([]User{})) - toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), true) + toSql = NewToSql(s.query.Where("id", 1).(*Query), true) s.Equal("SELECT * FROM \"users\" WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL", toSql.Get([]User{})) } func (s *ToSqlTestSuite) TestPluck() { - toSql := NewToSql(s.query.Where("id", 1).(*QueryImpl), false) + toSql := NewToSql(s.query.Where("id", 1).(*Query), false) s.Equal("SELECT \"id\" FROM \"users\" WHERE \"id\" = $1 AND \"users\".\"deleted_at\" IS NULL", toSql.Pluck("id", User{})) - toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), true) + toSql = NewToSql(s.query.Where("id", 1).(*Query), true) s.Equal("SELECT \"id\" FROM \"users\" WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL", toSql.Pluck("id", User{})) } func (s *ToSqlTestSuite) TestSave() { - toSql := NewToSql(s.query.Where("id", 1).(*QueryImpl), false) + toSql := NewToSql(s.query.Where("id", 1).(*Query), false) s.Equal("INSERT INTO \"users\" (\"created_at\",\"updated_at\",\"deleted_at\",\"name\",\"bio\",\"avatar\") VALUES ($1,$2,$3,$4,$5,$6) RETURNING \"id\"", toSql.Save(&User{})) - toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), true) + toSql = NewToSql(s.query.Where("id", 1).(*Query), true) sql := toSql.Save(&User{}) s.Contains(sql, "INSERT INTO \"users\" (\"created_at\",\"updated_at\",\"deleted_at\",\"name\",\"bio\",\"avatar\") VALUES (") s.Contains(sql, ",NULL,'',NULL,'')") } func (s *ToSqlTestSuite) TestSum() { - toSql := NewToSql(s.query.Where("id", 1).(*QueryImpl), false) + toSql := NewToSql(s.query.Where("id", 1).(*Query), false) s.Equal("SELECT SUM(id) FROM \"users\" WHERE \"id\" = $1 AND \"users\".\"deleted_at\" IS NULL", toSql.Sum("id", User{})) - toSql = NewToSql(s.query.Where("id", 1).(*QueryImpl), true) + toSql = NewToSql(s.query.Where("id", 1).(*Query), true) s.Equal("SELECT SUM(id) FROM \"users\" WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL", toSql.Sum("id", User{})) } func (s *ToSqlTestSuite) TestUpdate() { - toSql := NewToSql(s.query.Model(&User{}).Where("id", 1).(*QueryImpl), false) + toSql := NewToSql(s.query.Model(&User{}).Where("id", 1).(*Query), false) s.Equal("UPDATE \"users\" SET \"name\"=$1,\"updated_at\"=$2 WHERE \"id\" = $3 AND \"users\".\"deleted_at\" IS NULL", toSql.Update("name", "goravel")) - toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*QueryImpl), true) + toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*Query), true) sql := toSql.Update("name", "goravel") s.Contains(sql, "UPDATE \"users\" SET \"name\"='goravel',\"updated_at\"=") s.Contains(sql, "WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL") - toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*QueryImpl), false) + toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*Query), false) s.Empty(toSql.Update(0, "goravel")) - toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*QueryImpl), true) + toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*Query), true) s.Empty(toSql.Update(0, "goravel")) - toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*QueryImpl), false) + toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*Query), false) s.Equal("UPDATE \"users\" SET \"name\"=$1,\"updated_at\"=$2 WHERE \"id\" = $3 AND \"users\".\"deleted_at\" IS NULL", toSql.Update(map[string]any{ "name": "goravel", })) - toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*QueryImpl), true) + toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*Query), true) sql = toSql.Update(map[string]any{ "name": "goravel", }) s.Contains(sql, "UPDATE \"users\" SET \"name\"='goravel',\"updated_at\"=") s.Contains(sql, "WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL") - toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*QueryImpl), false) + toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*Query), false) s.Equal("UPDATE \"users\" SET \"updated_at\"=$1,\"name\"=$2 WHERE \"id\" = $3 AND \"users\".\"deleted_at\" IS NULL", toSql.Update(User{ Name: "goravel", })) - toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*QueryImpl), true) + toSql = NewToSql(s.query.Model(&User{}).Where("id", 1).(*Query), true) sql = toSql.Update(User{ Name: "goravel", }) diff --git a/database/gorm/wire.go b/database/gorm/wire.go deleted file mode 100644 index c43157861..000000000 --- a/database/gorm/wire.go +++ /dev/null @@ -1,29 +0,0 @@ -//go:build wireinject -// +build wireinject - -// The build tag makes sure the stub is not built in the final build. - -package gorm - -import ( - "context" - - "github.com/google/wire" - - "github.com/goravel/framework/contracts/config" - "github.com/goravel/framework/database/db" -) - -//go:generate wire -func InitializeGorm(config config.Config, connection string) *GormImpl { - wire.Build(NewGormImpl, db.ConfigSet, DialectorSet) - - return nil -} - -//go:generate wire -func InitializeQuery(ctx context.Context, config config.Config, connection string) (*QueryImpl, error) { - wire.Build(BuildQueryImpl, GormSet, db.ConfigSet, DialectorSet) - - return nil, nil -} diff --git a/database/gorm/wire_gen.go b/database/gorm/wire_gen.go deleted file mode 100644 index 7147c077a..000000000 --- a/database/gorm/wire_gen.go +++ /dev/null @@ -1,35 +0,0 @@ -// Code generated by Wire. DO NOT EDIT. - -//go:generate go run github.com/google/wire/cmd/wire -//go:build !wireinject -// +build !wireinject - -package gorm - -import ( - "context" - "github.com/goravel/framework/contracts/config" - "github.com/goravel/framework/database/db" -) - -// Injectors from wire.go: - -//go:generate wire -func InitializeGorm(config2 config.Config, connection string) *GormImpl { - configImpl := db.NewConfigImpl(config2, connection) - dialectorImpl := NewDialectorImpl(config2, connection) - gormImpl := NewGormImpl(config2, connection, configImpl, dialectorImpl) - return gormImpl -} - -//go:generate wire -func InitializeQuery(ctx context.Context, config2 config.Config, connection string) (*QueryImpl, error) { - configImpl := db.NewConfigImpl(config2, connection) - dialectorImpl := NewDialectorImpl(config2, connection) - gormImpl := NewGormImpl(config2, connection, configImpl, dialectorImpl) - queryImpl, err := BuildQueryImpl(ctx, config2, connection, gormImpl) - if err != nil { - return nil, err - } - return queryImpl, nil -} diff --git a/database/orm.go b/database/orm.go index f5ba723fa..02675015f 100644 --- a/database/orm.go +++ b/database/orm.go @@ -5,43 +5,48 @@ import ( "database/sql" "fmt" - "github.com/pkg/errors" - "github.com/goravel/framework/contracts/config" - ormcontract "github.com/goravel/framework/contracts/database/orm" - databasegorm "github.com/goravel/framework/database/gorm" + contractsorm "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/database/gorm" "github.com/goravel/framework/database/orm" "github.com/goravel/framework/support/color" ) -var _ ormcontract.Orm = (*OrmImpl)(nil) - -type OrmImpl struct { +type Orm struct { ctx context.Context config config.Config connection string - query ormcontract.Query - queries map[string]ormcontract.Query + query contractsorm.Query + queries map[string]contractsorm.Query } -func NewOrmImpl(ctx context.Context, config config.Config, connection string, query ormcontract.Query) (*OrmImpl, error) { - return &OrmImpl{ +func NewOrm(ctx context.Context, config config.Config, connection string, query contractsorm.Query) (*Orm, error) { + return &Orm{ ctx: ctx, config: config, connection: connection, query: query, - queries: map[string]ormcontract.Query{ + queries: map[string]contractsorm.Query{ connection: query, }, }, nil } -func (r *OrmImpl) Connection(name string) ormcontract.Orm { +func BuildOrm(ctx context.Context, config config.Config, connection string) (*Orm, error) { + query, err := gorm.BuildQuery(ctx, config, connection) + if err != nil { + return nil, fmt.Errorf("[Orm] Build query for %s connection error: %v", connection, err) + } + + return NewOrm(ctx, config, connection, query) +} + +func (r *Orm) Connection(name string) contractsorm.Orm { if name == "" { name = r.config.GetString("database.default") } if instance, exist := r.queries[name]; exist { - return &OrmImpl{ + return &Orm{ ctx: r.ctx, config: r.config, connection: name, @@ -50,7 +55,7 @@ func (r *OrmImpl) Connection(name string) ormcontract.Orm { } } - queue, err := databasegorm.InitializeQuery(r.ctx, r.config, name) + queue, err := gorm.BuildQuery(r.ctx, r.config, name) if err != nil || queue == nil { color.Red().Println(fmt.Sprintf("[Orm] Init %s connection error: %v", name, err)) @@ -59,7 +64,7 @@ func (r *OrmImpl) Connection(name string) ormcontract.Orm { r.queries[name] = queue - return &OrmImpl{ + return &Orm{ ctx: r.ctx, config: r.config, connection: name, @@ -68,28 +73,32 @@ func (r *OrmImpl) Connection(name string) ormcontract.Orm { } } -func (r *OrmImpl) DB() (*sql.DB, error) { - query := r.Query().(*databasegorm.QueryImpl) +func (r *Orm) DB() (*sql.DB, error) { + query := r.Query().(*gorm.Query) return query.Instance().DB() } -func (r *OrmImpl) Query() ormcontract.Query { +func (r *Orm) Query() contractsorm.Query { return r.query } -func (r *OrmImpl) Factory() ormcontract.Factory { +func (r *Orm) Factory() contractsorm.Factory { return NewFactoryImpl(r.Query()) } -func (r *OrmImpl) Observe(model any, observer ormcontract.Observer) { +func (r *Orm) Observe(model any, observer contractsorm.Observer) { orm.Observers = append(orm.Observers, orm.Observer{ Model: model, Observer: observer, }) } -func (r *OrmImpl) Transaction(txFunc func(tx ormcontract.Query) error) error { +func (r *Orm) Refresh() { + appFacade.Refresh(BindingOrm) +} + +func (r *Orm) Transaction(txFunc func(tx contractsorm.Query) error) error { tx, err := r.Query().Begin() if err != nil { return err @@ -97,7 +106,7 @@ func (r *OrmImpl) Transaction(txFunc func(tx ormcontract.Query) error) error { if err := txFunc(tx); err != nil { if err := tx.Rollback(); err != nil { - return errors.Wrapf(err, "rollback error: %v", err) + return fmt.Errorf("rollback error: %v", err) } return err @@ -106,16 +115,16 @@ func (r *OrmImpl) Transaction(txFunc func(tx ormcontract.Query) error) error { } } -func (r *OrmImpl) WithContext(ctx context.Context) ormcontract.Orm { +func (r *Orm) WithContext(ctx context.Context) contractsorm.Orm { for _, query := range r.queries { - query := query.(*databasegorm.QueryImpl) + query := query.(*gorm.Query) query.SetContext(ctx) } - query := r.query.(*databasegorm.QueryImpl) + query := r.query.(*gorm.Query) query.SetContext(ctx) - return &OrmImpl{ + return &Orm{ ctx: ctx, config: r.config, connection: r.connection, diff --git a/database/orm_test.go b/database/orm_test.go index 604508caf..ea178cef3 100644 --- a/database/orm_test.go +++ b/database/orm_test.go @@ -27,7 +27,7 @@ type User struct { type OrmSuite struct { suite.Suite - orm *OrmImpl + orm *Orm testQueries map[contractsorm.Driver]*gorm.TestQuery } @@ -50,7 +50,7 @@ func (s *OrmSuite) SetupTest() { queries[key.String()] = query.Query() } - s.orm = &OrmImpl{ + s.orm = &Orm{ connection: contractsorm.DriverPostgres.String(), ctx: context.Background(), query: queries[contractsorm.DriverPostgres.String()], diff --git a/database/service_provider.go b/database/service_provider.go index 5d7a361a0..06110b797 100644 --- a/database/service_provider.go +++ b/database/service_provider.go @@ -14,17 +14,19 @@ const BindingOrm = "goravel.orm" const BindingSchema = "goravel.schema" const BindingSeeder = "goravel.seeder" +var appFacade foundation.Application + type ServiceProvider struct { } -func (database *ServiceProvider) Register(app foundation.Application) { +func (r *ServiceProvider) Register(app foundation.Application) { app.Singleton(BindingOrm, func(app foundation.Application) (any, error) { + ctx := context.Background() config := app.MakeConfig() - defaultConnection := config.GetString("database.default") - - orm, err := InitializeOrm(context.Background(), config, defaultConnection) + connection := config.GetString("database.default") + orm, err := BuildOrm(ctx, config, connection) if err != nil { - return nil, fmt.Errorf("[Orm] Init %s connection error: %v", defaultConnection, err) + return nil, fmt.Errorf("[Orm] Init %s connection error: %v", connection, err) } return orm, nil @@ -46,11 +48,12 @@ func (database *ServiceProvider) Register(app foundation.Application) { }) } -func (database *ServiceProvider) Boot(app foundation.Application) { - database.registerCommands(app) +func (r *ServiceProvider) Boot(app foundation.Application) { + appFacade = app + r.registerCommands(app) } -func (database *ServiceProvider) registerCommands(app foundation.Application) { +func (r *ServiceProvider) registerCommands(app foundation.Application) { if artisanFacade := app.MakeArtisan(); artisanFacade != nil { config := app.MakeConfig() seeder := app.MakeSeeder() diff --git a/database/wire.go b/database/wire.go deleted file mode 100644 index e7fdb6260..000000000 --- a/database/wire.go +++ /dev/null @@ -1,23 +0,0 @@ -//go:build wireinject -// +build wireinject - -// The build tag makes sure the stub is not built in the final build. - -package database - -import ( - "context" - - "github.com/google/wire" - - "github.com/goravel/framework/contracts/config" - "github.com/goravel/framework/database/db" - "github.com/goravel/framework/database/gorm" -) - -//go:generate wire -func InitializeOrm(ctx context.Context, config config.Config, connection string) (*OrmImpl, error) { - wire.Build(NewOrmImpl, gorm.QuerySet, gorm.GormSet, db.ConfigSet, gorm.DialectorSet) - - return nil, nil -} diff --git a/database/wire_gen.go b/database/wire_gen.go deleted file mode 100644 index f78d32167..000000000 --- a/database/wire_gen.go +++ /dev/null @@ -1,32 +0,0 @@ -// Code generated by Wire. DO NOT EDIT. - -//go:generate go run github.com/google/wire/cmd/wire -//go:build !wireinject -// +build !wireinject - -package database - -import ( - "context" - "github.com/goravel/framework/contracts/config" - "github.com/goravel/framework/database/db" - "github.com/goravel/framework/database/gorm" -) - -// Injectors from wire.go: - -//go:generate wire -func InitializeOrm(ctx context.Context, config2 config.Config, connection string) (*OrmImpl, error) { - configImpl := db.NewConfigImpl(config2, connection) - dialectorImpl := gorm.NewDialectorImpl(config2, connection) - gormImpl := gorm.NewGormImpl(config2, connection, configImpl, dialectorImpl) - queryImpl, err := gorm.BuildQueryImpl(ctx, config2, connection, gormImpl) - if err != nil { - return nil, err - } - ormImpl, err := NewOrmImpl(ctx, config2, connection, queryImpl) - if err != nil { - return nil, err - } - return ormImpl, nil -} diff --git a/database/wire_interface.go b/database/wire_interface.go deleted file mode 100644 index a837d7212..000000000 --- a/database/wire_interface.go +++ /dev/null @@ -1,24 +0,0 @@ -package database - -import ( - "context" - - "github.com/goravel/framework/contracts/config" - contractsgorm "github.com/goravel/framework/contracts/database/gorm" - "github.com/goravel/framework/contracts/database/orm" - "github.com/goravel/framework/database/gorm" -) - -type InitializeImpl struct{} - -func NewInitializeImpl() *InitializeImpl { - return &InitializeImpl{} -} - -func (receive *InitializeImpl) InitializeGorm(config config.Config, connection string) contractsgorm.Gorm { - return gorm.InitializeGorm(config, connection) -} - -func (receive *InitializeImpl) InitializeQuery(ctx context.Context, config config.Config, connection string) (orm.Query, error) { - return gorm.InitializeQuery(ctx, config, connection) -} diff --git a/foundation/container.go b/foundation/container.go index 8f589314e..1a2c46306 100644 --- a/foundation/container.go +++ b/foundation/container.go @@ -331,6 +331,10 @@ func (c *Container) MakeWith(key any, parameters map[string]any) (any, error) { return c.make(key, parameters) } +func (c *Container) Refresh(key any) { + c.instances.Delete(key) +} + func (c *Container) Singleton(key any, callback func(app contractsfoundation.Application) (any, error)) { c.bindings.Store(key, instance{concrete: callback, shared: true}) } diff --git a/go.mod b/go.mod index 863378dc9..3a78e4779 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,6 @@ require ( github.com/golang-module/carbon/v2 v2.3.12 github.com/golang/protobuf v1.5.4 github.com/google/uuid v1.6.0 - github.com/google/wire v0.6.0 github.com/gookit/validate v1.5.2 github.com/goravel/file-rotatelogs/v2 v2.4.2 github.com/hashicorp/go-multierror v1.1.1 diff --git a/go.sum b/go.sum index 0b3edb001..ce957626e 100644 --- a/go.sum +++ b/go.sum @@ -341,14 +341,11 @@ github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S3 github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= -github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/wire v0.6.0 h1:HBkoIh4BdSxoyo9PveV8giw7ZsaBOvzWKfcg/6MrVwI= -github.com/google/wire v0.6.0/go.mod h1:F4QhpQ9EDIdJ1Mbop/NZBRB+5yrR6qg3BnctaoUk6NA= github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= @@ -659,8 +656,6 @@ golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58 golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= -golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= -golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -700,8 +695,6 @@ golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -747,8 +740,6 @@ golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= -golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= -golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -775,8 +766,6 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -837,8 +826,6 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -849,8 +836,6 @@ golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= -golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= -golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM= golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -866,7 +851,6 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= @@ -930,8 +914,6 @@ golang.org/x/tools v0.0.0-20210108195828-e2f9c7f1fc8e/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= -golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/mocks/database/Configs.go b/mocks/database/Configs.go new file mode 100644 index 000000000..aad275f98 --- /dev/null +++ b/mocks/database/Configs.go @@ -0,0 +1,129 @@ +// Code generated by mockery. DO NOT EDIT. + +package database + +import ( + database "github.com/goravel/framework/contracts/database" + mock "github.com/stretchr/testify/mock" +) + +// Configs is an autogenerated mock type for the Configs type +type Configs struct { + mock.Mock +} + +type Configs_Expecter struct { + mock *mock.Mock +} + +func (_m *Configs) EXPECT() *Configs_Expecter { + return &Configs_Expecter{mock: &_m.Mock} +} + +// Reads provides a mock function with given fields: +func (_m *Configs) Reads() []database.FullConfig { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Reads") + } + + var r0 []database.FullConfig + if rf, ok := ret.Get(0).(func() []database.FullConfig); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]database.FullConfig) + } + } + + return r0 +} + +// Configs_Reads_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Reads' +type Configs_Reads_Call struct { + *mock.Call +} + +// Reads is a helper method to define mock.On call +func (_e *Configs_Expecter) Reads() *Configs_Reads_Call { + return &Configs_Reads_Call{Call: _e.mock.On("Reads")} +} + +func (_c *Configs_Reads_Call) Run(run func()) *Configs_Reads_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Configs_Reads_Call) Return(_a0 []database.FullConfig) *Configs_Reads_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Configs_Reads_Call) RunAndReturn(run func() []database.FullConfig) *Configs_Reads_Call { + _c.Call.Return(run) + return _c +} + +// Writes provides a mock function with given fields: +func (_m *Configs) Writes() []database.FullConfig { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Writes") + } + + var r0 []database.FullConfig + if rf, ok := ret.Get(0).(func() []database.FullConfig); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]database.FullConfig) + } + } + + return r0 +} + +// Configs_Writes_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Writes' +type Configs_Writes_Call struct { + *mock.Call +} + +// Writes is a helper method to define mock.On call +func (_e *Configs_Expecter) Writes() *Configs_Writes_Call { + return &Configs_Writes_Call{Call: _e.mock.On("Writes")} +} + +func (_c *Configs_Writes_Call) Run(run func()) *Configs_Writes_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Configs_Writes_Call) Return(_a0 []database.FullConfig) *Configs_Writes_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Configs_Writes_Call) RunAndReturn(run func() []database.FullConfig) *Configs_Writes_Call { + _c.Call.Return(run) + return _c +} + +// NewConfigs creates a new instance of Configs. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewConfigs(t interface { + mock.TestingT + Cleanup(func()) +}) *Configs { + mock := &Configs{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/mocks/database/gorm/Gorm.go b/mocks/database/gorm/Gorm.go deleted file mode 100644 index b9be532b4..000000000 --- a/mocks/database/gorm/Gorm.go +++ /dev/null @@ -1,92 +0,0 @@ -// Code generated by mockery. DO NOT EDIT. - -package gorm - -import ( - mock "github.com/stretchr/testify/mock" - gorm "gorm.io/gorm" -) - -// Gorm is an autogenerated mock type for the Gorm type -type Gorm struct { - mock.Mock -} - -type Gorm_Expecter struct { - mock *mock.Mock -} - -func (_m *Gorm) EXPECT() *Gorm_Expecter { - return &Gorm_Expecter{mock: &_m.Mock} -} - -// Make provides a mock function with given fields: -func (_m *Gorm) Make() (*gorm.DB, error) { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for Make") - } - - var r0 *gorm.DB - var r1 error - if rf, ok := ret.Get(0).(func() (*gorm.DB, error)); ok { - return rf() - } - if rf, ok := ret.Get(0).(func() *gorm.DB); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*gorm.DB) - } - } - - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// Gorm_Make_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Make' -type Gorm_Make_Call struct { - *mock.Call -} - -// Make is a helper method to define mock.On call -func (_e *Gorm_Expecter) Make() *Gorm_Make_Call { - return &Gorm_Make_Call{Call: _e.mock.On("Make")} -} - -func (_c *Gorm_Make_Call) Run(run func()) *Gorm_Make_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *Gorm_Make_Call) Return(_a0 *gorm.DB, _a1 error) *Gorm_Make_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *Gorm_Make_Call) RunAndReturn(run func() (*gorm.DB, error)) *Gorm_Make_Call { - _c.Call.Return(run) - return _c -} - -// NewGorm creates a new instance of Gorm. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewGorm(t interface { - mock.TestingT - Cleanup(func()) -}) *Gorm { - mock := &Gorm{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/mocks/database/gorm/Initialize.go b/mocks/database/gorm/Initialize.go deleted file mode 100644 index 018d812ca..000000000 --- a/mocks/database/gorm/Initialize.go +++ /dev/null @@ -1,151 +0,0 @@ -// Code generated by mockery. DO NOT EDIT. - -package gorm - -import ( - context "context" - - config "github.com/goravel/framework/contracts/config" - - gorm "github.com/goravel/framework/contracts/database/gorm" - - mock "github.com/stretchr/testify/mock" - - orm "github.com/goravel/framework/contracts/database/orm" -) - -// Initialize is an autogenerated mock type for the Initialize type -type Initialize struct { - mock.Mock -} - -type Initialize_Expecter struct { - mock *mock.Mock -} - -func (_m *Initialize) EXPECT() *Initialize_Expecter { - return &Initialize_Expecter{mock: &_m.Mock} -} - -// InitializeGorm provides a mock function with given fields: _a0, connection -func (_m *Initialize) InitializeGorm(_a0 config.Config, connection string) gorm.Gorm { - ret := _m.Called(_a0, connection) - - if len(ret) == 0 { - panic("no return value specified for InitializeGorm") - } - - var r0 gorm.Gorm - if rf, ok := ret.Get(0).(func(config.Config, string) gorm.Gorm); ok { - r0 = rf(_a0, connection) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(gorm.Gorm) - } - } - - return r0 -} - -// Initialize_InitializeGorm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InitializeGorm' -type Initialize_InitializeGorm_Call struct { - *mock.Call -} - -// InitializeGorm is a helper method to define mock.On call -// - _a0 config.Config -// - connection string -func (_e *Initialize_Expecter) InitializeGorm(_a0 interface{}, connection interface{}) *Initialize_InitializeGorm_Call { - return &Initialize_InitializeGorm_Call{Call: _e.mock.On("InitializeGorm", _a0, connection)} -} - -func (_c *Initialize_InitializeGorm_Call) Run(run func(_a0 config.Config, connection string)) *Initialize_InitializeGorm_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(config.Config), args[1].(string)) - }) - return _c -} - -func (_c *Initialize_InitializeGorm_Call) Return(_a0 gorm.Gorm) *Initialize_InitializeGorm_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *Initialize_InitializeGorm_Call) RunAndReturn(run func(config.Config, string) gorm.Gorm) *Initialize_InitializeGorm_Call { - _c.Call.Return(run) - return _c -} - -// InitializeQuery provides a mock function with given fields: ctx, _a1, connection -func (_m *Initialize) InitializeQuery(ctx context.Context, _a1 config.Config, connection string) (orm.Query, error) { - ret := _m.Called(ctx, _a1, connection) - - if len(ret) == 0 { - panic("no return value specified for InitializeQuery") - } - - var r0 orm.Query - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, config.Config, string) (orm.Query, error)); ok { - return rf(ctx, _a1, connection) - } - if rf, ok := ret.Get(0).(func(context.Context, config.Config, string) orm.Query); ok { - r0 = rf(ctx, _a1, connection) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(orm.Query) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, config.Config, string) error); ok { - r1 = rf(ctx, _a1, connection) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// Initialize_InitializeQuery_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InitializeQuery' -type Initialize_InitializeQuery_Call struct { - *mock.Call -} - -// InitializeQuery is a helper method to define mock.On call -// - ctx context.Context -// - _a1 config.Config -// - connection string -func (_e *Initialize_Expecter) InitializeQuery(ctx interface{}, _a1 interface{}, connection interface{}) *Initialize_InitializeQuery_Call { - return &Initialize_InitializeQuery_Call{Call: _e.mock.On("InitializeQuery", ctx, _a1, connection)} -} - -func (_c *Initialize_InitializeQuery_Call) Run(run func(ctx context.Context, _a1 config.Config, connection string)) *Initialize_InitializeQuery_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(config.Config), args[2].(string)) - }) - return _c -} - -func (_c *Initialize_InitializeQuery_Call) Return(_a0 orm.Query, _a1 error) *Initialize_InitializeQuery_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *Initialize_InitializeQuery_Call) RunAndReturn(run func(context.Context, config.Config, string) (orm.Query, error)) *Initialize_InitializeQuery_Call { - _c.Call.Return(run) - return _c -} - -// NewInitialize creates a new instance of Initialize. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewInitialize(t interface { - mock.TestingT - Cleanup(func()) -}) *Initialize { - mock := &Initialize{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/mocks/database/orm/Orm.go b/mocks/database/orm/Orm.go index 932c6d837..0bfbe5d6e 100644 --- a/mocks/database/orm/Orm.go +++ b/mocks/database/orm/Orm.go @@ -257,6 +257,38 @@ func (_c *Orm_Query_Call) RunAndReturn(run func() orm.Query) *Orm_Query_Call { return _c } +// Refresh provides a mock function with given fields: +func (_m *Orm) Refresh() { + _m.Called() +} + +// Orm_Refresh_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Refresh' +type Orm_Refresh_Call struct { + *mock.Call +} + +// Refresh is a helper method to define mock.On call +func (_e *Orm_Expecter) Refresh() *Orm_Refresh_Call { + return &Orm_Refresh_Call{Call: _e.mock.On("Refresh")} +} + +func (_c *Orm_Refresh_Call) Run(run func()) *Orm_Refresh_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Orm_Refresh_Call) Return() *Orm_Refresh_Call { + _c.Call.Return() + return _c +} + +func (_c *Orm_Refresh_Call) RunAndReturn(run func()) *Orm_Refresh_Call { + _c.Call.Return(run) + return _c +} + // Transaction provides a mock function with given fields: txFunc func (_m *Orm) Transaction(txFunc func(orm.Query) error) error { ret := _m.Called(txFunc) diff --git a/mocks/foundation/Application.go b/mocks/foundation/Application.go index 9ffdcc2c6..64328820a 100644 --- a/mocks/foundation/Application.go +++ b/mocks/foundation/Application.go @@ -2084,6 +2084,39 @@ func (_c *Application_Publishes_Call) RunAndReturn(run func(string, map[string]s return _c } +// Refresh provides a mock function with given fields: key +func (_m *Application) Refresh(key interface{}) { + _m.Called(key) +} + +// Application_Refresh_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Refresh' +type Application_Refresh_Call struct { + *mock.Call +} + +// Refresh is a helper method to define mock.On call +// - key interface{} +func (_e *Application_Expecter) Refresh(key interface{}) *Application_Refresh_Call { + return &Application_Refresh_Call{Call: _e.mock.On("Refresh", key)} +} + +func (_c *Application_Refresh_Call) Run(run func(key interface{})) *Application_Refresh_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *Application_Refresh_Call) Return() *Application_Refresh_Call { + _c.Call.Return() + return _c +} + +func (_c *Application_Refresh_Call) RunAndReturn(run func(interface{})) *Application_Refresh_Call { + _c.Call.Return(run) + return _c +} + // SetJson provides a mock function with given fields: json func (_m *Application) SetJson(json foundation.Json) { _m.Called(json) diff --git a/testing/docker/database.go b/testing/docker/database.go index 67205011a..d731e1ee1 100644 --- a/testing/docker/database.go +++ b/testing/docker/database.go @@ -1,24 +1,23 @@ package docker import ( - "context" "fmt" contractsconfig "github.com/goravel/framework/contracts/config" contractsconsole "github.com/goravel/framework/contracts/console" + contractsorm "github.com/goravel/framework/contracts/database/orm" "github.com/goravel/framework/contracts/database/seeder" "github.com/goravel/framework/contracts/foundation" "github.com/goravel/framework/contracts/testing" - frameworkdatabase "github.com/goravel/framework/database" supportdocker "github.com/goravel/framework/support/docker" ) type Database struct { testing.DatabaseDriver - app foundation.Application artisan contractsconsole.Artisan config contractsconfig.Config connection string + orm contractsorm.Orm } func NewDatabase(app foundation.Application, connection string) (*Database, error) { @@ -43,39 +42,27 @@ func NewDatabase(app foundation.Application, connection string) (*Database, erro databaseDriver := supportdocker.DatabaseDriver(supportdocker.ContainerType(driver), database, username, password) return &Database{ - app: app, + DatabaseDriver: databaseDriver, artisan: artisanFacade, config: config, connection: connection, - DatabaseDriver: databaseDriver, + orm: app.MakeOrm(), }, nil } -func (receiver *Database) Build() error { - if err := receiver.DatabaseDriver.Build(); err != nil { +func (r *Database) Build() error { + if err := r.DatabaseDriver.Build(); err != nil { return err } - receiver.config.Add(fmt.Sprintf("database.connections.%s.port", receiver.connection), receiver.DatabaseDriver.Config().Port) - receiver.artisan.Call("migrate") - - // TODO Find a better way to refresh the database connection - receiver.app.Singleton(frameworkdatabase.BindingOrm, func(app foundation.Application) (any, error) { - config := app.MakeConfig() - defaultConnection := config.GetString("database.default") - - orm, err := frameworkdatabase.InitializeOrm(context.Background(), config, defaultConnection) - if err != nil { - return nil, fmt.Errorf("[Orm] Init %s connection error: %v", defaultConnection, err) - } - - return orm, nil - }) + r.config.Add(fmt.Sprintf("database.connections.%s.port", r.connection), r.DatabaseDriver.Config().Port) + r.artisan.Call("migrate") + r.orm.Refresh() return nil } -func (receiver *Database) Seed(seeds ...seeder.Seeder) { +func (r *Database) Seed(seeds ...seeder.Seeder) { command := "db:seed" if len(seeds) > 0 { command += " --seeder" @@ -84,5 +71,5 @@ func (receiver *Database) Seed(seeds ...seeder.Seeder) { } } - receiver.artisan.Call(command) + r.artisan.Call(command) } From 2197e0dfc264b00ff46c24d69c24555ca3da5a5d Mon Sep 17 00:00:00 2001 From: Bowen Date: Tue, 1 Oct 2024 11:58:30 +0800 Subject: [PATCH 02/11] remove driver --- contracts/database/orm/constants.go | 15 -- contracts/database/orm/orm.go | 4 +- contracts/testing/testing.go | 4 +- database/console/migrate.go | 12 +- database/console/migrate_creator.go | 10 +- database/console/test_utils.go | 12 +- database/gorm/query.go | 179 +++++++-------- database/gorm/query_test.go | 25 ++- database/gorm/test_utils.go | 322 +++++++++++++-------------- database/migration/blueprint_test.go | 10 +- database/migration/schema.go | 9 +- database/migration/schema_test.go | 13 +- database/migration/sql_driver.go | 8 +- database/orm_test.go | 7 +- foundation/application_test.go | 78 +++---- mocks/database/orm/Query.go | 16 +- mocks/testing/Database.go | 17 +- mocks/testing/DatabaseDriver.go | 17 +- support/docker/mysql.go | 6 +- support/docker/mysql_test.go | 4 +- support/docker/postgres.go | 6 +- support/docker/postgres_test.go | 4 +- support/docker/sqlite.go | 6 +- support/docker/sqlite_test.go | 4 +- support/docker/sqlserver.go | 6 +- support/docker/sqlserver_test.go | 2 +- testing/docker/database_test.go | 18 +- 27 files changed, 402 insertions(+), 412 deletions(-) delete mode 100644 contracts/database/orm/constants.go diff --git a/contracts/database/orm/constants.go b/contracts/database/orm/constants.go deleted file mode 100644 index 1217b0c32..000000000 --- a/contracts/database/orm/constants.go +++ /dev/null @@ -1,15 +0,0 @@ -package orm - -// DEPRECATED Move to contracts/database/config.go -const ( - DriverMysql Driver = "mysql" - DriverPostgres Driver = "postgres" - DriverSqlite Driver = "sqlite" - DriverSqlserver Driver = "sqlserver" -) - -type Driver string - -func (d Driver) String() string { - return string(d) -} diff --git a/contracts/database/orm/orm.go b/contracts/database/orm/orm.go index e15846ba7..29b53da8e 100644 --- a/contracts/database/orm/orm.go +++ b/contracts/database/orm/orm.go @@ -3,6 +3,8 @@ package orm import ( "context" "database/sql" + + "github.com/goravel/framework/contracts/database" ) type Orm interface { @@ -42,7 +44,7 @@ type Query interface { // Distinct specifies distinct fields to query. Distinct(args ...any) Query // Driver gets the driver for the query. - Driver() Driver + Driver() database.Driver // Exec executes raw sql Exec(sql string, values ...any) (*Result, error) // Exists returns true if matching records exist; otherwise, it returns false. diff --git a/contracts/testing/testing.go b/contracts/testing/testing.go index 859eae46e..8ba8ed8fd 100644 --- a/contracts/testing/testing.go +++ b/contracts/testing/testing.go @@ -1,7 +1,7 @@ package testing import ( - "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/contracts/database/seeder" ) @@ -31,7 +31,7 @@ type DatabaseDriver interface { // Image gets the database image. Image(image Image) // Driver gets the database driver name. - Driver() orm.Driver + Driver() database.Driver // Stop the database. Stop() error } diff --git a/database/console/migrate.go b/database/console/migrate.go index 2b449e396..ea952557e 100644 --- a/database/console/migrate.go +++ b/database/console/migrate.go @@ -11,7 +11,7 @@ import ( "github.com/golang-migrate/migrate/v4/database/sqlserver" "github.com/goravel/framework/contracts/config" - "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/database/console/driver" databasedb "github.com/goravel/framework/database/db" "github.com/goravel/framework/support" @@ -31,8 +31,8 @@ func getMigrate(config config.Config) (*migrate.Migrate, error) { return nil, errors.New("not found database configuration") } - switch orm.Driver(driver) { - case orm.DriverMysql: + switch database.Driver(driver) { + case database.DriverMysql: mysqlDsn := databasedb.Dsn(writeConfigs[0]) if mysqlDsn == "" { return nil, nil @@ -51,7 +51,7 @@ func getMigrate(config config.Config) (*migrate.Migrate, error) { } return migrate.NewWithDatabaseInstance(dir, "mysql", instance) - case orm.DriverPostgres: + case database.DriverPostgres: postgresDsn := databasedb.Dsn(writeConfigs[0]) if postgresDsn == "" { return nil, nil @@ -70,7 +70,7 @@ func getMigrate(config config.Config) (*migrate.Migrate, error) { } return migrate.NewWithDatabaseInstance(dir, "postgres", instance) - case orm.DriverSqlite: + case database.DriverSqlite: sqliteDsn := databasedb.Dsn(writeConfigs[0]) if sqliteDsn == "" { return nil, nil @@ -89,7 +89,7 @@ func getMigrate(config config.Config) (*migrate.Migrate, error) { } return migrate.NewWithDatabaseInstance(dir, "sqlite3", instance) - case orm.DriverSqlserver: + case database.DriverSqlserver: sqlserverDsn := databasedb.Dsn(writeConfigs[0]) if sqlserverDsn == "" { return nil, nil diff --git a/database/console/migrate_creator.go b/database/console/migrate_creator.go index f50673e5d..de0ec6380 100644 --- a/database/console/migrate_creator.go +++ b/database/console/migrate_creator.go @@ -6,7 +6,7 @@ import ( "strings" "github.com/goravel/framework/contracts/config" - "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/database/migration" "github.com/goravel/framework/support/carbon" "github.com/goravel/framework/support/file" @@ -49,20 +49,20 @@ func (receiver *MigrateCreator) getStub(table string, create bool) (string, stri } driver := receiver.config.GetString("database.connections." + receiver.config.GetString("database.default") + ".driver") - switch orm.Driver(driver) { - case orm.DriverPostgres: + switch database.Driver(driver) { + case database.DriverPostgres: if create { return migration.PostgresStubs{}.CreateUp(), migration.PostgresStubs{}.CreateDown() } return migration.PostgresStubs{}.UpdateUp(), migration.PostgresStubs{}.UpdateDown() - case orm.DriverSqlite: + case database.DriverSqlite: if create { return migration.SqliteStubs{}.CreateUp(), migration.SqliteStubs{}.CreateDown() } return migration.SqliteStubs{}.UpdateUp(), migration.SqliteStubs{}.UpdateDown() - case orm.DriverSqlserver: + case database.DriverSqlserver: if create { return migration.SqlserverStubs{}.CreateUp(), migration.SqlserverStubs{}.CreateDown() } diff --git a/database/console/test_utils.go b/database/console/test_utils.go index 2d8658783..f07b8c07c 100644 --- a/database/console/test_utils.go +++ b/database/console/test_utils.go @@ -1,19 +1,19 @@ package console import ( - contractsorm "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/support/file" ) -func createMigrations(driver contractsorm.Driver) { +func createMigrations(driver database.Driver) { switch driver { - case contractsorm.DriverPostgres: + case database.DriverPostgres: createPostgresMigrations() - case contractsorm.DriverMysql: + case database.DriverMysql: createMysqlMigrations() - case contractsorm.DriverSqlserver: + case database.DriverSqlserver: createSqlserverMigrations() - case contractsorm.DriverSqlite: + case database.DriverSqlite: createSqliteMigrations() } } diff --git a/database/gorm/query.go b/database/gorm/query.go index cd0e2c269..04eaeaff8 100644 --- a/database/gorm/query.go +++ b/database/gorm/query.go @@ -15,7 +15,8 @@ import ( "gorm.io/gorm/clause" "github.com/goravel/framework/contracts/config" - ormcontract "github.com/goravel/framework/contracts/database/orm" + contractsdatabase "github.com/goravel/framework/contracts/database" + contractsorm "github.com/goravel/framework/contracts/database/orm" "github.com/goravel/framework/database/db" "github.com/goravel/framework/database/gorm/hints" "github.com/goravel/framework/database/orm" @@ -57,13 +58,13 @@ func BuildQuery(ctx context.Context, config config.Config, connection string) (* return NewQuery(ctx, config, connection, gorm, nil), nil } -func (r *Query) Association(association string) ormcontract.Association { +func (r *Query) Association(association string) contractsorm.Association { query := r.buildConditions() return query.instance.Association(association) } -func (r *Query) Begin() (ormcontract.Query, error) { +func (r *Query) Begin() (contractsorm.Query, error) { tx := r.instance.Begin() if tx.Error != nil { return nil, tx.Error @@ -104,13 +105,13 @@ func (r *Query) Create(value any) error { return query.create(value) } -func (r *Query) Cursor() (chan ormcontract.Cursor, error) { +func (r *Query) Cursor() (chan contractsorm.Cursor, error) { with := r.conditions.with query := r.buildConditions() r.conditions.with = with var err error - cursorChan := make(chan ormcontract.Cursor) + cursorChan := make(chan contractsorm.Cursor) go func() { var rows *sql.Rows rows, err = query.instance.Rows() @@ -132,7 +133,7 @@ func (r *Query) Cursor() (chan ormcontract.Cursor, error) { return cursorChan, err } -func (r *Query) Delete(dest ...any) (*ormcontract.Result, error) { +func (r *Query) Delete(dest ...any) (*contractsorm.Result, error) { var ( realDest any err error @@ -161,27 +162,27 @@ func (r *Query) Delete(dest ...any) (*ormcontract.Result, error) { return nil, err } - return &ormcontract.Result{ + return &contractsorm.Result{ RowsAffected: res.RowsAffected, }, nil } -func (r *Query) Distinct(args ...any) ormcontract.Query { +func (r *Query) Distinct(args ...any) contractsorm.Query { conditions := r.conditions conditions.distinct = append(conditions.distinct, args...) return r.setConditions(conditions) } -func (r *Query) Driver() ormcontract.Driver { - return ormcontract.Driver(r.instance.Dialector.Name()) +func (r *Query) Driver() contractsdatabase.Driver { + return contractsdatabase.Driver(r.instance.Dialector.Name()) } -func (r *Query) Exec(sql string, values ...any) (*ormcontract.Result, error) { +func (r *Query) Exec(sql string, values ...any) (*contractsorm.Result, error) { query := r.buildConditions() result := query.instance.Exec(sql, values...) - return &ormcontract.Result{ + return &contractsorm.Result{ RowsAffected: result.RowsAffected, }, result.Error } @@ -346,7 +347,7 @@ func (r *Query) FirstOrNew(dest any, attributes any, values ...any) error { return nil } -func (r *Query) ForceDelete(dest ...any) (*ormcontract.Result, error) { +func (r *Query) ForceDelete(dest ...any) (*contractsorm.Result, error) { var ( realDest any err error @@ -377,7 +378,7 @@ func (r *Query) ForceDelete(dest ...any) (*ormcontract.Result, error) { } } - return &ormcontract.Result{ + return &contractsorm.Result{ RowsAffected: res.RowsAffected, }, res.Error } @@ -386,14 +387,14 @@ func (r *Query) Get(dest any) error { return r.Find(dest) } -func (r *Query) Group(name string) ormcontract.Query { +func (r *Query) Group(name string) contractsorm.Query { conditions := r.conditions conditions.group = name return r.setConditions(conditions) } -func (r *Query) Having(query any, args ...any) ormcontract.Query { +func (r *Query) Having(query any, args ...any) contractsorm.Query { conditions := r.conditions conditions.having = &Having{ query: query, @@ -407,7 +408,7 @@ func (r *Query) Instance() *gormio.DB { return r.instance } -func (r *Query) Join(query string, args ...any) ormcontract.Query { +func (r *Query) Join(query string, args ...any) contractsorm.Query { conditions := r.conditions conditions.join = append(conditions.join, Join{ query: query, @@ -417,7 +418,7 @@ func (r *Query) Join(query string, args ...any) ormcontract.Query { return r.setConditions(conditions) } -func (r *Query) Limit(limit int) ormcontract.Query { +func (r *Query) Limit(limit int) contractsorm.Query { conditions := r.conditions conditions.limit = &limit @@ -489,42 +490,42 @@ func (r *Query) LoadMissing(model any, relation string, args ...any) error { return r.Load(model, relation, args...) } -func (r *Query) LockForUpdate() ormcontract.Query { +func (r *Query) LockForUpdate() contractsorm.Query { conditions := r.conditions conditions.lockForUpdate = true return r.setConditions(conditions) } -func (r *Query) Model(value any) ormcontract.Query { +func (r *Query) Model(value any) contractsorm.Query { conditions := r.conditions conditions.model = value return r.setConditions(conditions) } -func (r *Query) Offset(offset int) ormcontract.Query { +func (r *Query) Offset(offset int) contractsorm.Query { conditions := r.conditions conditions.offset = &offset return r.setConditions(conditions) } -func (r *Query) Omit(columns ...string) ormcontract.Query { +func (r *Query) Omit(columns ...string) contractsorm.Query { conditions := r.conditions conditions.omit = columns return r.setConditions(conditions) } -func (r *Query) Order(value any) ormcontract.Query { +func (r *Query) Order(value any) contractsorm.Query { conditions := r.conditions conditions.order = append(r.conditions.order, value) return r.setConditions(conditions) } -func (r *Query) OrderBy(column string, direction ...string) ormcontract.Query { +func (r *Query) OrderBy(column string, direction ...string) contractsorm.Query { var orderDirection string if len(direction) > 0 { orderDirection = direction[0] @@ -534,26 +535,26 @@ func (r *Query) OrderBy(column string, direction ...string) ormcontract.Query { return r.Order(fmt.Sprintf("%s %s", column, orderDirection)) } -func (r *Query) OrderByDesc(column string) ormcontract.Query { +func (r *Query) OrderByDesc(column string) contractsorm.Query { return r.Order(fmt.Sprintf("%s DESC", column)) } -func (r *Query) InRandomOrder() ormcontract.Query { +func (r *Query) InRandomOrder() contractsorm.Query { order := "" switch r.Driver() { - case ormcontract.DriverMysql: + case contractsdatabase.DriverMysql: order = "RAND()" - case ormcontract.DriverSqlserver: + case contractsdatabase.DriverSqlserver: order = "NEWID()" - case ormcontract.DriverPostgres: + case contractsdatabase.DriverPostgres: order = "RANDOM()" - case ormcontract.DriverSqlite: + case contractsdatabase.DriverSqlite: order = "RANDOM()" } return r.Order(order) } -func (r *Query) OrWhere(query any, args ...any) ormcontract.Query { +func (r *Query) OrWhere(query any, args ...any) contractsorm.Query { conditions := r.conditions conditions.where = append(r.conditions.where, Where{ query: query, @@ -594,7 +595,7 @@ func (r *Query) Pluck(column string, dest any) error { return query.instance.Pluck(column, dest).Error } -func (r *Query) Raw(sql string, values ...any) ormcontract.Query { +func (r *Query) Raw(sql string, values ...any) contractsorm.Query { return r.new(r.instance.Raw(sql, values...)) } @@ -675,14 +676,14 @@ func (r *Query) Scan(dest any) error { return query.instance.Scan(dest).Error } -func (r *Query) Scopes(funcs ...func(ormcontract.Query) ormcontract.Query) ormcontract.Query { +func (r *Query) Scopes(funcs ...func(contractsorm.Query) contractsorm.Query) contractsorm.Query { conditions := r.conditions conditions.scopes = append(r.conditions.scopes, funcs...) return r.setConditions(conditions) } -func (r *Query) Select(query any, args ...any) ormcontract.Query { +func (r *Query) Select(query any, args ...any) contractsorm.Query { conditions := r.conditions conditions.selectColumns = &Select{ query: query, @@ -697,7 +698,7 @@ func (r *Query) SetContext(ctx context.Context) { r.instance.Statement.Context = ctx } -func (r *Query) SharedLock() ormcontract.Query { +func (r *Query) SharedLock() contractsorm.Query { conditions := r.conditions conditions.sharedLock = true @@ -710,7 +711,7 @@ func (r *Query) Sum(column string, dest any) error { return query.instance.Select("SUM(" + column + ")").Row().Scan(dest) } -func (r *Query) Table(name string, args ...any) ormcontract.Query { +func (r *Query) Table(name string, args ...any) contractsorm.Query { conditions := r.conditions conditions.table = &Table{ name: name, @@ -720,15 +721,15 @@ func (r *Query) Table(name string, args ...any) ormcontract.Query { return r.setConditions(conditions) } -func (r *Query) ToSql() ormcontract.ToSql { +func (r *Query) ToSql() contractsorm.ToSql { return NewToSql(r.setConditions(r.conditions), false) } -func (r *Query) ToRawSql() ormcontract.ToSql { +func (r *Query) ToRawSql() contractsorm.ToSql { return NewToSql(r.setConditions(r.conditions), true) } -func (r *Query) Update(column any, value ...any) (*ormcontract.Result, error) { +func (r *Query) Update(column any, value ...any) (*contractsorm.Result, error) { query := r.buildConditions() if _, ok := column.(string); !ok && len(value) > 0 { @@ -791,7 +792,7 @@ func (r *Query) UpdateOrCreate(dest any, attributes any, values any) error { return query.Create(dest) } -func (r *Query) Where(query any, args ...any) ormcontract.Query { +func (r *Query) Where(query any, args ...any) contractsorm.Query { conditions := r.conditions conditions.where = append(r.conditions.where, Where{ query: query, @@ -801,51 +802,51 @@ func (r *Query) Where(query any, args ...any) ormcontract.Query { return r.setConditions(conditions) } -func (r *Query) WhereIn(column string, values []any) ormcontract.Query { +func (r *Query) WhereIn(column string, values []any) contractsorm.Query { return r.Where(fmt.Sprintf("%s IN ?", column), values) } -func (r *Query) OrWhereIn(column string, values []any) ormcontract.Query { +func (r *Query) OrWhereIn(column string, values []any) contractsorm.Query { return r.OrWhere(fmt.Sprintf("%s IN ?", column), values) } -func (r *Query) WhereNotIn(column string, values []any) ormcontract.Query { +func (r *Query) WhereNotIn(column string, values []any) contractsorm.Query { return r.Where(fmt.Sprintf("%s NOT IN ?", column), values) } -func (r *Query) OrWhereNotIn(column string, values []any) ormcontract.Query { +func (r *Query) OrWhereNotIn(column string, values []any) contractsorm.Query { return r.OrWhere(fmt.Sprintf("%s NOT IN ?", column), values) } -func (r *Query) WhereBetween(column string, x, y any) ormcontract.Query { +func (r *Query) WhereBetween(column string, x, y any) contractsorm.Query { return r.Where(fmt.Sprintf("%s BETWEEN %v AND %v", column, x, y)) } -func (r *Query) WhereNotBetween(column string, x, y any) ormcontract.Query { +func (r *Query) WhereNotBetween(column string, x, y any) contractsorm.Query { return r.Where(fmt.Sprintf("%s NOT BETWEEN %v AND %v", column, x, y)) } -func (r *Query) OrWhereBetween(column string, x, y any) ormcontract.Query { +func (r *Query) OrWhereBetween(column string, x, y any) contractsorm.Query { return r.OrWhere(fmt.Sprintf("%s BETWEEN %v AND %v", column, x, y)) } -func (r *Query) OrWhereNotBetween(column string, x, y any) ormcontract.Query { +func (r *Query) OrWhereNotBetween(column string, x, y any) contractsorm.Query { return r.OrWhere(fmt.Sprintf("%s NOT BETWEEN %v AND %v", column, x, y)) } -func (r *Query) OrWhereNull(column string) ormcontract.Query { +func (r *Query) OrWhereNull(column string) contractsorm.Query { return r.OrWhere(fmt.Sprintf("%s IS NULL", column)) } -func (r *Query) WhereNull(column string) ormcontract.Query { +func (r *Query) WhereNull(column string) contractsorm.Query { return r.Where(fmt.Sprintf("%s IS NULL", column)) } -func (r *Query) WhereNotNull(column string) ormcontract.Query { +func (r *Query) WhereNotNull(column string) contractsorm.Query { return r.Where(fmt.Sprintf("%s IS NOT NULL", column)) } -func (r *Query) With(query string, args ...any) ormcontract.Query { +func (r *Query) With(query string, args ...any) contractsorm.Query { conditions := r.conditions conditions.with = append(r.conditions.with, With{ query: query, @@ -855,14 +856,14 @@ func (r *Query) With(query string, args ...any) ormcontract.Query { return r.setConditions(conditions) } -func (r *Query) WithoutEvents() ormcontract.Query { +func (r *Query) WithoutEvents() contractsorm.Query { conditions := r.conditions conditions.withoutEvents = true return r.setConditions(conditions) } -func (r *Query) WithTrashed() ormcontract.Query { +func (r *Query) WithTrashed() contractsorm.Query { conditions := r.conditions conditions.withTrashed = true @@ -1113,7 +1114,7 @@ func (r *Query) buildWith(db *gormio.DB) *gormio.DB { for _, item := range r.conditions.with { isSet := false if len(item.args) == 1 { - if arg, ok := item.args[0].(func(ormcontract.Query) ormcontract.Query); ok { + if arg, ok := item.args[0].(func(contractsorm.Query) contractsorm.Query); ok { newArgs := []any{ func(tx *gormio.DB) *gormio.DB { queryImpl := NewQuery(r.ctx, r.config, r.connection, tx, nil) @@ -1178,30 +1179,30 @@ func (r *Query) create(dest any) error { } func (r *Query) created(dest any) error { - return r.event(ormcontract.EventCreated, r.instance.Statement.Model, dest) + return r.event(contractsorm.EventCreated, r.instance.Statement.Model, dest) } func (r *Query) creating(dest any) error { - return r.event(ormcontract.EventCreating, r.instance.Statement.Model, dest) + return r.event(contractsorm.EventCreating, r.instance.Statement.Model, dest) } func (r *Query) deleting(dest any) error { - return r.event(ormcontract.EventDeleting, r.instance.Statement.Model, dest) + return r.event(contractsorm.EventDeleting, r.instance.Statement.Model, dest) } func (r *Query) deleted(dest any) error { - return r.event(ormcontract.EventDeleted, r.instance.Statement.Model, dest) + return r.event(contractsorm.EventDeleted, r.instance.Statement.Model, dest) } func (r *Query) forceDeleting(dest any) error { - return r.event(ormcontract.EventForceDeleting, r.instance.Statement.Model, dest) + return r.event(contractsorm.EventForceDeleting, r.instance.Statement.Model, dest) } func (r *Query) forceDeleted(dest any) error { - return r.event(ormcontract.EventForceDeleted, r.instance.Statement.Model, dest) + return r.event(contractsorm.EventForceDeleted, r.instance.Statement.Model, dest) } -func (r *Query) event(event ormcontract.EventType, model, dest any) error { +func (r *Query) event(event contractsorm.EventType, model, dest any) error { if r.conditions.withoutEvents { return nil } @@ -1209,7 +1210,7 @@ func (r *Query) event(event ormcontract.EventType, model, dest any) error { instance := NewEvent(r, model, dest) if dest != nil { - if dispatchesEvents, exist := dest.(ormcontract.DispatchesEvents); exist { + if dispatchesEvents, exist := dest.(contractsorm.DispatchesEvents); exist { if dispatchesEvent, exists := dispatchesEvents.DispatchesEvents()[event]; exists { return dispatchesEvent(instance) } @@ -1218,7 +1219,7 @@ func (r *Query) event(event ormcontract.EventType, model, dest any) error { } if model != nil { - if dispatchesEvents, exist := model.(ormcontract.DispatchesEvents); exist { + if dispatchesEvents, exist := model.(contractsorm.DispatchesEvents); exist { if dispatchesEvent, exists := dispatchesEvents.DispatchesEvents()[event]; exists { return dispatchesEvent(instance) } @@ -1334,7 +1335,7 @@ func (r *Query) refreshConnection(model any) (*Query, error) { } func (r *Query) retrieved(dest any) error { - return r.event(ormcontract.EventRetrieved, nil, dest) + return r.event(contractsorm.EventRetrieved, nil, dest) } func (r *Query) save(value any) error { @@ -1342,11 +1343,11 @@ func (r *Query) save(value any) error { } func (r *Query) saved(dest any) error { - return r.event(ormcontract.EventSaved, r.instance.Statement.Model, dest) + return r.event(contractsorm.EventSaved, r.instance.Statement.Model, dest) } func (r *Query) saving(dest any) error { - return r.event(ormcontract.EventSaving, r.instance.Statement.Model, dest) + return r.event(contractsorm.EventSaving, r.instance.Statement.Model, dest) } func (r *Query) selectCreate(value any) error { @@ -1405,14 +1406,14 @@ func (r *Query) setConditions(conditions Conditions) *Query { } func (r *Query) updating(dest any) error { - return r.event(ormcontract.EventUpdating, r.instance.Statement.Model, dest) + return r.event(contractsorm.EventUpdating, r.instance.Statement.Model, dest) } func (r *Query) updated(dest any) error { - return r.event(ormcontract.EventUpdated, r.instance.Statement.Model, dest) + return r.event(contractsorm.EventUpdated, r.instance.Statement.Model, dest) } -func (r *Query) updates(values any) (*ormcontract.Result, error) { +func (r *Query) updates(values any) (*contractsorm.Result, error) { if len(r.instance.Statement.Selects) > 0 && len(r.instance.Statement.Omits) > 0 { return nil, errors.New("cannot set Select and Omits at the same time") } @@ -1421,7 +1422,7 @@ func (r *Query) updates(values any) (*ormcontract.Result, error) { for _, val := range r.instance.Statement.Selects { if val == orm.Associations { result := r.instance.Session(&gormio.Session{FullSaveAssociations: true}).Updates(values) - return &ormcontract.Result{ + return &contractsorm.Result{ RowsAffected: result.RowsAffected, }, result.Error } @@ -1429,7 +1430,7 @@ func (r *Query) updates(values any) (*ormcontract.Result, error) { result := r.instance.Updates(values) - return &ormcontract.Result{ + return &contractsorm.Result{ RowsAffected: result.RowsAffected, }, result.Error } @@ -1439,20 +1440,20 @@ func (r *Query) updates(values any) (*ormcontract.Result, error) { if val == orm.Associations { result := r.instance.Omit(orm.Associations).Updates(values) - return &ormcontract.Result{ + return &contractsorm.Result{ RowsAffected: result.RowsAffected, }, result.Error } } result := r.instance.Updates(values) - return &ormcontract.Result{ + return &contractsorm.Result{ RowsAffected: result.RowsAffected, }, result.Error } result := r.instance.Omit(orm.Associations).Updates(values) - return &ormcontract.Result{ + return &contractsorm.Result{ RowsAffected: result.RowsAffected, }, result.Error } @@ -1506,7 +1507,7 @@ func getModelConnection(model any) (string, error) { } newModel := reflect.New(modelType) - connectionModel, ok := newModel.Interface().(ormcontract.ConnectionModel) + connectionModel, ok := newModel.Interface().(contractsorm.ConnectionModel) if !ok { return "", nil } @@ -1514,7 +1515,7 @@ func getModelConnection(model any) (string, error) { return connectionModel.Connection(), nil } -func getObserver(dest any) ormcontract.Observer { +func getObserver(dest any) contractsorm.Observer { destType := reflect.TypeOf(dest) if destType.Kind() == reflect.Pointer { destType = destType.Elem() @@ -1533,29 +1534,29 @@ func getObserver(dest any) ormcontract.Observer { return nil } -func getObserverEvent(event ormcontract.EventType, observer ormcontract.Observer) func(ormcontract.Event) error { +func getObserverEvent(event contractsorm.EventType, observer contractsorm.Observer) func(contractsorm.Event) error { switch event { - case ormcontract.EventRetrieved: + case contractsorm.EventRetrieved: return observer.Retrieved - case ormcontract.EventCreating: + case contractsorm.EventCreating: return observer.Creating - case ormcontract.EventCreated: + case contractsorm.EventCreated: return observer.Created - case ormcontract.EventUpdating: + case contractsorm.EventUpdating: return observer.Updating - case ormcontract.EventUpdated: + case contractsorm.EventUpdated: return observer.Updated - case ormcontract.EventSaving: + case contractsorm.EventSaving: return observer.Saving - case ormcontract.EventSaved: + case contractsorm.EventSaved: return observer.Saved - case ormcontract.EventDeleting: + case contractsorm.EventDeleting: return observer.Deleting - case ormcontract.EventDeleted: + case contractsorm.EventDeleted: return observer.Deleted - case ormcontract.EventForceDeleting: + case contractsorm.EventForceDeleting: return observer.ForceDeleting - case ormcontract.EventForceDeleted: + case contractsorm.EventForceDeleted: return observer.ForceDeleted } diff --git a/database/gorm/query_test.go b/database/gorm/query_test.go index 7fb5b9e85..048c468b3 100644 --- a/database/gorm/query_test.go +++ b/database/gorm/query_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/suite" _ "gorm.io/driver/postgres" + "github.com/goravel/framework/contracts/database" contractsorm "github.com/goravel/framework/contracts/database/orm" databasedb "github.com/goravel/framework/database/db" "github.com/goravel/framework/database/orm" @@ -23,7 +24,7 @@ import ( type QueryTestSuite struct { suite.Suite - queries map[contractsorm.Driver]*TestQuery + queries map[database.Driver]*TestQuery additionalQuery *TestQuery } @@ -608,7 +609,7 @@ func (s *QueryTestSuite) TestDBRaw() { s.Nil(query.Query().Create(&user)) s.True(user.ID > 0) switch driver { - case contractsorm.DriverSqlserver, contractsorm.DriverMysql: + case database.DriverSqlserver, database.DriverMysql: res, err := query.Query().Model(&user).Update("Name", databasedb.Raw("concat(name, ?)", driver.String())) s.Nil(err) s.Equal(int64(1), res.RowsAffected) @@ -2043,7 +2044,7 @@ func (s *QueryTestSuite) TestJoin() { func (s *QueryTestSuite) TestLockForUpdate() { for driver, query := range s.queries { - if driver != contractsorm.DriverSqlite { + if driver != database.DriverSqlite { s.Run(driver.String(), func() { user := User{Name: "lock_for_update_user"} s.Nil(query.Query().Create(&user)) @@ -2625,7 +2626,7 @@ func (s *QueryTestSuite) TestRefreshConnection() { return people }(), setup: func() { - mockCommonConnection(s.queries[contractsorm.DriverPostgres].MockConfig(), s.additionalQuery, "dummy") + mockCommonConnection(s.queries[database.DriverPostgres].MockConfig(), s.additionalQuery, "dummy") }, expectConnection: "dummy", }, @@ -2636,7 +2637,7 @@ func (s *QueryTestSuite) TestRefreshConnection() { return product }(), setup: func() { - mockCommonConnection(s.queries[contractsorm.DriverPostgres].MockConfig(), s.queries[contractsorm.DriverSqlite], "sqlite") + mockCommonConnection(s.queries[database.DriverPostgres].MockConfig(), s.queries[database.DriverSqlite], "sqlite") }, expectConnection: "sqlite", }, @@ -2645,7 +2646,7 @@ func (s *QueryTestSuite) TestRefreshConnection() { for _, test := range tests { s.Run(test.name, func() { test.setup() - testQuery := s.queries[contractsorm.DriverPostgres] + testQuery := s.queries[database.DriverPostgres] query, err := testQuery.Query().(*Query).refreshConnection(test.model) if test.expectErr != "" { s.EqualError(err, test.expectErr) @@ -2774,7 +2775,7 @@ func (s *QueryTestSuite) TestSelect() { func (s *QueryTestSuite) TestSharedLock() { for driver, query := range s.queries { - if driver != contractsorm.DriverSqlite { + if driver != database.DriverSqlite { s.Run(driver.String(), func() { user := User{Name: "shared_lock_user"} s.Nil(query.Query().Create(&user)) @@ -2856,9 +2857,9 @@ func (s *QueryTestSuite) TestToSql() { for driver, query := range s.queries { s.Run(driver.String(), func() { switch driver { - case contractsorm.DriverPostgres: + case database.DriverPostgres: s.Equal("SELECT * FROM \"users\" WHERE \"id\" = $1 AND \"users\".\"deleted_at\" IS NULL", query.Query().Where("id", 1).ToSql().Find(User{})) - case contractsorm.DriverSqlserver: + case database.DriverSqlserver: s.Equal("SELECT * FROM \"users\" WHERE \"id\" = @p1 AND \"users\".\"deleted_at\" IS NULL", query.Query().Where("id", 1).ToSql().Find(User{})) default: s.Equal("SELECT * FROM `users` WHERE `id` = ? AND `users`.`deleted_at` IS NULL", query.Query().Where("id", 1).ToSql().Find(User{})) @@ -2871,9 +2872,9 @@ func (s *QueryTestSuite) TestToRawSql() { for driver, query := range s.queries { s.Run(driver.String(), func() { switch driver { - case contractsorm.DriverPostgres: + case database.DriverPostgres: s.Equal("SELECT * FROM \"users\" WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL", query.Query().Where("id", 1).ToRawSql().Find(User{})) - case contractsorm.DriverSqlserver: + case database.DriverSqlserver: s.Equal("SELECT * FROM \"users\" WHERE \"id\" = $1$ AND \"users\".\"deleted_at\" IS NULL", query.Query().Where("id", 1).ToRawSql().Find(User{})) default: s.Equal("SELECT * FROM `users` WHERE `id` = 1 AND `users`.`deleted_at` IS NULL", query.Query().Where("id", 1).ToRawSql().Find(User{})) @@ -3655,7 +3656,7 @@ func paginator(page string, limit string) func(methods contractsorm.Query) contr } func mockCommonConnection(mockConfig *mocksconfig.Config, testQuery *TestQuery, connection string) { - mockDriver := GetMockDriver(testQuery.Docker(), mockConfig, connection) + mockDriver := getMockDriver(testQuery.Docker(), mockConfig, connection) mockDriver.Common() } diff --git a/database/gorm/test_utils.go b/database/gorm/test_utils.go index 7eb816908..010506609 100644 --- a/database/gorm/test_utils.go +++ b/database/gorm/test_utils.go @@ -5,7 +5,7 @@ import ( "fmt" "slices" - "github.com/goravel/framework/contracts/database" + contractsdatabase "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/contracts/database/orm" "github.com/goravel/framework/contracts/testing" mocksconfig "github.com/goravel/framework/mocks/config" @@ -79,11 +79,11 @@ func NewTestQueries() *TestQueries { } } -func (r *TestQueries) Queries() map[orm.Driver]*TestQuery { +func (r *TestQueries) Queries() map[contractsdatabase.Driver]*TestQuery { return r.queries(false) } -func (r *TestQueries) QueriesOfReadWrite() map[orm.Driver]map[string]orm.Query { +func (r *TestQueries) QueriesOfReadWrite() map[contractsdatabase.Driver]map[string]orm.Query { readPostgresQuery := NewTestQuery(r.postgresDockers[0]) readPostgresQuery.CreateTable(TestTableUsers) @@ -112,13 +112,13 @@ func (r *TestQueries) QueriesOfReadWrite() map[orm.Driver]map[string]orm.Query { } if TestModel == TestModelMinimum { - return map[orm.Driver]map[string]orm.Query{ - orm.DriverPostgres: { + return map[contractsdatabase.Driver]map[string]orm.Query{ + contractsdatabase.DriverPostgres: { "mix": postgresQuery, "read": readPostgresQuery.Query(), "write": writePostgresQuery.Query(), }, - orm.DriverSqlite: { + contractsdatabase.DriverSqlite: { "mix": sqliteQuery, "read": readSqliteQuery.Query(), "write": writeSqliteQuery.Query(), @@ -154,23 +154,23 @@ func (r *TestQueries) QueriesOfReadWrite() map[orm.Driver]map[string]orm.Query { panic(err) } - return map[orm.Driver]map[string]orm.Query{ - orm.DriverMysql: { + return map[contractsdatabase.Driver]map[string]orm.Query{ + contractsdatabase.DriverMysql: { "mix": mysqlQuery, "read": readMysqlQuery.Query(), "write": writeMysqlQuery.Query(), }, - orm.DriverPostgres: { + contractsdatabase.DriverPostgres: { "mix": postgresQuery, "read": readPostgresQuery.Query(), "write": writePostgresQuery.Query(), }, - orm.DriverSqlite: { + contractsdatabase.DriverSqlite: { "mix": sqliteQuery, "read": readSqliteQuery.Query(), "write": writeSqliteQuery.Query(), }, - orm.DriverSqlserver: { + contractsdatabase.DriverSqlserver: { "mix": sqlserverQuery, "read": readSqlserverQuery.Query(), "write": writeSqlserverQuery.Query(), @@ -178,7 +178,7 @@ func (r *TestQueries) QueriesOfReadWrite() map[orm.Driver]map[string]orm.Query { } } -func (r *TestQueries) QueriesWithPrefixAndSingular() map[orm.Driver]*TestQuery { +func (r *TestQueries) QueriesWithPrefixAndSingular() map[contractsdatabase.Driver]*TestQuery { return r.queries(true) } @@ -189,17 +189,17 @@ func (r *TestQueries) QueryOfAdditional() *TestQuery { return postgresQuery } -func (r *TestQueries) queries(withPrefixAndSingular bool) map[orm.Driver]*TestQuery { - driverToTestQuery := make(map[orm.Driver]*TestQuery) +func (r *TestQueries) queries(withPrefixAndSingular bool) map[contractsdatabase.Driver]*TestQuery { + driverToTestQuery := make(map[contractsdatabase.Driver]*TestQuery) - driverToDocker := map[orm.Driver]testing.DatabaseDriver{ - orm.DriverPostgres: r.postgresDockers[0], - orm.DriverSqlite: r.sqliteDockers[0], + driverToDocker := map[contractsdatabase.Driver]testing.DatabaseDriver{ + contractsdatabase.DriverPostgres: r.postgresDockers[0], + contractsdatabase.DriverSqlite: r.sqliteDockers[0], } if TestModel != TestModelMinimum { - driverToDocker[orm.DriverMysql] = r.mysqlDockers[0] - driverToDocker[orm.DriverSqlserver] = r.sqlserverDockers[0] + driverToDocker[contractsdatabase.DriverMysql] = r.mysqlDockers[0] + driverToDocker[contractsdatabase.DriverSqlserver] = r.sqlserverDockers[0] } for driver, docker := range driverToDocker { @@ -220,7 +220,7 @@ type TestQuery struct { func NewTestQuery(docker testing.DatabaseDriver, withPrefixAndSingular ...bool) *TestQuery { mockConfig := &mocksconfig.Config{} - mockDriver := GetMockDriver(docker, mockConfig, docker.Driver().String()) + mockDriver := getMockDriver(docker, mockConfig, docker.Driver().String()) testQuery := &TestQuery{ docker: docker, @@ -273,23 +273,23 @@ func (r *TestQuery) Query() orm.Query { func (r *TestQuery) QueryOfReadWrite(config TestReadWriteConfig) (orm.Query, error) { mockConfig := &mocksconfig.Config{} - mockDriver := GetMockDriver(r.Docker(), mockConfig, r.Docker().Driver().String()) + mockDriver := getMockDriver(r.Docker(), mockConfig, r.Docker().Driver().String()) mockDriver.ReadWrite(config) return BuildQuery(testContext, mockConfig, r.docker.Driver().String()) } -func GetMockDriver(docker testing.DatabaseDriver, mockConfig *mocksconfig.Config, connection string) testMockDriver { +func getMockDriver(docker testing.DatabaseDriver, mockConfig *mocksconfig.Config, connection string) testMockDriver { config := docker.Config() switch docker.Driver() { - case orm.DriverMysql: + case contractsdatabase.DriverMysql: return NewMockMysql(mockConfig, connection, config.Database, config.Username, config.Password, config.Port) - case orm.DriverPostgres: + case contractsdatabase.DriverPostgres: return NewMockPostgres(mockConfig, connection, config.Database, config.Username, config.Password, config.Port) - case orm.DriverSqlite: + case contractsdatabase.DriverSqlite: return NewMockSqlite(mockConfig, connection, config.Database) - case orm.DriverSqlserver: + case contractsdatabase.DriverSqlserver: return NewMockSqlserver(mockConfig, connection, config.Database, config.Username, config.Password, config.Port) default: panic("unsupported driver") @@ -297,7 +297,7 @@ func GetMockDriver(docker testing.DatabaseDriver, mockConfig *mocksconfig.Config } type MockMysql struct { - driver orm.Driver + driver contractsdatabase.Driver mockConfig *mocksconfig.Config connection string @@ -309,7 +309,7 @@ type MockMysql struct { func NewMockMysql(mockConfig *mocksconfig.Config, connection, database, username, password string, port int) *MockMysql { return &MockMysql{ - driver: orm.DriverMysql, + driver: contractsdatabase.DriverMysql, mockConfig: mockConfig, connection: connection, database: database, @@ -320,55 +320,55 @@ func NewMockMysql(mockConfig *mocksconfig.Config, connection, database, username } func (r *MockMysql) Common() { - r.mockConfig.On("GetString", "database.default").Return("mysql") - r.mockConfig.On("GetString", "database.migrations").Return("migrations") - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") - r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) + r.mockConfig.On("GetString", "contractsdatabase.default").Return("mysql") + r.mockConfig.On("GetString", "contractsdatabase.migrations").Return("migrations") + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("") + r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(false) r.single() r.basic() } func (r *MockMysql) ReadWrite(config TestReadWriteConfig) { - r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.read", r.connection)).Return([]database.Config{ + r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.read", r.connection)).Return([]contractsdatabase.Config{ {Host: "127.0.0.1", Port: config.ReadPort, Username: r.user, Password: r.password}, }) - r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.write", r.connection)).Return([]database.Config{ + r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.write", r.connection)).Return([]contractsdatabase.Config{ {Host: "127.0.0.1", Port: config.WritePort, Username: r.user, Password: r.password}, }) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") - r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("") + r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(false) r.basic() } func (r *MockMysql) WithPrefixAndSingular() { - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("goravel_") - r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(true) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("goravel_") + r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(true) r.single() r.basic() } func (r *MockMysql) basic() { r.mockConfig.On("GetBool", "app.debug").Return(true) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.driver", r.connection)).Return(r.driver.String()) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.charset", r.connection)).Return("utf8mb4") - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.loc", r.connection)).Return("Local") - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.database", r.connection)).Return(r.database) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.driver", r.connection)).Return(r.driver.String()) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.charset", r.connection)).Return("utf8mb4") + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.loc", r.connection)).Return("Local") + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.database", r.connection)).Return(r.database) mockPool(r.mockConfig) } func (r *MockMysql) single() { - r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.read", r.connection)).Return(nil) - r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.write", r.connection)).Return(nil) + r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.read", r.connection)).Return(nil) + r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.write", r.connection)).Return(nil) r.mockConfig.On("GetBool", "app.debug").Return(true) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.host", r.connection)).Return("127.0.0.1") - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.username", r.connection)).Return(r.user) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.password", r.connection)).Return(r.password) - r.mockConfig.On("GetInt", fmt.Sprintf("database.connections.%s.port", r.connection)).Return(r.port) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.host", r.connection)).Return("127.0.0.1") + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.username", r.connection)).Return(r.user) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.password", r.connection)).Return(r.password) + r.mockConfig.On("GetInt", fmt.Sprintf("contractsdatabase.connections.%s.port", r.connection)).Return(r.port) } type MockPostgres struct { - driver orm.Driver + driver contractsdatabase.Driver mockConfig *mocksconfig.Config connection string @@ -380,7 +380,7 @@ type MockPostgres struct { func NewMockPostgres(mockConfig *mocksconfig.Config, connection, database, username, password string, port int) *MockPostgres { return &MockPostgres{ - driver: orm.DriverPostgres, + driver: contractsdatabase.DriverPostgres, mockConfig: mockConfig, connection: connection, database: database, @@ -391,54 +391,54 @@ func NewMockPostgres(mockConfig *mocksconfig.Config, connection, database, usern } func (r *MockPostgres) Common() { - r.mockConfig.On("GetString", "database.default").Return("postgres") - r.mockConfig.On("GetString", "database.migrations").Return("migrations") - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") - r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) + r.mockConfig.On("GetString", "contractsdatabase.default").Return("postgres") + r.mockConfig.On("GetString", "contractsdatabase.migrations").Return("migrations") + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("") + r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(false) r.single() r.basic() } func (r *MockPostgres) ReadWrite(config TestReadWriteConfig) { - r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.read", r.connection)).Return([]database.Config{ + r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.read", r.connection)).Return([]contractsdatabase.Config{ {Host: "127.0.0.1", Port: config.ReadPort, Username: r.user, Password: r.password}, }) - r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.write", r.connection)).Return([]database.Config{ + r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.write", r.connection)).Return([]contractsdatabase.Config{ {Host: "127.0.0.1", Port: config.WritePort, Username: r.user, Password: r.password}, }) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") - r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("") + r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(false) r.basic() } func (r *MockPostgres) WithPrefixAndSingular() { - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("goravel_") - r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(true) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("goravel_") + r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(true) r.single() r.basic() } func (r *MockPostgres) basic() { r.mockConfig.On("GetBool", "app.debug").Return(true) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.driver", r.connection)).Return(r.driver.String()) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.sslmode", r.connection)).Return("disable") - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.timezone", r.connection)).Return("UTC") - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.database", r.connection)).Return(r.database) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.driver", r.connection)).Return(r.driver.String()) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.sslmode", r.connection)).Return("disable") + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.timezone", r.connection)).Return("UTC") + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.database", r.connection)).Return(r.database) mockPool(r.mockConfig) } func (r *MockPostgres) single() { - r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.read", r.connection)).Return(nil) - r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.write", r.connection)).Return(nil) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.host", r.connection)).Return("127.0.0.1") - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.username", r.connection)).Return(r.user) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.password", r.connection)).Return(r.password) - r.mockConfig.On("GetInt", fmt.Sprintf("database.connections.%s.port", r.connection)).Return(r.port) + r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.read", r.connection)).Return(nil) + r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.write", r.connection)).Return(nil) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.host", r.connection)).Return("127.0.0.1") + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.username", r.connection)).Return(r.user) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.password", r.connection)).Return(r.password) + r.mockConfig.On("GetInt", fmt.Sprintf("contractsdatabase.connections.%s.port", r.connection)).Return(r.port) } type MockSqlite struct { - driver orm.Driver + driver contractsdatabase.Driver mockConfig *mocksconfig.Config connection string @@ -447,7 +447,7 @@ type MockSqlite struct { func NewMockSqlite(mockConfig *mocksconfig.Config, connection, database string) *MockSqlite { return &MockSqlite{ - driver: orm.DriverSqlite, + driver: contractsdatabase.DriverSqlite, mockConfig: mockConfig, connection: connection, database: database, @@ -455,47 +455,47 @@ func NewMockSqlite(mockConfig *mocksconfig.Config, connection, database string) } func (r *MockSqlite) Common() { - r.mockConfig.On("GetString", "database.default").Return("sqlite") - r.mockConfig.On("GetString", "database.migrations").Return("migrations") - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") - r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) + r.mockConfig.On("GetString", "contractsdatabase.default").Return("sqlite") + r.mockConfig.On("GetString", "contractsdatabase.migrations").Return("migrations") + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("") + r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(false) r.single() r.basic() } func (r *MockSqlite) ReadWrite(config TestReadWriteConfig) { - r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.read", r.connection)).Return([]database.Config{ + r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.read", r.connection)).Return([]contractsdatabase.Config{ {Database: config.ReadDatabase}, }) - r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.write", r.connection)).Return([]database.Config{ + r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.write", r.connection)).Return([]contractsdatabase.Config{ {Database: r.database}, }) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") - r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("") + r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(false) r.basic() } func (r *MockSqlite) WithPrefixAndSingular() { - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("goravel_") - r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(true) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("goravel_") + r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(true) r.single() r.basic() } func (r *MockSqlite) basic() { r.mockConfig.On("GetBool", "app.debug").Return(true) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.driver", r.connection)).Return(r.driver.String()) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.driver", r.connection)).Return(r.driver.String()) mockPool(r.mockConfig) } func (r *MockSqlite) single() { - r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.read", r.connection)).Return(nil) - r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.write", r.connection)).Return(nil) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.database", r.connection)).Return(r.database) + r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.read", r.connection)).Return(nil) + r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.write", r.connection)).Return(nil) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.database", r.connection)).Return(r.database) } type MockSqlserver struct { - driver orm.Driver + driver contractsdatabase.Driver mockConfig *mocksconfig.Config connection string @@ -507,7 +507,7 @@ type MockSqlserver struct { func NewMockSqlserver(mockConfig *mocksconfig.Config, connection, database, username, password string, port int) *MockSqlserver { return &MockSqlserver{ - driver: orm.DriverSqlserver, + driver: contractsdatabase.DriverSqlserver, mockConfig: mockConfig, connection: connection, database: database, @@ -518,55 +518,55 @@ func NewMockSqlserver(mockConfig *mocksconfig.Config, connection, database, user } func (r *MockSqlserver) Common() { - r.mockConfig.On("GetString", "database.default").Return("sqlserver") - r.mockConfig.On("GetString", "database.migrations").Return("migrations") - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") - r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) + r.mockConfig.On("GetString", "contractsdatabase.default").Return("sqlserver") + r.mockConfig.On("GetString", "contractsdatabase.migrations").Return("migrations") + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("") + r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(false) r.single() r.basic() } func (r *MockSqlserver) ReadWrite(config TestReadWriteConfig) { - r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.read", r.connection)).Return([]database.Config{ + r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.read", r.connection)).Return([]contractsdatabase.Config{ {Host: "127.0.0.1", Port: config.ReadPort, Username: r.user, Password: r.password}, }) - r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.write", r.connection)).Return([]database.Config{ + r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.write", r.connection)).Return([]contractsdatabase.Config{ {Host: "127.0.0.1", Port: config.WritePort, Username: r.user, Password: r.password}, }) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") - r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("") + r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(false) r.basic() } func (r *MockSqlserver) WithPrefixAndSingular() { - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("goravel_") - r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(true) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("goravel_") + r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(true) r.single() r.basic() } func (r *MockSqlserver) basic() { r.mockConfig.On("GetBool", "app.debug").Return(true) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.driver", r.connection)).Return(r.driver.String()) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.database", r.connection)).Return(r.database) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.charset", r.connection)).Return("utf8mb4") + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.driver", r.connection)).Return(r.driver.String()) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.database", r.connection)).Return(r.database) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.charset", r.connection)).Return("utf8mb4") mockPool(r.mockConfig) } func (r *MockSqlserver) single() { - r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.read", r.connection)).Return(nil) - r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.write", r.connection)).Return(nil) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.host", r.connection)).Return("127.0.0.1") - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.username", r.connection)).Return(r.user) - r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.password", r.connection)).Return(r.password) - r.mockConfig.On("GetInt", fmt.Sprintf("database.connections.%s.port", r.connection)).Return(r.port) + r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.read", r.connection)).Return(nil) + r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.write", r.connection)).Return(nil) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.host", r.connection)).Return("127.0.0.1") + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.username", r.connection)).Return(r.user) + r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.password", r.connection)).Return(r.password) + r.mockConfig.On("GetInt", fmt.Sprintf("contractsdatabase.connections.%s.port", r.connection)).Return(r.port) } type testTables struct { - driver orm.Driver + driver contractsdatabase.Driver } -func newTestTables(driver orm.Driver) *testTables { +func newTestTables(driver contractsdatabase.Driver) *testTables { return &testTables{driver: driver} } @@ -589,7 +589,7 @@ func (r *testTables) All() map[TestTable]func() string { func (r *testTables) peoples() string { switch r.driver { - case orm.DriverMysql: + case contractsdatabase.DriverMysql: return ` CREATE TABLE peoples ( id bigint(20) unsigned NOT NULL AUTO_INCREMENT, @@ -602,7 +602,7 @@ CREATE TABLE peoples ( KEY idx_users_updated_at (updated_at) ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; ` - case orm.DriverPostgres: + case contractsdatabase.DriverPostgres: return ` CREATE TABLE peoples ( id SERIAL PRIMARY KEY NOT NULL, @@ -612,7 +612,7 @@ CREATE TABLE peoples ( deleted_at timestamp DEFAULT NULL ); ` - case orm.DriverSqlite: + case contractsdatabase.DriverSqlite: return ` CREATE TABLE peoples ( id integer PRIMARY KEY AUTOINCREMENT NOT NULL, @@ -622,7 +622,7 @@ CREATE TABLE peoples ( deleted_at datetime DEFAULT NULL ); ` - case orm.DriverSqlserver: + case contractsdatabase.DriverSqlserver: return ` CREATE TABLE peoples ( id bigint NOT NULL IDENTITY(1,1), @@ -640,7 +640,7 @@ CREATE TABLE peoples ( func (r *testTables) reviews() string { switch r.driver { - case orm.DriverMysql: + case contractsdatabase.DriverMysql: return ` CREATE TABLE reviews ( id bigint(20) unsigned NOT NULL AUTO_INCREMENT, @@ -653,7 +653,7 @@ CREATE TABLE reviews ( KEY idx_users_updated_at (updated_at) ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; ` - case orm.DriverPostgres: + case contractsdatabase.DriverPostgres: return ` CREATE TABLE reviews ( id SERIAL PRIMARY KEY NOT NULL, @@ -663,7 +663,7 @@ CREATE TABLE reviews ( deleted_at timestamp DEFAULT NULL ); ` - case orm.DriverSqlite: + case contractsdatabase.DriverSqlite: return ` CREATE TABLE reviews ( id integer PRIMARY KEY AUTOINCREMENT NOT NULL, @@ -673,7 +673,7 @@ CREATE TABLE reviews ( deleted_at datetime DEFAULT NULL ); ` - case orm.DriverSqlserver: + case contractsdatabase.DriverSqlserver: return ` CREATE TABLE reviews ( id bigint NOT NULL IDENTITY(1,1), @@ -691,7 +691,7 @@ CREATE TABLE reviews ( func (r *testTables) products() string { switch r.driver { - case orm.DriverMysql: + case contractsdatabase.DriverMysql: return ` CREATE TABLE products ( id bigint(20) unsigned NOT NULL AUTO_INCREMENT, @@ -704,7 +704,7 @@ CREATE TABLE products ( KEY idx_users_updated_at (updated_at) ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; ` - case orm.DriverPostgres: + case contractsdatabase.DriverPostgres: return ` CREATE TABLE products ( id SERIAL PRIMARY KEY NOT NULL, @@ -714,7 +714,7 @@ CREATE TABLE products ( deleted_at timestamp DEFAULT NULL ); ` - case orm.DriverSqlite: + case contractsdatabase.DriverSqlite: return ` CREATE TABLE products ( id integer PRIMARY KEY AUTOINCREMENT NOT NULL, @@ -724,7 +724,7 @@ CREATE TABLE products ( deleted_at datetime DEFAULT NULL ); ` - case orm.DriverSqlserver: + case contractsdatabase.DriverSqlserver: return ` CREATE TABLE products ( id bigint NOT NULL IDENTITY(1,1), @@ -742,7 +742,7 @@ CREATE TABLE products ( func (r *testTables) users() string { switch r.driver { - case orm.DriverMysql: + case contractsdatabase.DriverMysql: return ` CREATE TABLE users ( id bigint(20) unsigned NOT NULL AUTO_INCREMENT, @@ -757,7 +757,7 @@ CREATE TABLE users ( KEY idx_users_updated_at (updated_at) ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; ` - case orm.DriverPostgres: + case contractsdatabase.DriverPostgres: return ` CREATE TABLE users ( id SERIAL PRIMARY KEY NOT NULL, @@ -769,7 +769,7 @@ CREATE TABLE users ( deleted_at timestamp DEFAULT NULL ); ` - case orm.DriverSqlite: + case contractsdatabase.DriverSqlite: return ` CREATE TABLE users ( id integer PRIMARY KEY AUTOINCREMENT NOT NULL, @@ -781,7 +781,7 @@ CREATE TABLE users ( deleted_at datetime DEFAULT NULL ); ` - case orm.DriverSqlserver: + case contractsdatabase.DriverSqlserver: return ` CREATE TABLE users ( id bigint NOT NULL IDENTITY(1,1), @@ -801,7 +801,7 @@ CREATE TABLE users ( func (r *testTables) goravelUser() string { switch r.driver { - case orm.DriverMysql: + case contractsdatabase.DriverMysql: return ` CREATE TABLE goravel_user ( id bigint(20) unsigned NOT NULL AUTO_INCREMENT, @@ -816,7 +816,7 @@ CREATE TABLE goravel_user ( KEY idx_users_updated_at (updated_at) ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; ` - case orm.DriverPostgres: + case contractsdatabase.DriverPostgres: return ` CREATE TABLE goravel_user ( id SERIAL PRIMARY KEY NOT NULL, @@ -828,7 +828,7 @@ CREATE TABLE goravel_user ( deleted_at timestamp DEFAULT NULL ); ` - case orm.DriverSqlite: + case contractsdatabase.DriverSqlite: return ` CREATE TABLE goravel_user ( id integer PRIMARY KEY AUTOINCREMENT NOT NULL, @@ -840,7 +840,7 @@ CREATE TABLE goravel_user ( deleted_at datetime DEFAULT NULL ); ` - case orm.DriverSqlserver: + case contractsdatabase.DriverSqlserver: return ` CREATE TABLE goravel_user ( id bigint NOT NULL IDENTITY(1,1), @@ -860,7 +860,7 @@ CREATE TABLE goravel_user ( func (r *testTables) addresses() string { switch r.driver { - case orm.DriverMysql: + case contractsdatabase.DriverMysql: return ` CREATE TABLE addresses ( id bigint(20) unsigned NOT NULL AUTO_INCREMENT, @@ -874,7 +874,7 @@ CREATE TABLE addresses ( KEY idx_addresses_updated_at (updated_at) ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; ` - case orm.DriverPostgres: + case contractsdatabase.DriverPostgres: return ` CREATE TABLE addresses ( id SERIAL PRIMARY KEY NOT NULL, @@ -885,7 +885,7 @@ CREATE TABLE addresses ( updated_at timestamp NOT NULL ); ` - case orm.DriverSqlite: + case contractsdatabase.DriverSqlite: return ` CREATE TABLE addresses ( id integer PRIMARY KEY AUTOINCREMENT NOT NULL, @@ -896,7 +896,7 @@ CREATE TABLE addresses ( updated_at datetime NOT NULL ); ` - case orm.DriverSqlserver: + case contractsdatabase.DriverSqlserver: return ` CREATE TABLE addresses ( id bigint NOT NULL IDENTITY(1,1), @@ -915,7 +915,7 @@ CREATE TABLE addresses ( func (r *testTables) books() string { switch r.driver { - case orm.DriverMysql: + case contractsdatabase.DriverMysql: return ` CREATE TABLE books ( id bigint(20) unsigned NOT NULL AUTO_INCREMENT, @@ -928,7 +928,7 @@ CREATE TABLE books ( KEY idx_books_updated_at (updated_at) ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; ` - case orm.DriverPostgres: + case contractsdatabase.DriverPostgres: return ` CREATE TABLE books ( id SERIAL PRIMARY KEY NOT NULL, @@ -938,7 +938,7 @@ CREATE TABLE books ( updated_at timestamp NOT NULL ); ` - case orm.DriverSqlite: + case contractsdatabase.DriverSqlite: return ` CREATE TABLE books ( id integer PRIMARY KEY AUTOINCREMENT NOT NULL, @@ -948,7 +948,7 @@ CREATE TABLE books ( updated_at datetime NOT NULL ); ` - case orm.DriverSqlserver: + case contractsdatabase.DriverSqlserver: return ` CREATE TABLE books ( id bigint NOT NULL IDENTITY(1,1), @@ -966,7 +966,7 @@ CREATE TABLE books ( func (r *testTables) authors() string { switch r.driver { - case orm.DriverMysql: + case contractsdatabase.DriverMysql: return ` CREATE TABLE authors ( id bigint(20) unsigned NOT NULL AUTO_INCREMENT, @@ -979,7 +979,7 @@ CREATE TABLE authors ( KEY idx_books_updated_at (updated_at) ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; ` - case orm.DriverPostgres: + case contractsdatabase.DriverPostgres: return ` CREATE TABLE authors ( id SERIAL PRIMARY KEY NOT NULL, @@ -989,7 +989,7 @@ CREATE TABLE authors ( updated_at timestamp NOT NULL ); ` - case orm.DriverSqlite: + case contractsdatabase.DriverSqlite: return ` CREATE TABLE authors ( id integer PRIMARY KEY AUTOINCREMENT NOT NULL, @@ -999,7 +999,7 @@ CREATE TABLE authors ( updated_at datetime NOT NULL ); ` - case orm.DriverSqlserver: + case contractsdatabase.DriverSqlserver: return ` CREATE TABLE authors ( id bigint NOT NULL IDENTITY(1,1), @@ -1017,7 +1017,7 @@ CREATE TABLE authors ( func (r *testTables) roles() string { switch r.driver { - case orm.DriverMysql: + case contractsdatabase.DriverMysql: return ` CREATE TABLE roles ( id bigint(20) unsigned NOT NULL AUTO_INCREMENT, @@ -1029,7 +1029,7 @@ CREATE TABLE roles ( KEY idx_roles_updated_at (updated_at) ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; ` - case orm.DriverPostgres: + case contractsdatabase.DriverPostgres: return ` CREATE TABLE roles ( id SERIAL PRIMARY KEY NOT NULL, @@ -1038,7 +1038,7 @@ CREATE TABLE roles ( updated_at timestamp NOT NULL ); ` - case orm.DriverSqlite: + case contractsdatabase.DriverSqlite: return ` CREATE TABLE roles ( id integer PRIMARY KEY AUTOINCREMENT NOT NULL, @@ -1047,7 +1047,7 @@ CREATE TABLE roles ( updated_at datetime NOT NULL ); ` - case orm.DriverSqlserver: + case contractsdatabase.DriverSqlserver: return ` CREATE TABLE roles ( id bigint NOT NULL IDENTITY(1,1), @@ -1064,7 +1064,7 @@ CREATE TABLE roles ( func (r *testTables) houses() string { switch r.driver { - case orm.DriverMysql: + case contractsdatabase.DriverMysql: return ` CREATE TABLE houses ( id bigint(20) unsigned NOT NULL AUTO_INCREMENT, @@ -1078,7 +1078,7 @@ CREATE TABLE houses ( KEY idx_houses_updated_at (updated_at) ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; ` - case orm.DriverPostgres: + case contractsdatabase.DriverPostgres: return ` CREATE TABLE houses ( id SERIAL PRIMARY KEY NOT NULL, @@ -1089,7 +1089,7 @@ CREATE TABLE houses ( updated_at timestamp NOT NULL ); ` - case orm.DriverSqlite: + case contractsdatabase.DriverSqlite: return ` CREATE TABLE houses ( id integer PRIMARY KEY AUTOINCREMENT NOT NULL, @@ -1100,7 +1100,7 @@ CREATE TABLE houses ( updated_at datetime NOT NULL ); ` - case orm.DriverSqlserver: + case contractsdatabase.DriverSqlserver: return ` CREATE TABLE houses ( id bigint NOT NULL IDENTITY(1,1), @@ -1119,7 +1119,7 @@ CREATE TABLE houses ( func (r *testTables) phones() string { switch r.driver { - case orm.DriverMysql: + case contractsdatabase.DriverMysql: return ` CREATE TABLE phones ( id bigint(20) unsigned NOT NULL AUTO_INCREMENT, @@ -1133,7 +1133,7 @@ CREATE TABLE phones ( KEY idx_phones_updated_at (updated_at) ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; ` - case orm.DriverPostgres: + case contractsdatabase.DriverPostgres: return ` CREATE TABLE phones ( id SERIAL PRIMARY KEY NOT NULL, @@ -1144,7 +1144,7 @@ CREATE TABLE phones ( updated_at timestamp NOT NULL ); ` - case orm.DriverSqlite: + case contractsdatabase.DriverSqlite: return ` CREATE TABLE phones ( id integer PRIMARY KEY AUTOINCREMENT NOT NULL, @@ -1155,7 +1155,7 @@ CREATE TABLE phones ( updated_at datetime NOT NULL ); ` - case orm.DriverSqlserver: + case contractsdatabase.DriverSqlserver: return ` CREATE TABLE phones ( id bigint NOT NULL IDENTITY(1,1), @@ -1174,7 +1174,7 @@ CREATE TABLE phones ( func (r *testTables) roleUser() string { switch r.driver { - case orm.DriverMysql: + case contractsdatabase.DriverMysql: return ` CREATE TABLE role_user ( id bigint(20) unsigned NOT NULL AUTO_INCREMENT, @@ -1183,7 +1183,7 @@ CREATE TABLE role_user ( PRIMARY KEY (id) ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4; ` - case orm.DriverPostgres: + case contractsdatabase.DriverPostgres: return ` CREATE TABLE role_user ( id SERIAL PRIMARY KEY NOT NULL, @@ -1191,7 +1191,7 @@ CREATE TABLE role_user ( user_id int NOT NULL ); ` - case orm.DriverSqlite: + case contractsdatabase.DriverSqlite: return ` CREATE TABLE role_user ( id integer PRIMARY KEY AUTOINCREMENT NOT NULL, @@ -1199,7 +1199,7 @@ CREATE TABLE role_user ( user_id int NOT NULL ); ` - case orm.DriverSqlserver: + case contractsdatabase.DriverSqlserver: return ` CREATE TABLE role_user ( id bigint NOT NULL IDENTITY(1,1), @@ -1214,8 +1214,8 @@ CREATE TABLE role_user ( } func mockPool(mockConfig *mocksconfig.Config) { - mockConfig.On("GetInt", "database.pool.max_idle_conns", 10).Return(10) - mockConfig.On("GetInt", "database.pool.max_open_conns", 100).Return(100) - mockConfig.On("GetInt", "database.pool.conn_max_idletime", 3600).Return(3600) - mockConfig.On("GetInt", "database.pool.conn_max_lifetime", 3600).Return(3600) + mockConfig.On("GetInt", "contractsdatabase.pool.max_idle_conns", 10).Return(10) + mockConfig.On("GetInt", "contractsdatabase.pool.max_open_conns", 100).Return(100) + mockConfig.On("GetInt", "contractsdatabase.pool.conn_max_idletime", 3600).Return(3600) + mockConfig.On("GetInt", "contractsdatabase.pool.conn_max_lifetime", 3600).Return(3600) } diff --git a/database/migration/blueprint_test.go b/database/migration/blueprint_test.go index f155107ec..b16b8bd33 100644 --- a/database/migration/blueprint_test.go +++ b/database/migration/blueprint_test.go @@ -6,8 +6,8 @@ import ( "github.com/stretchr/testify/suite" + "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/contracts/database/migration" - "github.com/goravel/framework/contracts/database/orm" "github.com/goravel/framework/database/migration/grammars" mocksmigration "github.com/goravel/framework/mocks/database/migration" mocksorm "github.com/goravel/framework/mocks/database/orm" @@ -17,13 +17,13 @@ import ( type BlueprintTestSuite struct { suite.Suite blueprint *Blueprint - grammars map[orm.Driver]migration.Grammar + grammars map[database.Driver]migration.Grammar } func TestBlueprintTestSuite(t *testing.T) { suite.Run(t, &BlueprintTestSuite{ - grammars: map[orm.Driver]migration.Grammar{ - orm.DriverPostgres: grammars.NewPostgres(), + grammars: map[database.Driver]migration.Grammar{ + database.DriverPostgres: grammars.NewPostgres(), }, }) } @@ -355,7 +355,7 @@ func (s *BlueprintTestSuite) TestToSql() { //s.blueprint.String("name").Comment("comment") //s.blueprint.Comment("comment") - if driver == orm.DriverPostgres { + if driver == database.DriverPostgres { s.Len(s.blueprint.ToSql(mockQuery, grammar), 1) } else { s.Empty(s.blueprint.ToSql(mockQuery, grammar)) diff --git a/database/migration/schema.go b/database/migration/schema.go index 9900a205d..51788660c 100644 --- a/database/migration/schema.go +++ b/database/migration/schema.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/goravel/framework/contracts/config" + contractsdatabase "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/contracts/database/migration" contractsorm "github.com/goravel/framework/contracts/database/orm" "github.com/goravel/framework/contracts/log" @@ -73,15 +74,15 @@ func (r *Schema) Sql(sql string) { func getGrammar(driver string) migration.Grammar { switch driver { - case contractsorm.DriverMysql.String(): + case contractsdatabase.DriverMysql.String(): // TODO Optimize here when implementing Mysql driver return nil - case contractsorm.DriverPostgres.String(): + case contractsdatabase.DriverPostgres.String(): return grammars.NewPostgres() - case contractsorm.DriverSqlserver.String(): + case contractsdatabase.DriverSqlserver.String(): // TODO Optimize here when implementing Mysql driver return nil - case contractsorm.DriverSqlite.String(): + case contractsdatabase.DriverSqlite.String(): // TODO Optimize here when implementing Mysql driver return nil default: diff --git a/database/migration/schema_test.go b/database/migration/schema_test.go index 959c9aa46..f6a920b0d 100644 --- a/database/migration/schema_test.go +++ b/database/migration/schema_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/suite" + contractsdatabase "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/contracts/database/migration" contractsorm "github.com/goravel/framework/contracts/database/orm" contractstesting "github.com/goravel/framework/contracts/testing" @@ -24,7 +25,7 @@ type TestDB struct { type SchemaSuite struct { suite.Suite - driverToTestDB map[contractsorm.Driver]TestDB + driverToTestDB map[contractsdatabase.Driver]TestDB } func TestSchemaSuite(t *testing.T) { @@ -38,8 +39,8 @@ func TestSchemaSuite(t *testing.T) { func (s *SchemaSuite) SetupSuite() { postgresDocker := supportdocker.Postgres() postgresQuery := gorm.NewTestQuery(postgresDocker) - s.driverToTestDB = map[contractsorm.Driver]TestDB{ - contractsorm.DriverPostgres: { + s.driverToTestDB = map[contractsdatabase.Driver]TestDB{ + contractsdatabase.DriverPostgres: { config: postgresDocker.Config(), query: postgresQuery.Query(), }, @@ -51,8 +52,8 @@ func (s *SchemaSuite) SetupTest() { } func (s *SchemaSuite) TestConnection() { - schema, mockConfig, _, _ := initTest(s.T(), contractsorm.DriverMysql) - connection := contractsorm.DriverPostgres.String() + schema, mockConfig, _, _ := initTest(s.T(), contractsdatabase.DriverMysql) + connection := contractsdatabase.DriverPostgres.String() mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.prefix", connection)).Return("goravel_").Once() mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.schema", connection)).Return("").Once() mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.driver", connection)).Return(connection).Once() @@ -86,7 +87,7 @@ func (s *SchemaSuite) TestDropIfExists() { } } -func initTest(t *testing.T, driver contractsorm.Driver) (*Schema, *mocksconfig.Config, *mockslog.Log, *mocksorm.Orm) { +func initTest(t *testing.T, driver contractsdatabase.Driver) (*Schema, *mocksconfig.Config, *mockslog.Log, *mocksorm.Orm) { blueprint := NewBlueprint("goravel_", "") mockConfig := mocksconfig.NewConfig(t) mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.driver", driver)). diff --git a/database/migration/sql_driver.go b/database/migration/sql_driver.go index 146e64f58..56264d5c8 100644 --- a/database/migration/sql_driver.go +++ b/database/migration/sql_driver.go @@ -53,20 +53,20 @@ func (r *SqlDriver) getStub(table string, create bool) (string, string) { } driver := r.config.GetString("database.connections." + r.config.GetString("database.default") + ".driver") - switch orm.Driver(driver) { - case orm.DriverPostgres: + switch database.Driver(driver) { + case database.DriverPostgres: if create { return PostgresStubs{}.CreateUp(), PostgresStubs{}.CreateDown() } return PostgresStubs{}.UpdateUp(), PostgresStubs{}.UpdateDown() - case orm.DriverSqlite: + case database.DriverSqlite: if create { return SqliteStubs{}.CreateUp(), SqliteStubs{}.CreateDown() } return SqliteStubs{}.UpdateUp(), SqliteStubs{}.UpdateDown() - case orm.DriverSqlserver: + case database.DriverSqlserver: if create { return SqlserverStubs{}.CreateUp(), SqlserverStubs{}.CreateDown() } diff --git a/database/orm_test.go b/database/orm_test.go index ea178cef3..b2a45b235 100644 --- a/database/orm_test.go +++ b/database/orm_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/suite" + "github.com/goravel/framework/contracts/database" contractsorm "github.com/goravel/framework/contracts/database/orm" "github.com/goravel/framework/database/gorm" "github.com/goravel/framework/database/orm" @@ -28,7 +29,7 @@ type User struct { type OrmSuite struct { suite.Suite orm *Orm - testQueries map[contractsorm.Driver]*gorm.TestQuery + testQueries map[database.Driver]*gorm.TestQuery } func TestOrmSuite(t *testing.T) { @@ -51,9 +52,9 @@ func (s *OrmSuite) SetupTest() { } s.orm = &Orm{ - connection: contractsorm.DriverPostgres.String(), + connection: database.DriverPostgres.String(), ctx: context.Background(), - query: queries[contractsorm.DriverPostgres.String()], + query: queries[database.DriverPostgres.String()], queries: queries, } } diff --git a/foundation/application_test.go b/foundation/application_test.go index 0aee16d1c..123af9013 100644 --- a/foundation/application_test.go +++ b/foundation/application_test.go @@ -12,7 +12,7 @@ import ( "github.com/goravel/framework/cache" frameworkconfig "github.com/goravel/framework/config" "github.com/goravel/framework/console" - "github.com/goravel/framework/contracts/database/orm" + contractsdatabase "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/contracts/foundation" "github.com/goravel/framework/crypt" "github.com/goravel/framework/database" @@ -151,7 +151,7 @@ func (s *ApplicationTestSuite) TestMakeArtisan() { func (s *ApplicationTestSuite) TestMakeAuth() { mockConfig := &mocksconfig.Config{} - mockConfig.On("GetString", "auth.defaults.guard").Return("user").Once() + mockConfig.EXPECT().GetString("auth.defaults.guard").Return("user").Once() s.app.Singleton(frameworkconfig.Binding, func(app foundation.Application) (any, error) { return mockConfig, nil @@ -172,9 +172,9 @@ func (s *ApplicationTestSuite) TestMakeAuth() { func (s *ApplicationTestSuite) TestMakeCache() { mockConfig := &mocksconfig.Config{} - mockConfig.On("GetString", "cache.default").Return("memory").Once() - mockConfig.On("GetString", "cache.stores.memory.driver").Return("memory").Once() - mockConfig.On("GetString", "cache.prefix").Return("goravel").Once() + mockConfig.EXPECT().GetString("cache.default").Return("memory").Once() + mockConfig.EXPECT().GetString("cache.stores.memory.driver").Return("memory").Once() + mockConfig.EXPECT().GetString("cache.prefix").Return("goravel").Once() s.app.Singleton(frameworkconfig.Binding, func(app foundation.Application) (any, error) { return mockConfig, nil @@ -199,7 +199,7 @@ func (s *ApplicationTestSuite) TestMakeConfig() { func (s *ApplicationTestSuite) TestMakeCrypt() { mockConfig := &mocksconfig.Config{} - mockConfig.On("GetString", "app.key").Return("12345678901234567890123456789012").Once() + mockConfig.EXPECT().GetString("app.key").Return("12345678901234567890123456789012").Once() s.app.Singleton(frameworkconfig.Binding, func(app foundation.Application) (any, error) { return mockConfig, nil @@ -244,10 +244,10 @@ func (s *ApplicationTestSuite) TestMakeGrpc() { func (s *ApplicationTestSuite) TestMakeHash() { mockConfig := &mocksconfig.Config{} - mockConfig.On("GetString", "hashing.driver", "argon2id").Return("argon2id").Once() - mockConfig.On("GetInt", "hashing.argon2id.time", 4).Return(4).Once() - mockConfig.On("GetInt", "hashing.argon2id.memory", 65536).Return(65536).Once() - mockConfig.On("GetInt", "hashing.argon2id.threads", 1).Return(1).Once() + mockConfig.EXPECT().GetString("hashing.driver", "argon2id").Return("argon2id").Once() + mockConfig.EXPECT().GetInt("hashing.argon2id.time", 4).Return(4).Once() + mockConfig.EXPECT().GetInt("hashing.argon2id.memory", 65536).Return(65536).Once() + mockConfig.EXPECT().GetInt("hashing.argon2id.threads", 1).Return(1).Once() s.app.Singleton(frameworkconfig.Binding, func(app foundation.Application) (any, error) { return mockConfig, nil @@ -262,9 +262,9 @@ func (s *ApplicationTestSuite) TestMakeHash() { func (s *ApplicationTestSuite) TestMakeLang() { mockConfig := &mocksconfig.Config{} - mockConfig.On("GetString", "app.locale").Return("en").Once() - mockConfig.On("GetString", "app.fallback_locale").Return("en").Once() - mockConfig.On("GetString", "app.lang_path", "lang").Return("lang").Once() + mockConfig.EXPECT().GetString("app.locale").Return("en").Once() + mockConfig.EXPECT().GetString("app.fallback_locale").Return("en").Once() + mockConfig.EXPECT().GetString("app.lang_path", "lang").Return("lang").Once() s.app.Singleton(frameworkconfig.Binding, func(app foundation.Application) (any, error) { return mockConfig, nil @@ -310,24 +310,24 @@ func (s *ApplicationTestSuite) TestMakeOrm() { mysqlDocker := supportdocker.Mysql() config := mysqlDocker.Config() mockConfig := &mocksconfig.Config{} - mockConfig.On("GetString", "database.default").Return("mysql").Once() - mockConfig.On("Get", "database.connections.mysql.read").Return(nil).Once() - mockConfig.On("Get", "database.connections.mysql.write").Return(nil).Once() - mockConfig.On("GetString", "database.connections.mysql.driver").Return(orm.DriverMysql.String()).Twice() - mockConfig.On("GetString", "database.connections.mysql.charset").Return("utf8mb4").Once() - mockConfig.On("GetString", "database.connections.mysql.loc").Return("Local").Once() - mockConfig.On("GetString", "database.connections.mysql.database").Return(config.Database).Once() - mockConfig.On("GetString", "database.connections.mysql.host").Return("localhost").Once() - mockConfig.On("GetString", "database.connections.mysql.username").Return(config.Username).Once() - mockConfig.On("GetString", "database.connections.mysql.password").Return(config.Password).Once() - mockConfig.On("GetString", "database.connections.mysql.prefix").Return("").Once() - mockConfig.On("GetInt", "database.connections.mysql.port").Return(config.Port).Once() - mockConfig.On("GetBool", "database.connections.mysql.singular").Return(true).Once() - mockConfig.On("GetBool", "app.debug").Return(true).Once() - mockConfig.On("GetInt", "database.pool.max_idle_conns", 10).Return(10) - mockConfig.On("GetInt", "database.pool.max_open_conns", 100).Return(100) - mockConfig.On("GetInt", "database.pool.conn_max_idletime", 3600).Return(3600) - mockConfig.On("GetInt", "database.pool.conn_max_lifetime", 3600).Return(3600) + mockConfig.EXPECT().GetString("database.default").Return("mysql").Once() + mockConfig.EXPECT().Get("database.connections.mysql.read").Return(nil).Once() + mockConfig.EXPECT().Get("database.connections.mysql.write").Return(nil).Once() + mockConfig.EXPECT().GetString("database.connections.mysql.driver").Return(contractsdatabase.DriverMysql.String()).Twice() + mockConfig.EXPECT().GetString("database.connections.mysql.charset").Return("utf8mb4").Once() + mockConfig.EXPECT().GetString("database.connections.mysql.loc").Return("Local").Once() + mockConfig.EXPECT().GetString("database.connections.mysql.database").Return(config.Database).Once() + mockConfig.EXPECT().GetString("database.connections.mysql.host").Return("localhost").Once() + mockConfig.EXPECT().GetString("database.connections.mysql.username").Return(config.Username).Once() + mockConfig.EXPECT().GetString("database.connections.mysql.password").Return(config.Password).Once() + mockConfig.EXPECT().GetString("database.connections.mysql.prefix").Return("").Once() + mockConfig.EXPECT().GetInt("database.connections.mysql.port").Return(config.Port).Once() + mockConfig.EXPECT().GetBool("database.connections.mysql.singular").Return(true).Once() + mockConfig.EXPECT().GetBool("app.debug").Return(true).Once() + mockConfig.EXPECT().GetInt("database.pool.max_idle_conns", 10).Return(10) + mockConfig.EXPECT().GetInt("database.pool.max_open_conns", 100).Return(100) + mockConfig.EXPECT().GetInt("database.pool.conn_max_idletime", 3600).Return(3600) + mockConfig.EXPECT().GetInt("database.pool.conn_max_lifetime", 3600).Return(3600) s.app.Singleton(frameworkconfig.Binding, func(app foundation.Application) (any, error) { return mockConfig, nil @@ -376,7 +376,7 @@ func (s *ApplicationTestSuite) TestMakeRoute() { func (s *ApplicationTestSuite) TestMakeSchedule() { mockConfig := &mocksconfig.Config{} - mockConfig.On("GetBool", "app.debug").Return(false).Once() + mockConfig.EXPECT().GetBool("app.debug").Return(false).Once() s.app.Singleton(frameworkconfig.Binding, func(app foundation.Application) (any, error) { return mockConfig, nil @@ -397,9 +397,9 @@ func (s *ApplicationTestSuite) TestMakeSchedule() { func (s *ApplicationTestSuite) TestMakeSession() { mockConfig := &mocksconfig.Config{} - mockConfig.On("GetInt", "session.lifetime").Return(120).Once() - mockConfig.On("GetInt", "session.gc_interval", 30).Return(30).Once() - mockConfig.On("GetString", "session.files").Return("storage/framework/sessions").Once() + mockConfig.EXPECT().GetInt("session.lifetime").Return(120).Once() + mockConfig.EXPECT().GetInt("session.gc_interval", 30).Return(30).Once() + mockConfig.EXPECT().GetString("session.files").Return("storage/framework/sessions").Once() s.app.Singleton(frameworkconfig.Binding, func(app foundation.Application) (any, error) { return mockConfig, nil @@ -418,10 +418,10 @@ func (s *ApplicationTestSuite) TestMakeSession() { func (s *ApplicationTestSuite) TestMakeStorage() { mockConfig := &mocksconfig.Config{} - mockConfig.On("GetString", "filesystems.default").Return("local").Once() - mockConfig.On("GetString", "filesystems.disks.local.driver").Return("local").Once() - mockConfig.On("GetString", "filesystems.disks.local.root").Return("").Once() - mockConfig.On("GetString", "filesystems.disks.local.url").Return("").Once() + mockConfig.EXPECT().GetString("filesystems.default").Return("local").Once() + mockConfig.EXPECT().GetString("filesystems.disks.local.driver").Return("local").Once() + mockConfig.EXPECT().GetString("filesystems.disks.local.root").Return("").Once() + mockConfig.EXPECT().GetString("filesystems.disks.local.url").Return("").Once() s.app.Singleton(frameworkconfig.Binding, func(app foundation.Application) (any, error) { return mockConfig, nil diff --git a/mocks/database/orm/Query.go b/mocks/database/orm/Query.go index 38feda4e1..a179b320b 100644 --- a/mocks/database/orm/Query.go +++ b/mocks/database/orm/Query.go @@ -3,8 +3,10 @@ package orm import ( - orm "github.com/goravel/framework/contracts/database/orm" + database "github.com/goravel/framework/contracts/database" mock "github.com/stretchr/testify/mock" + + orm "github.com/goravel/framework/contracts/database/orm" ) // Query is an autogenerated mock type for the Query type @@ -444,18 +446,18 @@ func (_c *Query_Distinct_Call) RunAndReturn(run func(...interface{}) orm.Query) } // Driver provides a mock function with given fields: -func (_m *Query) Driver() orm.Driver { +func (_m *Query) Driver() database.Driver { ret := _m.Called() if len(ret) == 0 { panic("no return value specified for Driver") } - var r0 orm.Driver - if rf, ok := ret.Get(0).(func() orm.Driver); ok { + var r0 database.Driver + if rf, ok := ret.Get(0).(func() database.Driver); ok { r0 = rf() } else { - r0 = ret.Get(0).(orm.Driver) + r0 = ret.Get(0).(database.Driver) } return r0 @@ -478,12 +480,12 @@ func (_c *Query_Driver_Call) Run(run func()) *Query_Driver_Call { return _c } -func (_c *Query_Driver_Call) Return(_a0 orm.Driver) *Query_Driver_Call { +func (_c *Query_Driver_Call) Return(_a0 database.Driver) *Query_Driver_Call { _c.Call.Return(_a0) return _c } -func (_c *Query_Driver_Call) RunAndReturn(run func() orm.Driver) *Query_Driver_Call { +func (_c *Query_Driver_Call) RunAndReturn(run func() database.Driver) *Query_Driver_Call { _c.Call.Return(run) return _c } diff --git a/mocks/testing/Database.go b/mocks/testing/Database.go index e0d312398..2b01b1743 100644 --- a/mocks/testing/Database.go +++ b/mocks/testing/Database.go @@ -3,10 +3,11 @@ package testing import ( - orm "github.com/goravel/framework/contracts/database/orm" - seeder "github.com/goravel/framework/contracts/database/seeder" + database "github.com/goravel/framework/contracts/database" mock "github.com/stretchr/testify/mock" + seeder "github.com/goravel/framework/contracts/database/seeder" + testing "github.com/goravel/framework/contracts/testing" ) @@ -114,18 +115,18 @@ func (_c *Database_Config_Call) RunAndReturn(run func() testing.DatabaseConfig) } // Driver provides a mock function with given fields: -func (_m *Database) Driver() orm.Driver { +func (_m *Database) Driver() database.Driver { ret := _m.Called() if len(ret) == 0 { panic("no return value specified for Driver") } - var r0 orm.Driver - if rf, ok := ret.Get(0).(func() orm.Driver); ok { + var r0 database.Driver + if rf, ok := ret.Get(0).(func() database.Driver); ok { r0 = rf() } else { - r0 = ret.Get(0).(orm.Driver) + r0 = ret.Get(0).(database.Driver) } return r0 @@ -148,12 +149,12 @@ func (_c *Database_Driver_Call) Run(run func()) *Database_Driver_Call { return _c } -func (_c *Database_Driver_Call) Return(_a0 orm.Driver) *Database_Driver_Call { +func (_c *Database_Driver_Call) Return(_a0 database.Driver) *Database_Driver_Call { _c.Call.Return(_a0) return _c } -func (_c *Database_Driver_Call) RunAndReturn(run func() orm.Driver) *Database_Driver_Call { +func (_c *Database_Driver_Call) RunAndReturn(run func() database.Driver) *Database_Driver_Call { _c.Call.Return(run) return _c } diff --git a/mocks/testing/DatabaseDriver.go b/mocks/testing/DatabaseDriver.go index fb584fc46..b10f7b61b 100644 --- a/mocks/testing/DatabaseDriver.go +++ b/mocks/testing/DatabaseDriver.go @@ -3,9 +3,10 @@ package testing import ( - orm "github.com/goravel/framework/contracts/database/orm" - testing "github.com/goravel/framework/contracts/testing" + database "github.com/goravel/framework/contracts/database" mock "github.com/stretchr/testify/mock" + + testing "github.com/goravel/framework/contracts/testing" ) // DatabaseDriver is an autogenerated mock type for the DatabaseDriver type @@ -112,18 +113,18 @@ func (_c *DatabaseDriver_Config_Call) RunAndReturn(run func() testing.DatabaseCo } // Driver provides a mock function with given fields: -func (_m *DatabaseDriver) Driver() orm.Driver { +func (_m *DatabaseDriver) Driver() database.Driver { ret := _m.Called() if len(ret) == 0 { panic("no return value specified for Driver") } - var r0 orm.Driver - if rf, ok := ret.Get(0).(func() orm.Driver); ok { + var r0 database.Driver + if rf, ok := ret.Get(0).(func() database.Driver); ok { r0 = rf() } else { - r0 = ret.Get(0).(orm.Driver) + r0 = ret.Get(0).(database.Driver) } return r0 @@ -146,12 +147,12 @@ func (_c *DatabaseDriver_Driver_Call) Run(run func()) *DatabaseDriver_Driver_Cal return _c } -func (_c *DatabaseDriver_Driver_Call) Return(_a0 orm.Driver) *DatabaseDriver_Driver_Call { +func (_c *DatabaseDriver_Driver_Call) Return(_a0 database.Driver) *DatabaseDriver_Driver_Call { _c.Call.Return(_a0) return _c } -func (_c *DatabaseDriver_Driver_Call) RunAndReturn(run func() orm.Driver) *DatabaseDriver_Driver_Call { +func (_c *DatabaseDriver_Driver_Call) RunAndReturn(run func() database.Driver) *DatabaseDriver_Driver_Call { _c.Call.Return(run) return _c } diff --git a/support/docker/mysql.go b/support/docker/mysql.go index 569d83657..6b37c24df 100644 --- a/support/docker/mysql.go +++ b/support/docker/mysql.go @@ -7,7 +7,7 @@ import ( "gorm.io/driver/mysql" gormio "gorm.io/gorm" - "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/contracts/testing" ) @@ -75,8 +75,8 @@ func (receiver *MysqlImpl) Config() testing.DatabaseConfig { } } -func (receiver *MysqlImpl) Driver() orm.Driver { - return orm.DriverMysql +func (receiver *MysqlImpl) Driver() database.Driver { + return database.DriverMysql } func (receiver *MysqlImpl) Fresh() error { diff --git a/support/docker/mysql_test.go b/support/docker/mysql_test.go index d5c216c59..4129f9037 100644 --- a/support/docker/mysql_test.go +++ b/support/docker/mysql_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/suite" - "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/contracts/database" contractstesting "github.com/goravel/framework/contracts/testing" configmocks "github.com/goravel/framework/mocks/config" "github.com/goravel/framework/support/env" @@ -73,7 +73,7 @@ INSERT INTO users (name) VALUES ('goravel'); } func (s *MysqlTestSuite) TestDriver() { - s.Equal(orm.DriverMysql, s.mysql.Driver()) + s.Equal(database.DriverMysql, s.mysql.Driver()) } func (s *MysqlTestSuite) TestImage() { diff --git a/support/docker/postgres.go b/support/docker/postgres.go index ec86e53d5..c69ea9749 100644 --- a/support/docker/postgres.go +++ b/support/docker/postgres.go @@ -7,7 +7,7 @@ import ( "gorm.io/driver/postgres" gormio "gorm.io/gorm" - "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/contracts/testing" ) @@ -70,8 +70,8 @@ func (receiver *PostgresImpl) Config() testing.DatabaseConfig { } } -func (receiver *PostgresImpl) Driver() orm.Driver { - return orm.DriverPostgres +func (receiver *PostgresImpl) Driver() database.Driver { + return database.DriverPostgres } func (receiver *PostgresImpl) Fresh() error { diff --git a/support/docker/postgres_test.go b/support/docker/postgres_test.go index 664563f45..9c4866394 100644 --- a/support/docker/postgres_test.go +++ b/support/docker/postgres_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/suite" - "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/contracts/database" contractstesting "github.com/goravel/framework/contracts/testing" configmocks "github.com/goravel/framework/mocks/config" "github.com/goravel/framework/support/env" @@ -75,7 +75,7 @@ func (s *PostgresTestSuite) TestBuild() { } func (s *PostgresTestSuite) TestDriver() { - s.Equal(orm.DriverPostgres, s.postgres.Driver()) + s.Equal(database.DriverPostgres, s.postgres.Driver()) } func (s *PostgresTestSuite) TestImage() { diff --git a/support/docker/sqlite.go b/support/docker/sqlite.go index 7d4c901b0..01cc733da 100644 --- a/support/docker/sqlite.go +++ b/support/docker/sqlite.go @@ -6,7 +6,7 @@ import ( "github.com/glebarez/sqlite" gormio "gorm.io/gorm" - "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/contracts/testing" "github.com/goravel/framework/support/file" ) @@ -36,8 +36,8 @@ func (receiver *SqliteImpl) Config() testing.DatabaseConfig { } } -func (receiver *SqliteImpl) Driver() orm.Driver { - return orm.DriverSqlite +func (receiver *SqliteImpl) Driver() database.Driver { + return database.DriverSqlite } func (receiver *SqliteImpl) Fresh() error { diff --git a/support/docker/sqlite_test.go b/support/docker/sqlite_test.go index a90f3660e..e20a076c2 100644 --- a/support/docker/sqlite_test.go +++ b/support/docker/sqlite_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/suite" - "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/contracts/database" contractstesting "github.com/goravel/framework/contracts/testing" configmocks "github.com/goravel/framework/mocks/config" "github.com/goravel/framework/support/env" @@ -71,7 +71,7 @@ INSERT INTO users (name) VALUES ('goravel'); } func (s *SqliteTestSuite) TestDriver() { - s.Equal(orm.DriverSqlite, s.sqlite.Driver()) + s.Equal(database.DriverSqlite, s.sqlite.Driver()) } func (s *SqliteTestSuite) TestImage() { diff --git a/support/docker/sqlserver.go b/support/docker/sqlserver.go index b647d3d88..2e39a4c04 100644 --- a/support/docker/sqlserver.go +++ b/support/docker/sqlserver.go @@ -7,7 +7,7 @@ import ( "gorm.io/driver/sqlserver" gormio "gorm.io/gorm" - "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/contracts/testing" ) @@ -69,8 +69,8 @@ func (receiver *SqlserverImpl) Config() testing.DatabaseConfig { } } -func (receiver *SqlserverImpl) Driver() orm.Driver { - return orm.DriverSqlserver +func (receiver *SqlserverImpl) Driver() database.Driver { + return database.DriverSqlserver } func (receiver *SqlserverImpl) Fresh() error { diff --git a/support/docker/sqlserver_test.go b/support/docker/sqlserver_test.go index 9426abd00..89494fd70 100644 --- a/support/docker/sqlserver_test.go +++ b/support/docker/sqlserver_test.go @@ -76,7 +76,7 @@ func (s *SqlserverTestSuite) TestBuild() { } func (s *SqlserverTestSuite) TestDriver() { - s.Equal(orm.DriverSqlserver, s.sqlserver.Driver()) + s.Equal(database.DriverSqlserver, s.sqlserver.Driver()) } func (s *SqlserverTestSuite) TestImage() { diff --git a/testing/docker/database_test.go b/testing/docker/database_test.go index bb9505a8f..30b72a244 100644 --- a/testing/docker/database_test.go +++ b/testing/docker/database_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" - contractsorm "github.com/goravel/framework/contracts/database/orm" + contractsdatabase "github.com/goravel/framework/contracts/database" frameworkdatabase "github.com/goravel/framework/database" mocksconfig "github.com/goravel/framework/mocks/config" mocksconsole "github.com/goravel/framework/mocks/console" @@ -48,14 +48,13 @@ func TestNewDatabase(t *testing.T) { name: "success when connection is empty", setup: func() { mockConfig.EXPECT().GetString("database.default").Return("mysql").Once() - mockConfig.EXPECT().GetString("database.connections.mysql.driver").Return(contractsorm.DriverMysql.String()).Once() + mockConfig.EXPECT().GetString("database.connections.mysql.driver").Return(contractsdatabase.DriverMysql.String()).Once() mockConfig.EXPECT().GetString("database.connections.mysql.database").Return(testDatabase).Once() mockConfig.EXPECT().GetString("database.connections.mysql.username").Return(testUsername).Once() mockConfig.EXPECT().GetString("database.connections.mysql.password").Return(testPassword).Once() }, wantDatabase: func() *Database { return &Database{ - app: mockApp, artisan: mockArtisan, config: mockConfig, connection: "mysql", @@ -67,14 +66,13 @@ func TestNewDatabase(t *testing.T) { name: "success when connection is mysql", connection: "mysql", setup: func() { - mockConfig.EXPECT().GetString("database.connections.mysql.driver").Return(contractsorm.DriverMysql.String()).Once() + mockConfig.EXPECT().GetString("database.connections.mysql.driver").Return(contractsdatabase.DriverMysql.String()).Once() mockConfig.EXPECT().GetString("database.connections.mysql.database").Return(testDatabase).Once() mockConfig.EXPECT().GetString("database.connections.mysql.username").Return(testUsername).Once() mockConfig.EXPECT().GetString("database.connections.mysql.password").Return(testPassword).Once() }, wantDatabase: func() *Database { return &Database{ - app: mockApp, artisan: mockArtisan, config: mockConfig, connection: "mysql", @@ -86,14 +84,13 @@ func TestNewDatabase(t *testing.T) { name: "success when connection is postgres", connection: "postgres", setup: func() { - mockConfig.EXPECT().GetString("database.connections.postgres.driver").Return(contractsorm.DriverPostgres.String()).Once() + mockConfig.EXPECT().GetString("database.connections.postgres.driver").Return(contractsdatabase.DriverPostgres.String()).Once() mockConfig.EXPECT().GetString("database.connections.postgres.database").Return(testDatabase).Once() mockConfig.EXPECT().GetString("database.connections.postgres.username").Return(testUsername).Once() mockConfig.EXPECT().GetString("database.connections.postgres.password").Return(testPassword).Once() }, wantDatabase: func() *Database { return &Database{ - app: mockApp, artisan: mockArtisan, config: mockConfig, connection: "postgres", @@ -105,14 +102,13 @@ func TestNewDatabase(t *testing.T) { name: "success when connection is sqlserver", connection: "sqlserver", setup: func() { - mockConfig.EXPECT().GetString("database.connections.sqlserver.driver").Return(contractsorm.DriverSqlserver.String()).Once() + mockConfig.EXPECT().GetString("database.connections.sqlserver.driver").Return(contractsdatabase.DriverSqlserver.String()).Once() mockConfig.EXPECT().GetString("database.connections.sqlserver.database").Return(testDatabase).Once() mockConfig.EXPECT().GetString("database.connections.sqlserver.username").Return(testUsername).Once() mockConfig.EXPECT().GetString("database.connections.sqlserver.password").Return(testPassword).Once() }, wantDatabase: func() *Database { return &Database{ - app: mockApp, artisan: mockArtisan, config: mockConfig, connection: "sqlserver", @@ -124,14 +120,13 @@ func TestNewDatabase(t *testing.T) { name: "success when connection is sqlite", connection: "sqlite", setup: func() { - mockConfig.EXPECT().GetString("database.connections.sqlite.driver").Return(contractsorm.DriverSqlite.String()).Once() + mockConfig.EXPECT().GetString("database.connections.sqlite.driver").Return(contractsdatabase.DriverSqlite.String()).Once() mockConfig.EXPECT().GetString("database.connections.sqlite.database").Return(testDatabase).Once() mockConfig.EXPECT().GetString("database.connections.sqlite.username").Return(testUsername).Once() mockConfig.EXPECT().GetString("database.connections.sqlite.password").Return(testPassword).Once() }, wantDatabase: func() *Database { return &Database{ - app: mockApp, artisan: mockArtisan, config: mockConfig, connection: "sqlite", @@ -170,7 +165,6 @@ func (s *DatabaseTestSuite) SetupTest() { s.mockArtisan = mocksconsole.NewArtisan(s.T()) s.mockConfig = mocksconfig.NewConfig(s.T()) s.database = &Database{ - app: s.mockApp, artisan: s.mockArtisan, config: s.mockConfig, connection: "mysql", From c5c4706706dc496e218931f18d42bafebc361896 Mon Sep 17 00:00:00 2001 From: Bowen Date: Tue, 1 Oct 2024 17:42:18 +0800 Subject: [PATCH 03/11] fix test --- database/db/dsn.go | 25 ++--- database/db/dsn_test.go | 6 + database/gorm/dialector_test.go | 184 +++++++++++++++++-------------- database/gorm/test_utils.go | 132 +++++++++++----------- support/docker/sqlserver_test.go | 8 +- 5 files changed, 186 insertions(+), 169 deletions(-) diff --git a/database/db/dsn.go b/database/db/dsn.go index 492a64b45..595ead700 100644 --- a/database/db/dsn.go +++ b/database/db/dsn.go @@ -7,33 +7,22 @@ import ( ) func Dsn(config database.FullConfig) string { + if config.Host == "" { + return "" + } + switch config.Driver { case database.DriverMysql: - host := config.Host - if host == "" { - return "" - } - return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=%t&loc=%s&multiStatements=true", - config.Username, config.Password, host, config.Port, config.Database, config.Charset, true, config.Loc) + config.Username, config.Password, config.Host, config.Port, config.Database, config.Charset, true, config.Loc) case database.DriverPostgres: - host := config.Host - if host == "" { - return "" - } - return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s&timezone=%s", - config.Username, config.Password, host, config.Port, config.Database, config.Sslmode, config.Timezone) + config.Username, config.Password, config.Host, config.Port, config.Database, config.Sslmode, config.Timezone) case database.DriverSqlite: return fmt.Sprintf("%s?multi_stmts=true", config.Database) case database.DriverSqlserver: - host := config.Host - if host == "" { - return "" - } - return fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s&charset=%s&MultipleActiveResultSets=true", - config.Username, config.Password, host, config.Port, config.Database, config.Charset) + config.Username, config.Password, config.Host, config.Port, config.Database, config.Charset) default: return "" } diff --git a/database/db/dsn_test.go b/database/db/dsn_test.go index 7a5ca4245..eaa6bbc92 100644 --- a/database/db/dsn_test.go +++ b/database/db/dsn_test.go @@ -31,6 +31,12 @@ func TestDsn(t *testing.T) { config database.FullConfig expectDsn string }{ + { + name: "empty", + config: database.FullConfig{ + Config: database.Config{}, + }, + }, { name: "mysql", config: database.FullConfig{ diff --git a/database/gorm/dialector_test.go b/database/gorm/dialector_test.go index b27c5e9e4..cb475de94 100644 --- a/database/gorm/dialector_test.go +++ b/database/gorm/dialector_test.go @@ -1,90 +1,112 @@ package gorm import ( + "fmt" "testing" -) -//type DialectorTestSuite struct { -// suite.Suite -// mockConfig *configmock.Config -// config databasecontract.Config -//} -// -//func TestDialectorTestSuite(t *testing.T) { -// suite.Run(t, &DialectorTestSuite{ -// config: databasecontract.Config{ -// Host: "localhost", -// Port: 3306, -// Database: "forge", -// Username: "root", -// Password: "123123", -// }, -// }) -//} -// -//func (s *DialectorTestSuite) SetupTest() { -// s.mockConfig = &configmock.Config{} -//} -// -//func (s *DialectorTestSuite) TestMysql() { -// dialector := NewDialector(s.mockConfig, orm.DriverMysql.String()) -// s.mockConfig.On("GetString", "database.connections.mysql.driver"). -// Return(orm.DriverMysql.String()).Once() -// s.mockConfig.On("GetString", "database.connections.mysql.charset"). -// Return("utf8mb4").Once() -// s.mockConfig.On("GetString", "database.connections.mysql.loc"). -// Return("Local").Once() -// dialectors, err := dialector.Make([]databasecontract.Config{s.config}) -// s.Nil(err) -// s.NotEmpty(dialectors) -// s.Equal(mysql.New(mysql.Config{ -// DSN: fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=%t&loc=%s&multiStatements=true", -// s.config.Username, s.config.Password, s.config.Host, s.config.Port, s.config.Database, "utf8mb4", true, "Local"), -// }), dialectors[0]) -//} -// -//func (s *DialectorTestSuite) TestPostgres() { -// dialector := NewDialector(s.mockConfig, orm.DriverPostgres.String()) -// s.mockConfig.On("GetString", "database.connections.postgres.driver"). -// Return(orm.DriverPostgres.String()).Once() -// s.mockConfig.On("GetString", "database.connections.postgres.sslmode"). -// Return("disable").Once() -// s.mockConfig.On("GetString", "database.connections.postgres.timezone"). -// Return("UTC").Once() -// dialectors, err := dialector.Make([]databasecontract.Config{s.config}) -// s.Nil(err) -// s.NotEmpty(dialectors) -// s.Equal(postgres.New(postgres.Config{ -// DSN: fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=%s&timezone=%s", -// s.config.Username, s.config.Password, s.config.Host, s.config.Port, s.config.Database, "disable", "UTC"), -// }), dialectors[0]) -//} -// -//func (s *DialectorTestSuite) TestSqlite() { -// dialector := NewDialector(s.mockConfig, orm.DriverSqlite.String()) -// s.mockConfig.On("GetString", "database.connections.sqlite.driver"). -// Return(orm.DriverSqlite.String()).Once() -// dialectors, err := dialector.Make([]databasecontract.Config{s.config}) -// s.Nil(err) -// s.NotEmpty(dialectors) -// s.Equal(sqlite.Open(fmt.Sprintf("%s?multi_stmts=true", s.config.Database)), dialectors[0]) -//} -// -//func (s *DialectorTestSuite) TestSqlserver() { -// dialector := NewDialector(s.mockConfig, orm.DriverSqlserver.String()) -// s.mockConfig.On("GetString", "database.connections.sqlserver.driver"). -// Return(orm.DriverSqlserver.String()).Once() -// s.mockConfig.On("GetString", "database.connections.sqlserver.charset"). -// Return("utf8mb4").Once() -// dialectors, err := dialector.Make([]databasecontract.Config{s.config}) -// s.Nil(err) -// s.NotEmpty(dialectors) -// s.Equal(sqlserver.New(sqlserver.Config{ -// DSN: fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s&charset=%s&MultipleActiveResultSets=true", -// s.config.Username, s.config.Password, s.config.Host, s.config.Port, s.config.Database, "utf8mb4"), -// }), dialectors[0]) -//} + "github.com/glebarez/sqlite" + "github.com/stretchr/testify/assert" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlserver" + "gorm.io/gorm" + + "github.com/goravel/framework/contracts/database" +) func TestGetDialectors(t *testing.T) { + config := database.Config{ + Host: "localhost", + } + + tests := []struct { + name string + configs []database.FullConfig + expectDialectors func(dialector gorm.Dialector) bool + expectError error + }{ + { + name: "Sad path - dsn is empty", + configs: []database.FullConfig{ + { + Connection: "postgres", + }, + }, + expectError: fmt.Errorf("failed to get dsn for postgres"), + }, + { + name: "Happy path - mysql", + configs: []database.FullConfig{ + { + Connection: "mysql", + Driver: database.DriverMysql, + Config: config, + }, + }, + expectDialectors: func(dialector gorm.Dialector) bool { + _, ok := dialector.(*mysql.Dialector) + + return ok + }, + }, + { + name: "Happy path - postgres", + configs: []database.FullConfig{ + { + Connection: "postgres", + Driver: database.DriverPostgres, + Config: config, + }, + }, + expectDialectors: func(dialector gorm.Dialector) bool { + _, ok := dialector.(*postgres.Dialector) + + return ok + }, + }, + { + name: "Happy path - sqlserver", + configs: []database.FullConfig{ + { + Connection: "sqlserver", + Driver: database.DriverSqlserver, + Config: config, + }, + }, + expectDialectors: func(dialector gorm.Dialector) bool { + _, ok := dialector.(*sqlserver.Dialector) + + return ok + }, + }, + { + name: "Happy path - sqlite", + configs: []database.FullConfig{ + { + Connection: "sqlite", + Driver: database.DriverSqlite, + Config: config, + }, + }, + expectDialectors: func(dialector gorm.Dialector) bool { + _, ok := dialector.(*sqlite.Dialector) + + return ok + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dialectors, err := GetDialectors(test.configs) + if test.expectError != nil { + assert.Equal(t, test.expectError, err) + assert.Nil(t, dialectors) + } else { + assert.NoError(t, err) + assert.Len(t, dialectors, 1) + assert.True(t, test.expectDialectors(dialectors[0])) + } + }) + } } diff --git a/database/gorm/test_utils.go b/database/gorm/test_utils.go index 010506609..2a9910bda 100644 --- a/database/gorm/test_utils.go +++ b/database/gorm/test_utils.go @@ -20,7 +20,7 @@ const ( TestModelNormal // Switch this value to control the test model. - TestModel = TestModelMinimum + TestModel = TestModelNormal ) type TestTable int @@ -322,49 +322,49 @@ func NewMockMysql(mockConfig *mocksconfig.Config, connection, database, username func (r *MockMysql) Common() { r.mockConfig.On("GetString", "contractsdatabase.default").Return("mysql") r.mockConfig.On("GetString", "contractsdatabase.migrations").Return("migrations") - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("") - r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(false) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") + r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) r.single() r.basic() } func (r *MockMysql) ReadWrite(config TestReadWriteConfig) { - r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.read", r.connection)).Return([]contractsdatabase.Config{ + r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.read", r.connection)).Return([]contractsdatabase.Config{ {Host: "127.0.0.1", Port: config.ReadPort, Username: r.user, Password: r.password}, }) - r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.write", r.connection)).Return([]contractsdatabase.Config{ + r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.write", r.connection)).Return([]contractsdatabase.Config{ {Host: "127.0.0.1", Port: config.WritePort, Username: r.user, Password: r.password}, }) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("") - r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(false) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") + r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) r.basic() } func (r *MockMysql) WithPrefixAndSingular() { - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("goravel_") - r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(true) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("goravel_") + r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(true) r.single() r.basic() } func (r *MockMysql) basic() { r.mockConfig.On("GetBool", "app.debug").Return(true) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.driver", r.connection)).Return(r.driver.String()) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.charset", r.connection)).Return("utf8mb4") - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.loc", r.connection)).Return("Local") - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.database", r.connection)).Return(r.database) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.driver", r.connection)).Return(r.driver.String()) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.charset", r.connection)).Return("utf8mb4") + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.loc", r.connection)).Return("Local") + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.database", r.connection)).Return(r.database) mockPool(r.mockConfig) } func (r *MockMysql) single() { - r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.read", r.connection)).Return(nil) - r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.write", r.connection)).Return(nil) + r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.read", r.connection)).Return(nil) + r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.write", r.connection)).Return(nil) r.mockConfig.On("GetBool", "app.debug").Return(true) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.host", r.connection)).Return("127.0.0.1") - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.username", r.connection)).Return(r.user) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.password", r.connection)).Return(r.password) - r.mockConfig.On("GetInt", fmt.Sprintf("contractsdatabase.connections.%s.port", r.connection)).Return(r.port) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.host", r.connection)).Return("127.0.0.1") + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.username", r.connection)).Return(r.user) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.password", r.connection)).Return(r.password) + r.mockConfig.On("GetInt", fmt.Sprintf("database.connections.%s.port", r.connection)).Return(r.port) } type MockPostgres struct { @@ -393,48 +393,48 @@ func NewMockPostgres(mockConfig *mocksconfig.Config, connection, database, usern func (r *MockPostgres) Common() { r.mockConfig.On("GetString", "contractsdatabase.default").Return("postgres") r.mockConfig.On("GetString", "contractsdatabase.migrations").Return("migrations") - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("") - r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(false) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") + r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) r.single() r.basic() } func (r *MockPostgres) ReadWrite(config TestReadWriteConfig) { - r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.read", r.connection)).Return([]contractsdatabase.Config{ + r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.read", r.connection)).Return([]contractsdatabase.Config{ {Host: "127.0.0.1", Port: config.ReadPort, Username: r.user, Password: r.password}, }) - r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.write", r.connection)).Return([]contractsdatabase.Config{ + r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.write", r.connection)).Return([]contractsdatabase.Config{ {Host: "127.0.0.1", Port: config.WritePort, Username: r.user, Password: r.password}, }) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("") - r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(false) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") + r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) r.basic() } func (r *MockPostgres) WithPrefixAndSingular() { - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("goravel_") - r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(true) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("goravel_") + r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(true) r.single() r.basic() } func (r *MockPostgres) basic() { r.mockConfig.On("GetBool", "app.debug").Return(true) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.driver", r.connection)).Return(r.driver.String()) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.sslmode", r.connection)).Return("disable") - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.timezone", r.connection)).Return("UTC") - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.database", r.connection)).Return(r.database) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.driver", r.connection)).Return(r.driver.String()) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.sslmode", r.connection)).Return("disable") + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.timezone", r.connection)).Return("UTC") + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.database", r.connection)).Return(r.database) mockPool(r.mockConfig) } func (r *MockPostgres) single() { - r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.read", r.connection)).Return(nil) - r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.write", r.connection)).Return(nil) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.host", r.connection)).Return("127.0.0.1") - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.username", r.connection)).Return(r.user) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.password", r.connection)).Return(r.password) - r.mockConfig.On("GetInt", fmt.Sprintf("contractsdatabase.connections.%s.port", r.connection)).Return(r.port) + r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.read", r.connection)).Return(nil) + r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.write", r.connection)).Return(nil) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.host", r.connection)).Return("127.0.0.1") + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.username", r.connection)).Return(r.user) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.password", r.connection)).Return(r.password) + r.mockConfig.On("GetInt", fmt.Sprintf("database.connections.%s.port", r.connection)).Return(r.port) } type MockSqlite struct { @@ -457,41 +457,41 @@ func NewMockSqlite(mockConfig *mocksconfig.Config, connection, database string) func (r *MockSqlite) Common() { r.mockConfig.On("GetString", "contractsdatabase.default").Return("sqlite") r.mockConfig.On("GetString", "contractsdatabase.migrations").Return("migrations") - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("") - r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(false) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") + r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) r.single() r.basic() } func (r *MockSqlite) ReadWrite(config TestReadWriteConfig) { - r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.read", r.connection)).Return([]contractsdatabase.Config{ + r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.read", r.connection)).Return([]contractsdatabase.Config{ {Database: config.ReadDatabase}, }) - r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.write", r.connection)).Return([]contractsdatabase.Config{ + r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.write", r.connection)).Return([]contractsdatabase.Config{ {Database: r.database}, }) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("") - r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(false) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") + r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) r.basic() } func (r *MockSqlite) WithPrefixAndSingular() { - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("goravel_") - r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(true) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("goravel_") + r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(true) r.single() r.basic() } func (r *MockSqlite) basic() { r.mockConfig.On("GetBool", "app.debug").Return(true) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.driver", r.connection)).Return(r.driver.String()) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.driver", r.connection)).Return(r.driver.String()) mockPool(r.mockConfig) } func (r *MockSqlite) single() { - r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.read", r.connection)).Return(nil) - r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.write", r.connection)).Return(nil) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.database", r.connection)).Return(r.database) + r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.read", r.connection)).Return(nil) + r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.write", r.connection)).Return(nil) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.database", r.connection)).Return(r.database) } type MockSqlserver struct { @@ -520,46 +520,46 @@ func NewMockSqlserver(mockConfig *mocksconfig.Config, connection, database, user func (r *MockSqlserver) Common() { r.mockConfig.On("GetString", "contractsdatabase.default").Return("sqlserver") r.mockConfig.On("GetString", "contractsdatabase.migrations").Return("migrations") - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("") - r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(false) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") + r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) r.single() r.basic() } func (r *MockSqlserver) ReadWrite(config TestReadWriteConfig) { - r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.read", r.connection)).Return([]contractsdatabase.Config{ + r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.read", r.connection)).Return([]contractsdatabase.Config{ {Host: "127.0.0.1", Port: config.ReadPort, Username: r.user, Password: r.password}, }) - r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.write", r.connection)).Return([]contractsdatabase.Config{ + r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.write", r.connection)).Return([]contractsdatabase.Config{ {Host: "127.0.0.1", Port: config.WritePort, Username: r.user, Password: r.password}, }) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("") - r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(false) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") + r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) r.basic() } func (r *MockSqlserver) WithPrefixAndSingular() { - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.prefix", r.connection)).Return("goravel_") - r.mockConfig.On("GetBool", fmt.Sprintf("contractsdatabase.connections.%s.singular", r.connection)).Return(true) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("goravel_") + r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(true) r.single() r.basic() } func (r *MockSqlserver) basic() { r.mockConfig.On("GetBool", "app.debug").Return(true) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.driver", r.connection)).Return(r.driver.String()) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.database", r.connection)).Return(r.database) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.charset", r.connection)).Return("utf8mb4") + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.driver", r.connection)).Return(r.driver.String()) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.database", r.connection)).Return(r.database) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.charset", r.connection)).Return("utf8mb4") mockPool(r.mockConfig) } func (r *MockSqlserver) single() { - r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.read", r.connection)).Return(nil) - r.mockConfig.On("Get", fmt.Sprintf("contractsdatabase.connections.%s.write", r.connection)).Return(nil) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.host", r.connection)).Return("127.0.0.1") - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.username", r.connection)).Return(r.user) - r.mockConfig.On("GetString", fmt.Sprintf("contractsdatabase.connections.%s.password", r.connection)).Return(r.password) - r.mockConfig.On("GetInt", fmt.Sprintf("contractsdatabase.connections.%s.port", r.connection)).Return(r.port) + r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.read", r.connection)).Return(nil) + r.mockConfig.On("Get", fmt.Sprintf("database.connections.%s.write", r.connection)).Return(nil) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.host", r.connection)).Return("127.0.0.1") + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.username", r.connection)).Return(r.user) + r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.password", r.connection)).Return(r.password) + r.mockConfig.On("GetInt", fmt.Sprintf("database.connections.%s.port", r.connection)).Return(r.port) } type testTables struct { diff --git a/support/docker/sqlserver_test.go b/support/docker/sqlserver_test.go index 89494fd70..ebc1a0aea 100644 --- a/support/docker/sqlserver_test.go +++ b/support/docker/sqlserver_test.go @@ -5,15 +5,15 @@ import ( "github.com/stretchr/testify/suite" - "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/contracts/database" contractstesting "github.com/goravel/framework/contracts/testing" - configmocks "github.com/goravel/framework/mocks/config" + mocksconfig "github.com/goravel/framework/mocks/config" "github.com/goravel/framework/support/env" ) type SqlserverTestSuite struct { suite.Suite - mockConfig *configmocks.Config + mockConfig *mocksconfig.Config sqlserver *SqlserverImpl } @@ -26,7 +26,7 @@ func TestSqlserverTestSuite(t *testing.T) { } func (s *SqlserverTestSuite) SetupTest() { - s.mockConfig = &configmocks.Config{} + s.mockConfig = &mocksconfig.Config{} s.sqlserver = NewSqlserverImpl(testDatabase, testUsername, testPassword) } From 2ae891804e47bb78b619ec1818711de4bb3a74d1 Mon Sep 17 00:00:00 2001 From: Bowen Date: Tue, 1 Oct 2024 17:57:14 +0800 Subject: [PATCH 04/11] fix test --- database/db/dsn.go | 2 +- database/gorm/test_utils.go | 10 +++++----- database/migration/sql_driver.go | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/database/db/dsn.go b/database/db/dsn.go index 595ead700..5b2a0628e 100644 --- a/database/db/dsn.go +++ b/database/db/dsn.go @@ -7,7 +7,7 @@ import ( ) func Dsn(config database.FullConfig) string { - if config.Host == "" { + if config.Host == "" && config.Driver != database.DriverSqlite { return "" } diff --git a/database/gorm/test_utils.go b/database/gorm/test_utils.go index 2a9910bda..f2242fea9 100644 --- a/database/gorm/test_utils.go +++ b/database/gorm/test_utils.go @@ -20,7 +20,7 @@ const ( TestModelNormal // Switch this value to control the test model. - TestModel = TestModelNormal + TestModel = TestModelMinimum ) type TestTable int @@ -1214,8 +1214,8 @@ CREATE TABLE role_user ( } func mockPool(mockConfig *mocksconfig.Config) { - mockConfig.On("GetInt", "contractsdatabase.pool.max_idle_conns", 10).Return(10) - mockConfig.On("GetInt", "contractsdatabase.pool.max_open_conns", 100).Return(100) - mockConfig.On("GetInt", "contractsdatabase.pool.conn_max_idletime", 3600).Return(3600) - mockConfig.On("GetInt", "contractsdatabase.pool.conn_max_lifetime", 3600).Return(3600) + mockConfig.On("GetInt", "database.pool.max_idle_conns", 10).Return(10) + mockConfig.On("GetInt", "database.pool.max_open_conns", 100).Return(100) + mockConfig.On("GetInt", "database.pool.conn_max_idletime", 3600).Return(3600) + mockConfig.On("GetInt", "database.pool.conn_max_lifetime", 3600).Return(3600) } diff --git a/database/migration/sql_driver.go b/database/migration/sql_driver.go index 56264d5c8..5fc3ac973 100644 --- a/database/migration/sql_driver.go +++ b/database/migration/sql_driver.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/goravel/framework/contracts/config" - "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/support/carbon" "github.com/goravel/framework/support/file" ) From 9bb103cabedc15d5d4c4b0cf751a8e5f672a2802 Mon Sep 17 00:00:00 2001 From: Bowen Date: Tue, 1 Oct 2024 18:11:10 +0800 Subject: [PATCH 05/11] fix test --- database/gorm/test_utils.go | 2 +- testing/docker/database_test.go | 23 ++++++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/database/gorm/test_utils.go b/database/gorm/test_utils.go index f2242fea9..c171b2668 100644 --- a/database/gorm/test_utils.go +++ b/database/gorm/test_utils.go @@ -20,7 +20,7 @@ const ( TestModelNormal // Switch this value to control the test model. - TestModel = TestModelMinimum + TestModel = TestModelNormal ) type TestTable int diff --git a/testing/docker/database_test.go b/testing/docker/database_test.go index 30b72a244..f2698c03a 100644 --- a/testing/docker/database_test.go +++ b/testing/docker/database_test.go @@ -8,9 +8,9 @@ import ( "github.com/stretchr/testify/suite" contractsdatabase "github.com/goravel/framework/contracts/database" - frameworkdatabase "github.com/goravel/framework/database" mocksconfig "github.com/goravel/framework/mocks/config" mocksconsole "github.com/goravel/framework/mocks/console" + mocksorm "github.com/goravel/framework/mocks/database/orm" mocksfoundation "github.com/goravel/framework/mocks/foundation" supportdocker "github.com/goravel/framework/support/docker" "github.com/goravel/framework/support/env" @@ -27,14 +27,17 @@ func TestNewDatabase(t *testing.T) { mockApp *mocksfoundation.Application mockArtisan *mocksconsole.Artisan mockConfig *mocksconfig.Config + mockOrm *mocksorm.Orm ) beforeEach := func() { mockApp = mocksfoundation.NewApplication(t) mockArtisan = mocksconsole.NewArtisan(t) mockConfig = mocksconfig.NewConfig(t) + mockOrm = mocksorm.NewOrm(t) mockApp.EXPECT().MakeArtisan().Return(mockArtisan).Once() mockApp.EXPECT().MakeConfig().Return(mockConfig).Once() + mockApp.EXPECT().MakeOrm().Return(mockOrm).Once() } tests := []struct { @@ -58,6 +61,7 @@ func TestNewDatabase(t *testing.T) { artisan: mockArtisan, config: mockConfig, connection: "mysql", + orm: mockOrm, DatabaseDriver: supportdocker.NewMysqlImpl(testDatabase, testUsername, testPassword), } }, @@ -76,6 +80,7 @@ func TestNewDatabase(t *testing.T) { artisan: mockArtisan, config: mockConfig, connection: "mysql", + orm: mockOrm, DatabaseDriver: supportdocker.NewMysqlImpl(testDatabase, testUsername, testPassword), } }, @@ -94,6 +99,7 @@ func TestNewDatabase(t *testing.T) { artisan: mockArtisan, config: mockConfig, connection: "postgres", + orm: mockOrm, DatabaseDriver: supportdocker.NewPostgresImpl(testDatabase, testUsername, testPassword), } }, @@ -112,6 +118,7 @@ func TestNewDatabase(t *testing.T) { artisan: mockArtisan, config: mockConfig, connection: "sqlserver", + orm: mockOrm, DatabaseDriver: supportdocker.NewSqlserverImpl(testDatabase, testUsername, testPassword), } }, @@ -130,6 +137,7 @@ func TestNewDatabase(t *testing.T) { artisan: mockArtisan, config: mockConfig, connection: "sqlite", + orm: mockOrm, DatabaseDriver: supportdocker.NewSqliteImpl(testDatabase), } }, @@ -153,6 +161,7 @@ type DatabaseTestSuite struct { mockApp *mocksfoundation.Application mockArtisan *mocksconsole.Artisan mockConfig *mocksconfig.Config + mockOrm *mocksorm.Orm database *Database } @@ -164,11 +173,13 @@ func (s *DatabaseTestSuite) SetupTest() { s.mockApp = mocksfoundation.NewApplication(s.T()) s.mockArtisan = mocksconsole.NewArtisan(s.T()) s.mockConfig = mocksconfig.NewConfig(s.T()) + s.mockOrm = mocksorm.NewOrm(s.T()) s.database = &Database{ artisan: s.mockArtisan, config: s.mockConfig, - connection: "mysql", - DatabaseDriver: supportdocker.NewMysqlImpl(testDatabase, testUsername, testPassword), + connection: "postgres", + orm: s.mockOrm, + DatabaseDriver: supportdocker.NewPostgresImpl(testDatabase, testUsername, testPassword), } } @@ -177,9 +188,9 @@ func (s *DatabaseTestSuite) TestBuild() { s.T().Skip("Skipping tests of using docker") } - s.mockConfig.EXPECT().Add("database.connections.mysql.port", mock.Anything).Once() + s.mockConfig.EXPECT().Add("database.connections.postgres.port", mock.Anything).Once() s.mockArtisan.EXPECT().Call("migrate").Once() - s.mockApp.EXPECT().Singleton(frameworkdatabase.BindingOrm, mock.Anything).Once() + s.mockOrm.EXPECT().Refresh().Once() s.Nil(s.database.Build()) s.True(s.database.Config().Port > 0) @@ -196,11 +207,9 @@ func (s *DatabaseTestSuite) TestConfig() { func (s *DatabaseTestSuite) TestSeed() { s.mockArtisan.EXPECT().Call("db:seed").Once() - s.database.Seed() s.mockArtisan.EXPECT().Call("db:seed --seeder mock").Once() - s.database.Seed(&MockSeeder{}) } From 3a99442c25775a3af643e57111ebf22bdee62f18 Mon Sep 17 00:00:00 2001 From: Bowen Date: Tue, 1 Oct 2024 18:58:36 +0800 Subject: [PATCH 06/11] fix test --- database/gorm/test_utils.go | 18 +++++++-------- database/orm_test.go | 5 ----- foundation/application_test.go | 41 +++++++++++++++++----------------- foundation/container_test.go | 14 +++++++++++- testing/docker/docker_test.go | 5 +++++ 5 files changed, 47 insertions(+), 36 deletions(-) diff --git a/database/gorm/test_utils.go b/database/gorm/test_utils.go index c171b2668..ab055aeb7 100644 --- a/database/gorm/test_utils.go +++ b/database/gorm/test_utils.go @@ -20,7 +20,7 @@ const ( TestModelNormal // Switch this value to control the test model. - TestModel = TestModelNormal + TestModel = TestModelMinimum ) type TestTable int @@ -320,8 +320,8 @@ func NewMockMysql(mockConfig *mocksconfig.Config, connection, database, username } func (r *MockMysql) Common() { - r.mockConfig.On("GetString", "contractsdatabase.default").Return("mysql") - r.mockConfig.On("GetString", "contractsdatabase.migrations").Return("migrations") + r.mockConfig.On("GetString", "database.default").Return("mysql") + r.mockConfig.On("GetString", "database.migrations").Return("migrations") r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) r.single() @@ -391,8 +391,8 @@ func NewMockPostgres(mockConfig *mocksconfig.Config, connection, database, usern } func (r *MockPostgres) Common() { - r.mockConfig.On("GetString", "contractsdatabase.default").Return("postgres") - r.mockConfig.On("GetString", "contractsdatabase.migrations").Return("migrations") + r.mockConfig.On("GetString", "database.default").Return("postgres") + r.mockConfig.On("GetString", "database.migrations").Return("migrations") r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) r.single() @@ -455,8 +455,8 @@ func NewMockSqlite(mockConfig *mocksconfig.Config, connection, database string) } func (r *MockSqlite) Common() { - r.mockConfig.On("GetString", "contractsdatabase.default").Return("sqlite") - r.mockConfig.On("GetString", "contractsdatabase.migrations").Return("migrations") + r.mockConfig.On("GetString", "database.default").Return("sqlite") + r.mockConfig.On("GetString", "database.migrations").Return("migrations") r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) r.single() @@ -518,8 +518,8 @@ func NewMockSqlserver(mockConfig *mocksconfig.Config, connection, database, user } func (r *MockSqlserver) Common() { - r.mockConfig.On("GetString", "contractsdatabase.default").Return("sqlserver") - r.mockConfig.On("GetString", "contractsdatabase.migrations").Return("migrations") + r.mockConfig.On("GetString", "database.default").Return("sqlserver") + r.mockConfig.On("GetString", "database.migrations").Return("migrations") r.mockConfig.On("GetString", fmt.Sprintf("database.connections.%s.prefix", r.connection)).Return("") r.mockConfig.On("GetBool", fmt.Sprintf("database.connections.%s.singular", r.connection)).Return(false) r.single() diff --git a/database/orm_test.go b/database/orm_test.go index b2a45b235..52a188325 100644 --- a/database/orm_test.go +++ b/database/orm_test.go @@ -12,7 +12,6 @@ import ( "github.com/goravel/framework/database/gorm" "github.com/goravel/framework/database/orm" "github.com/goravel/framework/support/env" - "github.com/goravel/framework/support/file" ) type contextKey int @@ -59,10 +58,6 @@ func (s *OrmSuite) SetupTest() { } } -func (s *OrmSuite) TearDownSuite() { - s.Nil(file.Remove("goravel")) -} - func (s *OrmSuite) TestConnection() { for driver := range s.testQueries { s.NotNil(s.orm.Connection(driver.String())) diff --git a/foundation/application_test.go b/foundation/application_test.go index 123af9013..a53c2bbd8 100644 --- a/foundation/application_test.go +++ b/foundation/application_test.go @@ -307,27 +307,27 @@ func (s *ApplicationTestSuite) TestMakeOrm() { s.T().Skip("Skipping tests of using docker") } - mysqlDocker := supportdocker.Mysql() - config := mysqlDocker.Config() - mockConfig := &mocksconfig.Config{} - mockConfig.EXPECT().GetString("database.default").Return("mysql").Once() - mockConfig.EXPECT().Get("database.connections.mysql.read").Return(nil).Once() - mockConfig.EXPECT().Get("database.connections.mysql.write").Return(nil).Once() - mockConfig.EXPECT().GetString("database.connections.mysql.driver").Return(contractsdatabase.DriverMysql.String()).Twice() - mockConfig.EXPECT().GetString("database.connections.mysql.charset").Return("utf8mb4").Once() - mockConfig.EXPECT().GetString("database.connections.mysql.loc").Return("Local").Once() - mockConfig.EXPECT().GetString("database.connections.mysql.database").Return(config.Database).Once() - mockConfig.EXPECT().GetString("database.connections.mysql.host").Return("localhost").Once() - mockConfig.EXPECT().GetString("database.connections.mysql.username").Return(config.Username).Once() - mockConfig.EXPECT().GetString("database.connections.mysql.password").Return(config.Password).Once() - mockConfig.EXPECT().GetString("database.connections.mysql.prefix").Return("").Once() - mockConfig.EXPECT().GetInt("database.connections.mysql.port").Return(config.Port).Once() - mockConfig.EXPECT().GetBool("database.connections.mysql.singular").Return(true).Once() + postgresDocker := supportdocker.Postgres() + config := postgresDocker.Config() + mockConfig := mocksconfig.NewConfig(s.T()) + mockConfig.EXPECT().GetString("database.default").Return("postgres").Once() + mockConfig.EXPECT().Get("database.connections.postgres.read").Return(nil).Once() + mockConfig.EXPECT().Get("database.connections.postgres.write").Return(nil).Once() + mockConfig.EXPECT().GetString("database.connections.postgres.driver").Return(contractsdatabase.DriverPostgres.String()).Once() + mockConfig.EXPECT().GetString("database.connections.postgres.prefix").Return("").Once() + mockConfig.EXPECT().GetBool("database.connections.postgres.singular").Return(true).Once() + mockConfig.EXPECT().GetString("database.connections.postgres.host").Return("localhost").Once() + mockConfig.EXPECT().GetString("database.connections.postgres.username").Return(config.Username).Once() + mockConfig.EXPECT().GetString("database.connections.postgres.password").Return(config.Password).Once() + mockConfig.EXPECT().GetInt("database.connections.postgres.port").Return(config.Port).Once() + mockConfig.EXPECT().GetString("database.connections.postgres.sslmode").Return("disable").Once() + mockConfig.EXPECT().GetString("database.connections.postgres.timezone").Return("UTC").Once() + mockConfig.EXPECT().GetString("database.connections.postgres.database").Return(config.Database).Once() mockConfig.EXPECT().GetBool("app.debug").Return(true).Once() - mockConfig.EXPECT().GetInt("database.pool.max_idle_conns", 10).Return(10) - mockConfig.EXPECT().GetInt("database.pool.max_open_conns", 100).Return(100) - mockConfig.EXPECT().GetInt("database.pool.conn_max_idletime", 3600).Return(3600) - mockConfig.EXPECT().GetInt("database.pool.conn_max_lifetime", 3600).Return(3600) + mockConfig.EXPECT().GetInt("database.pool.max_idle_conns", 10).Return(10).Once() + mockConfig.EXPECT().GetInt("database.pool.max_open_conns", 100).Return(100).Once() + mockConfig.EXPECT().GetInt("database.pool.conn_max_idletime", 3600).Return(3600).Once() + mockConfig.EXPECT().GetInt("database.pool.conn_max_lifetime", 3600).Return(3600).Once() s.app.Singleton(frameworkconfig.Binding, func(app foundation.Application) (any, error) { return mockConfig, nil @@ -337,7 +337,6 @@ func (s *ApplicationTestSuite) TestMakeOrm() { serviceProvider.Register(s.app) s.NotNil(s.app.MakeOrm()) - mockConfig.AssertExpectations(s.T()) } func (s *ApplicationTestSuite) TestMakeQueue() { diff --git a/foundation/container_test.go b/foundation/container_test.go index a9ac578e1..4ed2f1269 100644 --- a/foundation/container_test.go +++ b/foundation/container_test.go @@ -97,8 +97,20 @@ func (s *ContainerTestSuite) TestSingleton() { s.Equal(1, concreteImpl) s.Nil(err) default: - s.T().Errorf("error") + panic("concrete err") } + + s.container.Refresh("Singleton") + _, exist = s.container.instances.Load("Singleton") + s.False(exist) + + s.container.Singleton("Singleton", callback) + concrete, exist = s.container.bindings.Load("Singleton") + s.True(exist) + ins, ok = concrete.(instance) + s.True(ok) + s.True(ins.shared) + s.NotNil(ins.concrete) } func (s *ContainerTestSuite) TestMake() { diff --git a/testing/docker/docker_test.go b/testing/docker/docker_test.go index dd13f20e8..e1efc9048 100644 --- a/testing/docker/docker_test.go +++ b/testing/docker/docker_test.go @@ -7,6 +7,7 @@ import ( mocksconfig "github.com/goravel/framework/mocks/config" mocksconsole "github.com/goravel/framework/mocks/console" + mocksorm "github.com/goravel/framework/mocks/database/orm" mocksfoundation "github.com/goravel/framework/mocks/foundation" ) @@ -34,9 +35,11 @@ func (s *DockerTestSuite) TestDatabase() { mockConfig.EXPECT().GetString("database.connections.mysql.password").Return("goravel").Once() mockArtisan := mocksconsole.NewArtisan(s.T()) + mockOrm := mocksorm.NewOrm(s.T()) s.mockApp.EXPECT().MakeArtisan().Return(mockArtisan).Once() s.mockApp.EXPECT().MakeConfig().Return(mockConfig).Once() + s.mockApp.EXPECT().MakeOrm().Return(mockOrm).Once() database, err := s.docker.Database() s.Nil(err) @@ -52,9 +55,11 @@ func (s *DockerTestSuite) TestDatabase() { mockConfig.EXPECT().GetString("database.connections.postgres.password").Return("goravel").Once() mockArtisan = mocksconsole.NewArtisan(s.T()) + mockOrm = mocksorm.NewOrm(s.T()) s.mockApp.EXPECT().MakeArtisan().Return(mockArtisan).Once() s.mockApp.On("MakeConfig").Return(mockConfig).Once() + s.mockApp.EXPECT().MakeOrm().Return(mockOrm).Once() database, err = s.docker.Database("postgres") s.Nil(err) From aee982cd5985ae27f3876bdea01f211bd25014ee Mon Sep 17 00:00:00 2001 From: Bowen Date: Tue, 1 Oct 2024 22:17:48 +0800 Subject: [PATCH 07/11] remove global variable --- database/gorm/test_utils.go | 2 +- database/orm.go | 10 ++++++---- database/service_provider.go | 5 +---- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/database/gorm/test_utils.go b/database/gorm/test_utils.go index ab055aeb7..dd8e79078 100644 --- a/database/gorm/test_utils.go +++ b/database/gorm/test_utils.go @@ -20,7 +20,7 @@ const ( TestModelNormal // Switch this value to control the test model. - TestModel = TestModelMinimum + TestModel = TestModelNormal ) type TestTable int diff --git a/database/orm.go b/database/orm.go index 02675015f..25ce5bef9 100644 --- a/database/orm.go +++ b/database/orm.go @@ -18,9 +18,10 @@ type Orm struct { connection string query contractsorm.Query queries map[string]contractsorm.Query + refresh func(key any) } -func NewOrm(ctx context.Context, config config.Config, connection string, query contractsorm.Query) (*Orm, error) { +func NewOrm(ctx context.Context, config config.Config, connection string, query contractsorm.Query, refresh func(key any)) (*Orm, error) { return &Orm{ ctx: ctx, config: config, @@ -29,16 +30,17 @@ func NewOrm(ctx context.Context, config config.Config, connection string, query queries: map[string]contractsorm.Query{ connection: query, }, + refresh: refresh, }, nil } -func BuildOrm(ctx context.Context, config config.Config, connection string) (*Orm, error) { +func BuildOrm(ctx context.Context, config config.Config, connection string, refresh func(key any)) (*Orm, error) { query, err := gorm.BuildQuery(ctx, config, connection) if err != nil { return nil, fmt.Errorf("[Orm] Build query for %s connection error: %v", connection, err) } - return NewOrm(ctx, config, connection, query) + return NewOrm(ctx, config, connection, query, refresh) } func (r *Orm) Connection(name string) contractsorm.Orm { @@ -95,7 +97,7 @@ func (r *Orm) Observe(model any, observer contractsorm.Observer) { } func (r *Orm) Refresh() { - appFacade.Refresh(BindingOrm) + r.refresh(BindingOrm) } func (r *Orm) Transaction(txFunc func(tx contractsorm.Query) error) error { diff --git a/database/service_provider.go b/database/service_provider.go index 06110b797..77db293b7 100644 --- a/database/service_provider.go +++ b/database/service_provider.go @@ -14,8 +14,6 @@ const BindingOrm = "goravel.orm" const BindingSchema = "goravel.schema" const BindingSeeder = "goravel.seeder" -var appFacade foundation.Application - type ServiceProvider struct { } @@ -24,7 +22,7 @@ func (r *ServiceProvider) Register(app foundation.Application) { ctx := context.Background() config := app.MakeConfig() connection := config.GetString("database.default") - orm, err := BuildOrm(ctx, config, connection) + orm, err := BuildOrm(ctx, config, connection, app.Refresh) if err != nil { return nil, fmt.Errorf("[Orm] Init %s connection error: %v", connection, err) } @@ -49,7 +47,6 @@ func (r *ServiceProvider) Register(app foundation.Application) { } func (r *ServiceProvider) Boot(app foundation.Application) { - appFacade = app r.registerCommands(app) } From de9d92e9820a8a891e12cdc27b4b040d4457b958 Mon Sep 17 00:00:00 2001 From: Bowen Date: Tue, 1 Oct 2024 22:37:30 +0800 Subject: [PATCH 08/11] optimize based on ai --- contracts/database/config.go | 2 +- database/console/migrate.go | 4 ++-- database/db/configs.go | 20 ++++++++++---------- database/db/configs_test.go | 22 +++++++++++----------- database/gorm/dialector.go | 4 ++-- database/gorm/dialector_test.go | 4 ++-- database/gorm/gorm.go | 22 +++++++++++----------- database/gorm/query.go | 4 ++-- database/orm.go | 17 +++++++++++------ foundation/container_test.go | 2 +- 10 files changed, 53 insertions(+), 48 deletions(-) diff --git a/contracts/database/config.go b/contracts/database/config.go index 641d3c8b5..7d1157ae6 100644 --- a/contracts/database/config.go +++ b/contracts/database/config.go @@ -35,7 +35,7 @@ type FullConfig struct { Timezone string // Postgres } -type Configs interface { +type ConfigBuilder interface { Reads() []FullConfig Writes() []FullConfig } diff --git a/database/console/migrate.go b/database/console/migrate.go index ea952557e..26193d604 100644 --- a/database/console/migrate.go +++ b/database/console/migrate.go @@ -25,8 +25,8 @@ func getMigrate(config config.Config) (*migrate.Migrate, error) { dir = fmt.Sprintf("file://%s/database/migrations", support.RelativePath) } - configs := databasedb.NewConfigs(config, connection) - writeConfigs := configs.Writes() + configBuilder := databasedb.NewConfigBuilder(config, connection) + writeConfigs := configBuilder.Writes() if len(writeConfigs) == 0 { return nil, errors.New("not found database configuration") } diff --git a/database/db/configs.go b/database/db/configs.go index 1e8315206..2ae4ac473 100644 --- a/database/db/configs.go +++ b/database/db/configs.go @@ -7,38 +7,38 @@ import ( "github.com/goravel/framework/contracts/database" ) -type Configs struct { +type ConfigBuilder struct { config contractsconfig.Config connection string } -func NewConfigs(config contractsconfig.Config, connection string) *Configs { - return &Configs{ +func NewConfigBuilder(config contractsconfig.Config, connection string) *ConfigBuilder { + return &ConfigBuilder{ config: config, connection: connection, } } -func (c *Configs) Reads() []database.FullConfig { +func (c *ConfigBuilder) Reads() []database.FullConfig { configs := c.config.Get(fmt.Sprintf("database.connections.%s.read", c.connection)) - if configs, ok := configs.([]database.Config); ok { - return c.fillDefault(configs) + if readConfigs, ok := configs.([]database.Config); ok { + return c.fillDefault(readConfigs) } return nil } -func (c *Configs) Writes() []database.FullConfig { +func (c *ConfigBuilder) Writes() []database.FullConfig { configs := c.config.Get(fmt.Sprintf("database.connections.%s.write", c.connection)) - if configs, ok := configs.([]database.Config); ok { - return c.fillDefault(configs) + if writeConfigs, ok := configs.([]database.Config); ok { + return c.fillDefault(writeConfigs) } // Use default db configuration when write is empty return c.fillDefault([]database.Config{{}}) } -func (c *Configs) fillDefault(configs []database.Config) []database.FullConfig { +func (c *ConfigBuilder) fillDefault(configs []database.Config) []database.FullConfig { if len(configs) == 0 { return nil } diff --git a/database/db/configs_test.go b/database/db/configs_test.go index 08db5c588..86dba04e3 100644 --- a/database/db/configs_test.go +++ b/database/db/configs_test.go @@ -12,9 +12,9 @@ import ( type ConfigTestSuite struct { suite.Suite - configs *Configs - connection string - mockConfig *mocksconfig.Config + configBuilder *ConfigBuilder + connection string + mockConfig *mocksconfig.Config } func TestConfigTestSuite(t *testing.T) { @@ -25,7 +25,7 @@ func TestConfigTestSuite(t *testing.T) { func (s *ConfigTestSuite) SetupTest() { s.mockConfig = mocksconfig.NewConfig(s.T()) - s.configs = NewConfigs(s.mockConfig, s.connection) + s.configBuilder = NewConfigBuilder(s.mockConfig, s.connection) } func (s *ConfigTestSuite) TestReads() { @@ -35,7 +35,7 @@ func (s *ConfigTestSuite) TestReads() { // Test when configs is empty s.mockConfig.EXPECT().Get("database.connections.mysql.read").Return(nil).Once() - s.Nil(s.configs.Reads()) + s.Nil(s.configBuilder.Reads()) // Test when configs is not empty s.mockConfig.EXPECT().Get("database.connections.mysql.read").Return([]contractsdatabase.Config{ @@ -56,7 +56,7 @@ func (s *ConfigTestSuite) TestReads() { Database: database, }, }, - }, s.configs.Reads()) + }, s.configBuilder.Reads()) } func (s *ConfigTestSuite) TestWrites() { @@ -64,7 +64,7 @@ func (s *ConfigTestSuite) TestWrites() { prefix := "goravel_" singular := false - // Test when configs is empty + // Test when configBuilder is empty s.mockConfig.EXPECT().Get("database.connections.mysql.write").Return(nil).Once() s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.driver", s.connection)).Return(contractsdatabase.DriverSqlite.String()).Once() s.mockConfig.EXPECT().GetString(fmt.Sprintf("database.connections.%s.database", s.connection)).Return(database).Once() @@ -80,9 +80,9 @@ func (s *ConfigTestSuite) TestWrites() { Database: database, }, }, - }, s.configs.Writes()) + }, s.configBuilder.Writes()) - // Test when configs is not empty + // Test when configBuilder is not empty s.mockConfig.EXPECT().Get("database.connections.mysql.write").Return([]contractsdatabase.Config{ { Database: database, @@ -101,7 +101,7 @@ func (s *ConfigTestSuite) TestWrites() { Database: database, }, }, - }, s.configs.Writes()) + }, s.configBuilder.Writes()) } func (s *ConfigTestSuite) TestFillDefault() { @@ -224,7 +224,7 @@ func (s *ConfigTestSuite) TestFillDefault() { for _, test := range tests { s.Run(test.name, func() { test.setup() - configs := s.configs.fillDefault(test.configs) + configs := s.configBuilder.fillDefault(test.configs) s.Equal(test.expectConfigs, configs) }) diff --git a/database/gorm/dialector.go b/database/gorm/dialector.go index 5d44f3ad9..c337a50e1 100644 --- a/database/gorm/dialector.go +++ b/database/gorm/dialector.go @@ -13,14 +13,14 @@ import ( "github.com/goravel/framework/database/db" ) -func GetDialectors(configs []database.FullConfig) ([]gorm.Dialector, error) { +func getDialectors(configs []database.FullConfig) ([]gorm.Dialector, error) { var dialectors []gorm.Dialector for _, config := range configs { var dialector gorm.Dialector dsn := db.Dsn(config) if dsn == "" { - return nil, fmt.Errorf("failed to get dsn for %s", config.Connection) + return nil, fmt.Errorf("failed to generate DSN for connection '%s'", config.Connection) } switch config.Driver { diff --git a/database/gorm/dialector_test.go b/database/gorm/dialector_test.go index cb475de94..e5cdeaa37 100644 --- a/database/gorm/dialector_test.go +++ b/database/gorm/dialector_test.go @@ -98,9 +98,9 @@ func TestGetDialectors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - dialectors, err := GetDialectors(test.configs) + dialectors, err := getDialectors(test.configs) if test.expectError != nil { - assert.Equal(t, test.expectError, err) + assert.EqualError(t, err, test.expectError.Error()) assert.Nil(t, dialectors) } else { assert.NoError(t, err) diff --git a/database/gorm/gorm.go b/database/gorm/gorm.go index 8ba01fc9c..48b2ee5ce 100644 --- a/database/gorm/gorm.go +++ b/database/gorm/gorm.go @@ -18,23 +18,23 @@ import ( ) type Builder struct { - config config.Config - configs database.Configs - instance *gormio.DB + config config.Config + configBuilder database.ConfigBuilder + instance *gormio.DB } -func NewGorm(config config.Config, configs database.Configs) (*gormio.DB, error) { +func NewGorm(config config.Config, configBuilder database.ConfigBuilder) (*gormio.DB, error) { builder := &Builder{ - config: config, - configs: configs, + config: config, + configBuilder: configBuilder, } return builder.Build() } func (r *Builder) Build() (*gormio.DB, error) { - readConfigs := r.configs.Reads() - writeConfigs := r.configs.Writes() + readConfigs := r.configBuilder.Reads() + writeConfigs := r.configBuilder.Writes() if len(writeConfigs) == 0 { return nil, errors.New("not found database configuration") } @@ -73,12 +73,12 @@ func (r *Builder) configureReadWriteSeparate(readConfigs, writeConfigs []databas return nil } - readDialectors, err := GetDialectors(readConfigs) + readDialectors, err := getDialectors(readConfigs) if err != nil { return err } - writeDialectors, err := GetDialectors(writeConfigs) + writeDialectors, err := getDialectors(writeConfigs) if err != nil { return err } @@ -92,7 +92,7 @@ func (r *Builder) configureReadWriteSeparate(readConfigs, writeConfigs []databas } func (r *Builder) init(fullConfig database.FullConfig) error { - dialectors, err := GetDialectors([]database.FullConfig{fullConfig}) + dialectors, err := getDialectors([]database.FullConfig{fullConfig}) if err != nil { return fmt.Errorf("init gorm dialector error: %v", err) } diff --git a/database/gorm/query.go b/database/gorm/query.go index 04eaeaff8..17757d2b9 100644 --- a/database/gorm/query.go +++ b/database/gorm/query.go @@ -49,8 +49,8 @@ func NewQuery(ctx context.Context, config config.Config, connection string, db * } func BuildQuery(ctx context.Context, config config.Config, connection string) (*Query, error) { - configs := db.NewConfigs(config, connection) - gorm, err := NewGorm(config, configs) + configBuilder := db.NewConfigBuilder(config, connection) + gorm, err := NewGorm(config, configBuilder) if err != nil { return nil, err } diff --git a/database/orm.go b/database/orm.go index 25ce5bef9..264b537ba 100644 --- a/database/orm.go +++ b/database/orm.go @@ -76,7 +76,10 @@ func (r *Orm) Connection(name string) contractsorm.Orm { } func (r *Orm) DB() (*sql.DB, error) { - query := r.Query().(*gorm.Query) + query, ok := r.Query().(*gorm.Query) + if !ok { + return nil, fmt.Errorf("unexpected Query type %T, expected *gorm.Query", r.Query()) + } return query.Instance().DB() } @@ -119,18 +122,20 @@ func (r *Orm) Transaction(txFunc func(tx contractsorm.Query) error) error { func (r *Orm) WithContext(ctx context.Context) contractsorm.Orm { for _, query := range r.queries { - query := query.(*gorm.Query) - query.SetContext(ctx) + if gormQuery, ok := query.(*gorm.Query); ok { + gormQuery.SetContext(ctx) + } } - query := r.query.(*gorm.Query) - query.SetContext(ctx) + if gormQuery, ok := r.query.(*gorm.Query); ok { + gormQuery.SetContext(ctx) + } return &Orm{ ctx: ctx, config: r.config, connection: r.connection, - query: query, + query: r.query, queries: r.queries, } } diff --git a/foundation/container_test.go b/foundation/container_test.go index 4ed2f1269..05a7cd511 100644 --- a/foundation/container_test.go +++ b/foundation/container_test.go @@ -97,7 +97,7 @@ func (s *ContainerTestSuite) TestSingleton() { s.Equal(1, concreteImpl) s.Nil(err) default: - panic("concrete err") + s.Fail("concrete err") } s.container.Refresh("Singleton") From 4bb9026a9485a9c708bb0c70461b8fd930fe5bdc Mon Sep 17 00:00:00 2001 From: hwbrzzl Date: Tue, 1 Oct 2024 14:38:19 +0000 Subject: [PATCH 09/11] chore: update mocks --- mocks/database/ConfigBuilder.go | 129 ++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 mocks/database/ConfigBuilder.go diff --git a/mocks/database/ConfigBuilder.go b/mocks/database/ConfigBuilder.go new file mode 100644 index 000000000..328036c30 --- /dev/null +++ b/mocks/database/ConfigBuilder.go @@ -0,0 +1,129 @@ +// Code generated by mockery. DO NOT EDIT. + +package database + +import ( + database "github.com/goravel/framework/contracts/database" + mock "github.com/stretchr/testify/mock" +) + +// ConfigBuilder is an autogenerated mock type for the ConfigBuilder type +type ConfigBuilder struct { + mock.Mock +} + +type ConfigBuilder_Expecter struct { + mock *mock.Mock +} + +func (_m *ConfigBuilder) EXPECT() *ConfigBuilder_Expecter { + return &ConfigBuilder_Expecter{mock: &_m.Mock} +} + +// Reads provides a mock function with given fields: +func (_m *ConfigBuilder) Reads() []database.FullConfig { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Reads") + } + + var r0 []database.FullConfig + if rf, ok := ret.Get(0).(func() []database.FullConfig); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]database.FullConfig) + } + } + + return r0 +} + +// ConfigBuilder_Reads_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Reads' +type ConfigBuilder_Reads_Call struct { + *mock.Call +} + +// Reads is a helper method to define mock.On call +func (_e *ConfigBuilder_Expecter) Reads() *ConfigBuilder_Reads_Call { + return &ConfigBuilder_Reads_Call{Call: _e.mock.On("Reads")} +} + +func (_c *ConfigBuilder_Reads_Call) Run(run func()) *ConfigBuilder_Reads_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *ConfigBuilder_Reads_Call) Return(_a0 []database.FullConfig) *ConfigBuilder_Reads_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *ConfigBuilder_Reads_Call) RunAndReturn(run func() []database.FullConfig) *ConfigBuilder_Reads_Call { + _c.Call.Return(run) + return _c +} + +// Writes provides a mock function with given fields: +func (_m *ConfigBuilder) Writes() []database.FullConfig { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Writes") + } + + var r0 []database.FullConfig + if rf, ok := ret.Get(0).(func() []database.FullConfig); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]database.FullConfig) + } + } + + return r0 +} + +// ConfigBuilder_Writes_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Writes' +type ConfigBuilder_Writes_Call struct { + *mock.Call +} + +// Writes is a helper method to define mock.On call +func (_e *ConfigBuilder_Expecter) Writes() *ConfigBuilder_Writes_Call { + return &ConfigBuilder_Writes_Call{Call: _e.mock.On("Writes")} +} + +func (_c *ConfigBuilder_Writes_Call) Run(run func()) *ConfigBuilder_Writes_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *ConfigBuilder_Writes_Call) Return(_a0 []database.FullConfig) *ConfigBuilder_Writes_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *ConfigBuilder_Writes_Call) RunAndReturn(run func() []database.FullConfig) *ConfigBuilder_Writes_Call { + _c.Call.Return(run) + return _c +} + +// NewConfigBuilder creates a new instance of ConfigBuilder. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewConfigBuilder(t interface { + mock.TestingT + Cleanup(func()) +}) *ConfigBuilder { + mock := &ConfigBuilder{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} From 3e6bb94761448cb4e15b62eca343c0331e653e2a Mon Sep 17 00:00:00 2001 From: Bowen Date: Tue, 1 Oct 2024 23:32:43 +0800 Subject: [PATCH 10/11] fix test --- database/db/configs.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/database/db/configs.go b/database/db/configs.go index 2ae4ac473..53b8c633a 100644 --- a/database/db/configs.go +++ b/database/db/configs.go @@ -67,10 +67,10 @@ func (c *ConfigBuilder) fillDefault(configs []database.Config) []database.FullCo if fullConfig.Password == "" { fullConfig.Password = c.config.GetString(fmt.Sprintf("database.connections.%s.password", c.connection)) } - if driver == database.DriverMysql { + if driver == database.DriverMysql || driver == database.DriverSqlserver { fullConfig.Charset = c.config.GetString(fmt.Sprintf("database.connections.%s.charset", c.connection)) } - if driver == database.DriverMysql || driver == database.DriverSqlserver { + if driver == database.DriverMysql { fullConfig.Loc = c.config.GetString(fmt.Sprintf("database.connections.%s.loc", c.connection)) } if driver == database.DriverPostgres { From 61aa33fb8017329e309692f0759c35256540f195 Mon Sep 17 00:00:00 2001 From: Bowen Date: Tue, 1 Oct 2024 23:40:13 +0800 Subject: [PATCH 11/11] fix test --- database/gorm/dialector.go | 2 +- database/gorm/dialector_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/database/gorm/dialector.go b/database/gorm/dialector.go index c337a50e1..4902d90f1 100644 --- a/database/gorm/dialector.go +++ b/database/gorm/dialector.go @@ -20,7 +20,7 @@ func getDialectors(configs []database.FullConfig) ([]gorm.Dialector, error) { var dialector gorm.Dialector dsn := db.Dsn(config) if dsn == "" { - return nil, fmt.Errorf("failed to generate DSN for connection '%s'", config.Connection) + return nil, fmt.Errorf("failed to generate DSN for connection: %s", config.Connection) } switch config.Driver { diff --git a/database/gorm/dialector_test.go b/database/gorm/dialector_test.go index e5cdeaa37..a43ee9898 100644 --- a/database/gorm/dialector_test.go +++ b/database/gorm/dialector_test.go @@ -32,7 +32,7 @@ func TestGetDialectors(t *testing.T) { Connection: "postgres", }, }, - expectError: fmt.Errorf("failed to get dsn for postgres"), + expectError: fmt.Errorf("failed to generate DSN for connection: postgres"), }, { name: "Happy path - mysql",