From 128a098378cf58aabbe15ae149a6e2435647ea06 Mon Sep 17 00:00:00 2001 From: caixw Date: Tue, 9 Apr 2024 15:59:27 +0800 Subject: [PATCH] =?UTF-8?q?refactor(internal/model):=20=E5=B0=86=20interna?= =?UTF-8?q?l/engine=20=E5=B9=B6=E5=85=A5=20internal/model?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- db.go | 8 ++++---- internal/engine/engine_test.go | 15 --------------- internal/model/bench_test.go | 4 ++-- internal/model/column_test.go | 2 +- internal/{engine => model}/engine.go | 14 +++++++------- internal/model/model.go | 4 ++-- internal/model/model_test.go | 2 +- internal/model/models.go | 16 +++++++++++----- internal/model/models_test.go | 2 +- tx.go | 5 ++--- 10 files changed, 31 insertions(+), 41 deletions(-) delete mode 100644 internal/engine/engine_test.go rename internal/{engine => model}/engine.go (91%) diff --git a/db.go b/db.go index c3f586a..fffe927 100644 --- a/db.go +++ b/db.go @@ -9,7 +9,6 @@ import ( "database/sql" "github.com/issue9/orm/v6/core" - "github.com/issue9/orm/v6/internal/engine" "github.com/issue9/orm/v6/internal/model" "github.com/issue9/orm/v6/sqlbuilder" ) @@ -44,9 +43,11 @@ func NewDB(tablePrefix, dsn string, dialect Dialect) (*DB, error) { // NOTE: 请确保用于打开 db 的 driverName 参数与 dialect.DriverName() 是相同的, // 否则后续操作的结果是未知的。 func NewDBWithStdDB(tablePrefix string, db *sql.DB, dialect Dialect) (*DB, error) { + ms := model.NewModels(dialect) inst := &DB{ db: db, - Engine: engine.New(db, tablePrefix, dialect), + models: ms, + Engine: ms.NewEngine(db, tablePrefix), } ver, err := sqlbuilder.Version(inst) @@ -55,7 +56,6 @@ func NewDBWithStdDB(tablePrefix string, db *sql.DB, dialect Dialect) (*DB, error } inst.version = ver - inst.models = model.NewModels() inst.sqlBuilder = sqlbuilder.New(inst) return inst, nil @@ -69,7 +69,7 @@ func (db *DB) New(tablePrefix string) *DB { return db } - e := engine.New(db.DB(), tablePrefix, db.Dialect()) + e := db.models.NewEngine(db.DB(), tablePrefix) return &DB{ Engine: e, sqlBuilder: sqlbuilder.New(e), diff --git a/internal/engine/engine_test.go b/internal/engine/engine_test.go deleted file mode 100644 index 1c19f97..0000000 --- a/internal/engine/engine_test.go +++ /dev/null @@ -1,15 +0,0 @@ -// SPDX-FileCopyrightText: 2024 caixw -// -// SPDX-License-Identifier: MIT - -package engine_test - -import ( - "testing" - - "github.com/issue9/orm/v6/internal/test" -) - -func TestMain(m *testing.M) { - test.Main(m) -} diff --git a/internal/model/bench_test.go b/internal/model/bench_test.go index 6d07305..783b335 100644 --- a/internal/model/bench_test.go +++ b/internal/model/bench_test.go @@ -12,7 +12,7 @@ import ( func BenchmarkNewModelNoCached(b *testing.B) { a := assert.New(b, false) - ms := NewModels() + ms := NewModels(nil) a.NotNil(ms) for i := 0; i < b.N; i++ { @@ -24,7 +24,7 @@ func BenchmarkNewModelNoCached(b *testing.B) { func BenchmarkNewModelCached(b *testing.B) { a := assert.New(b, false) - ms := NewModels() + ms := NewModels(nil) a.NotNil(ms) for i := 0; i < b.N; i++ { diff --git a/internal/model/column_test.go b/internal/model/column_test.go index 0e56ff0..e5ae2dc 100644 --- a/internal/model/column_test.go +++ b/internal/model/column_test.go @@ -93,7 +93,7 @@ func TestColumn_setNullable(t *testing.T) { a.Error(col.setNullable([]string{"1", "2"})) a.Error(col.setNullable([]string{"T1"})) - ms := NewModels() + ms := NewModels(nil) a.NotNil(ms) // 将 AI 设置为 nullable diff --git a/internal/engine/engine.go b/internal/model/engine.go similarity index 91% rename from internal/engine/engine.go rename to internal/model/engine.go index a5dc27d..b85505e 100644 --- a/internal/engine/engine.go +++ b/internal/model/engine.go @@ -2,8 +2,7 @@ // // SPDX-License-Identifier: MIT -// Package engine [core.Engine] 的默认实现 -package engine +package model import ( "context" @@ -14,8 +13,8 @@ import ( ) type coreEngine struct { + ms *Models engine stdEngine - dialect core.Dialect tablePrefix string replacer *strings.Replacer sqlLogger func(string) @@ -31,12 +30,13 @@ type stdEngine interface { func defaultSQLLogger(string) {} -func New(e stdEngine, tablePrefix string, d core.Dialect) core.Engine { - l, r := d.Quotes() +// NewEngine 声明实现 [core.Engine] 接口的实例 +func (ms *Models) NewEngine(e stdEngine, tablePrefix string) core.Engine { + l, r := ms.dialect.Quotes() return &coreEngine{ + ms: ms, engine: e, - dialect: d, tablePrefix: tablePrefix, sqlLogger: defaultSQLLogger, replacer: strings.NewReplacer( @@ -60,7 +60,7 @@ func (db *coreEngine) Debug(l func(string)) { db.sqlLogger = l } -func (db *coreEngine) Dialect() core.Dialect { return db.dialect } +func (db *coreEngine) Dialect() core.Dialect { return db.ms.dialect } func (db *coreEngine) QueryRow(query string, args ...any) *sql.Row { return db.QueryRowContext(context.Background(), query, args...) diff --git a/internal/model/model.go b/internal/model/model.go index f5d9c14..fdae560 100644 --- a/internal/model/model.go +++ b/internal/model/model.go @@ -19,9 +19,9 @@ func propertyError(field, name, message string) error { return fmt.Errorf("%s 的 %s 属性发生以下错误: %s", field, name, message) } -// New 从一个 obj 声明一个 Model 实例 +// New 从一个 obj 声明 [core.Model] 实例 // -// obj 可以是一个 struct 实例或是指针。 +// obj 可以是一个结构体或是指针。 func (ms *Models) New(obj core.TableNamer) (*core.Model, error) { rtype := reflect.TypeOf(obj) for rtype.Kind() == reflect.Ptr { diff --git a/internal/model/model_test.go b/internal/model/model_test.go index fcabcc9..058fd7b 100644 --- a/internal/model/model_test.go +++ b/internal/model/model_test.go @@ -108,7 +108,7 @@ func (v *viewObject) ViewAs() (string, error) { func TestModels_New(t *testing.T) { a := assert.New(t, false) - ms := NewModels() + ms := NewModels(nil) a.NotNil(ms) m, err := ms.New(&Admin{}) diff --git a/internal/model/models.go b/internal/model/models.go index 84438c2..42b0208 100644 --- a/internal/model/models.go +++ b/internal/model/models.go @@ -4,21 +4,27 @@ package model -import "sync" +import ( + "sync" + + "github.com/issue9/orm/v6/core" +) // Models 数据模型管理 type Models struct { - models *sync.Map + dialect core.Dialect + models *sync.Map } // NewModels 声明 [Models] 变量 -func NewModels() *Models { +func NewModels(d core.Dialect) *Models { return &Models{ - models: &sync.Map{}, + dialect: d, + models: &sync.Map{}, } } -// Clear 清除所有的 Model 缓存 +// Clear 清除所有的 [core.Model] 缓存 func (ms *Models) Clear() { ms.models.Range(func(key, _ any) bool { ms.models.Delete(key) diff --git a/internal/model/models_test.go b/internal/model/models_test.go index 6cf5a41..ac59c50 100644 --- a/internal/model/models_test.go +++ b/internal/model/models_test.go @@ -23,7 +23,7 @@ func (ms *Models) len() (cnt int) { func TestModels(t *testing.T) { a := assert.New(t, false) - ms := NewModels() + ms := NewModels(nil) a.NotNil(ms) m, err := ms.New(&User{}) diff --git a/tx.go b/tx.go index c101d40..a046aba 100644 --- a/tx.go +++ b/tx.go @@ -10,7 +10,6 @@ import ( "errors" "github.com/issue9/orm/v6/core" - "github.com/issue9/orm/v6/internal/engine" "github.com/issue9/orm/v6/sqlbuilder" ) @@ -39,7 +38,7 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { return &Tx{ tx: tx, db: db, - Engine: engine.New(tx, db.TablePrefix(), db.Dialect()), + Engine: db.models.NewEngine(tx, db.TablePrefix()), }, nil } @@ -85,7 +84,7 @@ func (tx *Tx) NewEngine(tablePrefix string) Engine { } return &txEngine{ - Engine: engine.New(tx.Tx(), tablePrefix, tx.Dialect()), + Engine: tx.db.models.NewEngine(tx.Tx(), tablePrefix), tx: tx, } }