Skip to content

Commit

Permalink
refactor: 将 ModelEngine 并入 Engine
Browse files Browse the repository at this point in the history
  • Loading branch information
caixw committed Apr 7, 2024
1 parent 95b6293 commit b80608e
Show file tree
Hide file tree
Showing 49 changed files with 297 additions and 254 deletions.
8 changes: 4 additions & 4 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func BenchmarkDB_Insert(b *testing.B) {
Created: time.Now(),
}

suite := test.NewSuite(a, benchDBDriverName)
suite := test.NewSuite(a, "", benchDBDriverName)

suite.Run(func(t *test.Driver) {
t.NotError(t.DB.Create(&Group{}))
Expand All @@ -49,7 +49,7 @@ func BenchmarkDB_Update(b *testing.B) {
Created: time.Now(),
}

suite := test.NewSuite(a, benchDBDriverName)
suite := test.NewSuite(a, "", benchDBDriverName)

suite.Run(func(t *test.Driver) {
t.NotError(t.DB.Create(&Group{}))
Expand Down Expand Up @@ -79,7 +79,7 @@ func BenchmarkDB_Select(b *testing.B) {
Created: time.Now(),
}

suite := test.NewSuite(a, benchDBDriverName)
suite := test.NewSuite(a, "", benchDBDriverName)

suite.Run(func(t *test.Driver) {
t.NotError(t.DB.Create(&Group{}))
Expand All @@ -106,7 +106,7 @@ func BenchmarkDB_WhereUpdate(b *testing.B) {
Created: time.Now(),
}

suite := test.NewSuite(a, benchDBDriverName)
suite := test.NewSuite(a, "", benchDBDriverName)

suite.Run(func(t *test.Driver) {
t.NotError(t.DB.Create(&Group{}))
Expand Down
11 changes: 11 additions & 0 deletions core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ type IndexType int8

type ConstraintType int8

// TablePrefix 表名前缀
//
// 当需要在一个数据库中创建不同的实例,
// 或是同一个数据模式应用在不同的对象是,可以通过不同的表名前缀对数据表进行区分。
type TablePrefix interface {
// TablePrefix 所有数据表拥有的统一表名前缀
TablePrefix() string
}

// Engine 数据库执行的基本接口
//
// orm.DB 和 orm.Tx 应该实现此接口。
Expand All @@ -54,6 +63,8 @@ type Engine interface {
Prepare(query string) (*Stmt, error)

PrepareContext(ctx context.Context, query string) (*Stmt, error)

TablePrefix
}

// Dialect 用于描述与数据库和驱动相关的一些特性
Expand Down
24 changes: 14 additions & 10 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ import (
// DB 数据库操作实例
type DB struct {
*sql.DB
dialect Dialect
sqlBuilder *sqlbuilder.SQLBuilder
models *model.Models
version string
replacer *strings.Replacer
tablePrefix string
dialect Dialect
sqlBuilder *sqlbuilder.SQLBuilder
models *model.Models
version string
replacer *strings.Replacer

sqlLogger func(string)
}
Expand All @@ -35,23 +36,24 @@ func defaultSQLLogger(string) {}
// - postgres 已经固定为 UTC;
// - sqlite3 可以在 dsn 中通过 _loc=UTC 指定;
// - mysql 默认是 UTC,也可以在 DSN 中通过 loc=UTC 指定;
func NewDB(dsn string, dialect Dialect) (*DB, error) {
func NewDB(tablePrefix, dsn string, dialect Dialect) (*DB, error) {
db, err := sql.Open(dialect.DriverName(), dsn)
if err != nil {
return nil, err
}
return NewDBWithStdDB(db, dialect)
return NewDBWithStdDB(tablePrefix, db, dialect)
}

// NewDBWithStdDB 从 [sql.DB] 构建 [DB] 实例
//
// NOTE: 请确保用于打开 db 的 driverName 参数与 dialect.DriverName() 是相同的,
// 否则后续操作的结果是未知的。
func NewDBWithStdDB(db *sql.DB, dialect Dialect) (*DB, error) {
func NewDBWithStdDB(tablePRefix string, db *sql.DB, dialect Dialect) (*DB, error) {
l, r := dialect.Quotes()
inst := &DB{
DB: db,
dialect: dialect,
DB: db,
tablePrefix: tablePRefix,
dialect: dialect,
replacer: strings.NewReplacer(
string(core.QuoteLeft), string(l),
string(core.QuoteRight), string(r),
Expand All @@ -66,6 +68,8 @@ func NewDBWithStdDB(db *sql.DB, dialect Dialect) (*DB, error) {
return inst, nil
}

func (db *DB) TablePrefix() string { return db.tablePrefix }

// Debug 指定调输出调试内容通道
//
// 如果 l 不为 nil,则每次 SQL 调用都会输出 SQL 语句,预编译的语句,仅在预编译时输出;
Expand Down
20 changes: 10 additions & 10 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestMain(m *testing.M) {

func TestDB_LastInsertID(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a)
suite := test.NewSuite(a, "")

suite.Run(func(t *test.Driver) {
t.NotError(t.DB.Create(&User{}))
Expand All @@ -46,7 +46,7 @@ func (v *defvalues) TableName() string { return "defvalues" }

func TestDB_InsertDefaultValues(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a)
suite := test.NewSuite(a, "")

suite.Run(func(d *test.Driver) {
d.NotError(d.DB.Create(&defvalues{}))
Expand All @@ -66,7 +66,7 @@ func TestDB_InsertDefaultValues(t *testing.T) {

func TestDB_Update(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a)
suite := test.NewSuite(a, "")

suite.Run(func(t *test.Driver) {
initData(t)
Expand Down Expand Up @@ -117,7 +117,7 @@ func TestDB_Update(t *testing.T) {

func TestDB_Update_occ(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a)
suite := test.NewSuite(a, "")

suite.Run(func(t *test.Driver) {
initData(t)
Expand Down Expand Up @@ -170,7 +170,7 @@ func TestDB_Update_occ(t *testing.T) {

func TestDB_Update_error(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a)
suite := test.NewSuite(a, "")

// 多个唯一约束符合查询条件
suite.Run(func(t *test.Driver) {
Expand All @@ -190,7 +190,7 @@ func TestDB_Update_error(t *testing.T) {

func TestDB_Delete(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a)
suite := test.NewSuite(a, "")

suite.Run(func(t *test.Driver) {
initData(t)
Expand Down Expand Up @@ -234,7 +234,7 @@ func TestDB_Delete(t *testing.T) {

func TestDB_Truncate(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a)
suite := test.NewSuite(a, "")

suite.Run(func(t *test.Driver) {
initData(t)
Expand Down Expand Up @@ -265,7 +265,7 @@ func TestDB_Truncate(t *testing.T) {

func TestDB_Drop(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a)
suite := test.NewSuite(a, "")

suite.Run(func(t *test.Driver) {
initData(t)
Expand All @@ -280,7 +280,7 @@ func TestDB_Drop(t *testing.T) {

func TestDB_Version(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a)
suite := test.NewSuite(a, "")

suite.Run(func(t *test.Driver) {
v, err := t.DB.Version()
Expand All @@ -291,7 +291,7 @@ func TestDB_Version(t *testing.T) {

func TestDB_Debug(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a)
suite := test.NewSuite(a, "")

suite.Run(func(t *test.Driver) {
buf := new(bytes.Buffer)
Expand Down
12 changes: 6 additions & 6 deletions dialect/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (

func TestMysql_VersionSQL(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a, test.Mysql, test.Mariadb)
suite := test.NewSuite(a, "", test.Mysql, test.Mariadb)

suite.Run(func(t *test.Driver) {
testDialectVersionSQL(t)
Expand All @@ -27,7 +27,7 @@ func TestMysql_VersionSQL(t *testing.T) {

func TestMysql_DropConstrainStmtHook(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a, test.Mysql, test.Mariadb)
suite := test.NewSuite(a, "", test.Mysql, test.Mariadb)

// 约束名不是根据 core.pkName 生成的
suite.Run(func(t *test.Driver) {
Expand Down Expand Up @@ -72,7 +72,7 @@ func TestMysql_DropConstrainStmtHook(t *testing.T) {
func TestMysql_DropIndexSQL(t *testing.T) {
a := assert.New(t, false)

suite := test.NewSuite(a, test.Mysql, test.Mariadb)
suite := test.NewSuite(a, "", test.Mysql, test.Mariadb)

suite.Run(func(t *test.Driver) {
qs, err := t.DB.Dialect().DropIndexSQL("tbl", "index_name")
Expand All @@ -83,7 +83,7 @@ func TestMysql_DropIndexSQL(t *testing.T) {
func TestMysql_TruncateTableSQL(t *testing.T) {
a := assert.New(t, false)

suite := test.NewSuite(a, test.Mysql, test.Mariadb)
suite := test.NewSuite(a, "", test.Mysql, test.Mariadb)

suite.Run(func(t *test.Driver) {
qs, err := t.DB.Dialect().TruncateTableSQL("tbl", "")
Expand Down Expand Up @@ -270,7 +270,7 @@ func TestMysql_SQLType(t *testing.T) {

func TestMysql_Types(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a, test.Mysql, test.Mariadb)
suite := test.NewSuite(a, "", test.Mysql, test.Mariadb)

suite.Run(func(t *test.Driver) {
testTypes(t)
Expand All @@ -279,7 +279,7 @@ func TestMysql_Types(t *testing.T) {

func TestMysql_TypesDefault(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a, test.Mysql, test.Mariadb)
suite := test.NewSuite(a, "", test.Mysql, test.Mariadb)

suite.Run(func(t *test.Driver) {
testTypesDefault(t)
Expand Down
8 changes: 4 additions & 4 deletions dialect/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (

func TestPostgres_VersionSQL(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a, test.Postgres)
suite := test.NewSuite(a, "", test.Postgres)

suite.Run(func(t *test.Driver) {
testDialectVersionSQL(t)
Expand Down Expand Up @@ -214,7 +214,7 @@ func TestPostgres_SQLType(t *testing.T) {
func TestPostgres_TruncateTableSQL(t *testing.T) {
a := assert.New(t, false)

suite := test.NewSuite(a, test.Postgres)
suite := test.NewSuite(a, "", test.Postgres)

suite.Run(func(t *test.Driver) {
stmt := sqlbuilder.TruncateTable(t.DB).Table("tbl", "")
Expand Down Expand Up @@ -302,7 +302,7 @@ func BenchmarkPostgres_Fix(b *testing.B) {

func TestPostgres_Types(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a, test.Postgres)
suite := test.NewSuite(a, "", test.Postgres)

suite.Run(func(t *test.Driver) {
testTypes(t)
Expand All @@ -311,7 +311,7 @@ func TestPostgres_Types(t *testing.T) {

func TestPostgres_TypesDefault(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a, test.Postgres)
suite := test.NewSuite(a, "", test.Postgres)

suite.Run(func(t *test.Driver) {
testTypesDefault(t)
Expand Down
14 changes: 7 additions & 7 deletions dialect/sqlite3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func clearSqlite3CreateTable(t *test.Driver, db core.Engine) {

func TestSqlite3_VersionSQL(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a, test.Sqlite3)
suite := test.NewSuite(a, "", test.Sqlite3)

suite.Run(func(t *test.Driver) {
testDialectVersionSQL(t)
Expand All @@ -62,7 +62,7 @@ func TestSqlite3_VersionSQL(t *testing.T) {

func TestSqlite3_AddConstraintStmtHook(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a, test.Sqlite3)
suite := test.NewSuite(a, "", test.Sqlite3)

suite.Run(func(t *test.Driver) {
db := t.DB
Expand All @@ -87,7 +87,7 @@ func TestSqlite3_AddConstraintStmtHook(t *testing.T) {
func TestSqlite3_DropConstraintStmtHook(t *testing.T) {
a := assert.New(t, false)

suite := test.NewSuite(a, test.Sqlite3)
suite := test.NewSuite(a, "", test.Sqlite3)

suite.Run(func(t *test.Driver) {
db := t.DB
Expand Down Expand Up @@ -143,7 +143,7 @@ func testMysqlDropConstraintStmtHook(t *test.Driver) {

func TestSqlite3_DropColumnStmtHook(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a, test.Sqlite3)
suite := test.NewSuite(a, "", test.Sqlite3)

suite.Run(func(t *test.Driver) {
db := t.DB
Expand Down Expand Up @@ -196,7 +196,7 @@ func TestSqlite3_CreateTableOptions(t *testing.T) {
func TestSqlite3_TruncateTableSQL(t *testing.T) {
a := assert.New(t, false)

suite := test.NewSuite(a, test.Sqlite3)
suite := test.NewSuite(a, "", test.Sqlite3)

suite.Run(func(t *test.Driver) {
qs, err := t.DB.Dialect().TruncateTableSQL("tbl", "")
Expand Down Expand Up @@ -313,7 +313,7 @@ func TestSqlite3_SQLType(t *testing.T) {

func TestSqlite3_Types(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a, test.Sqlite3)
suite := test.NewSuite(a, "", test.Sqlite3)

suite.Run(func(t *test.Driver) {
testTypes(t)
Expand All @@ -322,7 +322,7 @@ func TestSqlite3_Types(t *testing.T) {

func TestSqlite3_TypesDefault(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a, test.Sqlite3)
suite := test.NewSuite(a, "", test.Sqlite3)

suite.Run(func(t *test.Driver) {
testTypesDefault(t)
Expand Down
4 changes: 2 additions & 2 deletions fetch/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ var benchDBDriverName = test.Mysql

func BenchmarkObject(b *testing.B) {
a := assert.New(b, false)
suite := test.NewSuite(a, benchDBDriverName)
suite := test.NewSuite(a, "", benchDBDriverName)

suite.Run(func(t *test.Driver) {
initDB(t)
Expand All @@ -43,7 +43,7 @@ func BenchmarkObject(b *testing.B) {

func BenchmarkMap(b *testing.B) {
a := assert.New(b, false)
suite := test.NewSuite(a, benchDBDriverName)
suite := test.NewSuite(a, "", benchDBDriverName)

suite.Run(func(t *test.Driver) {
initDB(t)
Expand Down
4 changes: 2 additions & 2 deletions fetch/column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (

func TestColumn(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a)
suite := test.NewSuite(a, "")

eq := func(s1, s2 []any) bool {
if len(s1) != len(s2) {
Expand Down Expand Up @@ -99,7 +99,7 @@ func TestColumn(t *testing.T) {

func TestColumnString(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a)
suite := test.NewSuite(a, "")

suite.Run(func(t *test.Driver) {
initDB(t)
Expand Down
Loading

0 comments on commit b80608e

Please sign in to comment.