From e50cbd0fade3782d37c49cd24b1271f43354f189 Mon Sep 17 00:00:00 2001 From: caixw Date: Tue, 9 Apr 2024 22:14:47 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20Context=20?= =?UTF-8?q?=E7=9B=B8=E5=85=B3=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- core/core.go | 8 +-- db.go | 58 +++++++++++++--- sqlbuilder.go | 45 +++++++------ sqlbuilder/insert.go | 2 +- sqlbuilder/select.go | 36 +++++++--- sqlbuilder/sqlbuilder_test.go | 4 +- sqlbuilder/table.go | 2 +- sqlbuilder/view.go | 2 +- tx.go | 123 ++++++++++++++++++++++++++++------ types.go | 22 ++++-- 10 files changed, 222 insertions(+), 80 deletions(-) diff --git a/core/core.go b/core/core.go index 0418968..232ba06 100644 --- a/core/core.go +++ b/core/core.go @@ -39,24 +39,20 @@ type ConstraintType int8 // - {} 符号会被替换为 [Dialect.Quotes] 对应的符号; // - # 会被替换为 [Engine.TablePrefix] 的返回值; type Engine interface { - Dialect() Dialect - Query(query string, args ...any) (*sql.Rows, error) - QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) QueryRow(query string, args ...any) *sql.Row - QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row Exec(query string, args ...any) (sql.Result, error) - ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) Prepare(query string) (*Stmt, error) - PrepareContext(ctx context.Context, query string) (*Stmt, error) + Dialect() Dialect + // TablePrefix 所有数据表拥有的统一表名前缀 // // 当需要在一个数据库中创建不同的实例, diff --git a/db.go b/db.go index fffe927..f3136a3 100644 --- a/db.go +++ b/db.go @@ -92,35 +92,73 @@ func (db *DB) Close() error { // Version 数据库服务端的版本号 func (db *DB) Version() string { return db.version } -func (db *DB) LastInsertID(v TableNamer) (int64, error) { return lastInsertID(db, v) } +func (db *DB) LastInsertID(v TableNamer) (int64, error) { + return db.LastInsertIDContext(context.Background(), v) +} + +func (db *DB) LastInsertIDContext(ctx context.Context, v TableNamer) (int64, error) { + return lastInsertID(ctx, db, v) +} // Insert 插入数据 // // NOTE: 若需一次性插入多条数据,请使用 [Tx.InsertMany]。 -func (db *DB) Insert(v TableNamer) (sql.Result, error) { return insert(db, v) } +func (db *DB) Insert(v TableNamer) (sql.Result, error) { + return db.InsertContext(context.Background(), v) +} + +func (db *DB) InsertContext(ctx context.Context, v TableNamer) (sql.Result, error) { + return insert(ctx, db, v) +} -func (db *DB) Delete(v TableNamer) (sql.Result, error) { return del(db, v) } +func (db *DB) Delete(v TableNamer) (sql.Result, error) { + return db.DeleteContext(context.Background(), v) +} -func (db *DB) Update(v TableNamer, cols ...string) (sql.Result, error) { return update(db, v, cols...) } +func (db *DB) DeleteContext(ctx context.Context, v TableNamer) (sql.Result, error) { + return del(ctx, db, v) +} -func (db *DB) Select(v TableNamer) (bool, error) { return find(db, v) } +func (db *DB) Update(v TableNamer, cols ...string) (sql.Result, error) { + return db.UpdateContext(context.Background(), v, cols...) +} -func (db *DB) Create(v TableNamer) error { return create(db, v) } +func (db *DB) UpdateContext(ctx context.Context, v TableNamer, cols ...string) (sql.Result, error) { + return update(ctx, db, v, cols...) +} -func (db *DB) Drop(v TableNamer) error { return drop(db, v) } +func (db *DB) Select(v TableNamer) (bool, error) { return db.SelectContext(context.Background(), v) } + +func (db *DB) SelectContext(ctx context.Context, v TableNamer) (bool, error) { return find(ctx, db, v) } + +func (db *DB) Create(v TableNamer) error { return db.CreateContext(context.Background(), v) } + +func (db *DB) CreateContext(ctx context.Context, v TableNamer) error { return create(ctx, db, v) } + +func (db *DB) Drop(v TableNamer) error { return db.DropContext(context.Background(), v) } + +func (db *DB) DropContext(ctx context.Context, v TableNamer) error { return drop(ctx, db, v) } func (db *DB) Truncate(v TableNamer) error { + return db.TruncateContext(context.Background(), v) +} + +func (db *DB) TruncateContext(ctx context.Context, v TableNamer) error { if !db.Dialect().TransactionalDDL() { - return truncate(db, v) + return truncate(ctx, db, v) } - return db.DoTransaction(func(tx *Tx) error { return truncate(tx, v) }) + return db.DoTransaction(func(tx *Tx) error { return truncate(ctx, tx, v) }) } // InsertMany 一次插入多条数据 // // 会自动转换成事务进行处理。 func (db *DB) InsertMany(max int, v ...TableNamer) error { - return db.DoTransaction(func(tx *Tx) error { return tx.InsertMany(max, v...) }) + return db.InsertManyContext(context.Background(), max, v...) +} + +func (db *DB) InsertManyContext(ctx context.Context, max int, v ...TableNamer) error { + return db.DoTransaction(func(tx *Tx) error { return tx.InsertManyContext(ctx, max, v...) }) } func (db *DB) SQLBuilder() *sqlbuilder.SQLBuilder { return db.sqlBuilder } diff --git a/sqlbuilder.go b/sqlbuilder.go index 5181517..afd3a32 100644 --- a/sqlbuilder.go +++ b/sqlbuilder.go @@ -5,6 +5,7 @@ package orm import ( + "context" "database/sql" "errors" "fmt" @@ -100,14 +101,14 @@ func getKV(rval reflect.Value, cols ...*core.Column) (keys []string, vals []any) } // 创建表或是视图 -func create(e Engine, v TableNamer) error { +func create(ctx context.Context, e Engine, v TableNamer) error { m, _, err := getModel(e, v) if err != nil { return err } if m.Type == core.View { - return createView(e, m) + return createView(ctx, e, m) } sb := e.SQLBuilder().CreateTable().Table(m.Name) @@ -152,19 +153,19 @@ func create(e Engine, v TableNamer) error { sb.PK(constraintName(m.Name, m.PrimaryKey.Name), cols...) } - return sb.Exec() + return sb.ExecContext(ctx) } -func createView(e Engine, m *core.Model) error { +func createView(ctx context.Context, e Engine, m *core.Model) error { stmt := e.SQLBuilder().CreateView().Name(m.Name) for _, col := range m.Columns { stmt.Column(col.Name) } - return stmt.FromQuery(m.ViewAs).Exec() + return stmt.FromQuery(m.ViewAs).ExecContext(ctx) } -func truncate(e Engine, v TableNamer) error { +func truncate(ctx context.Context, e Engine, v TableNamer) error { m, err := e.newModel(v) if err != nil { return err @@ -181,24 +182,24 @@ func truncate(e Engine, v TableNamer) error { stmt.Table(m.Name, "") } - return stmt.Exec() + return stmt.ExecContext(ctx) } // 删除表或视图 -func drop(e Engine, v TableNamer) error { +func drop(ctx context.Context, e Engine, v TableNamer) error { m, err := e.newModel(v) if err != nil { return err } if m.Type == core.View { - return e.SQLBuilder().DropView().Name(m.Name).Exec() + return e.SQLBuilder().DropView().Name(m.Name).ExecContext(ctx) } - return e.SQLBuilder().DropTable().Table(m.Name).Exec() + return e.SQLBuilder().DropTable().Table(m.Name).ExecContext(ctx) } -func lastInsertID(e Engine, v TableNamer) (int64, error) { +func lastInsertID(ctx context.Context, e Engine, v TableNamer) (int64, error) { m, rval, err := getModel(e, v) if err != nil { return 0, err @@ -233,10 +234,10 @@ func lastInsertID(e Engine, v TableNamer) (int64, error) { stmt.KeyValue(col.Name, field.Interface()) } - return stmt.LastInsertID(m.Name, m.AutoIncrement.Name) + return stmt.LastInsertIDContext(ctx, m.AutoIncrement.Name) } -func insert(e Engine, v TableNamer) (sql.Result, error) { +func insert(ctx context.Context, e Engine, v TableNamer) (sql.Result, error) { m, rval, err := getModel(e, v) if err != nil { return nil, err @@ -267,14 +268,14 @@ func insert(e Engine, v TableNamer) (sql.Result, error) { stmt.KeyValue(col.Name, field.Interface()) } - return stmt.Exec() + return stmt.ExecContext(ctx) } // 查找数据 // // 根据 v 的 pk 或中唯一索引列查找一行数据,并赋值给 v。 // 若 v 为空,则不发生任何操作,v 可以是数组。 -func find(e Engine, v TableNamer) (bool, error) { +func find(ctx context.Context, e Engine, v TableNamer) (bool, error) { m, rval, err := getModel(e, v) if err != nil { return false, err @@ -285,7 +286,7 @@ func find(e Engine, v TableNamer) (bool, error) { return false, err } - size, err := stmt.QueryObject(true, v) + size, err := stmt.QueryObjectContext(ctx, true, v) if err != nil { return false, err } @@ -293,7 +294,7 @@ func find(e Engine, v TableNamer) (bool, error) { } // for update 只能作用于事务 -func forUpdate(tx *Tx, v TableNamer) error { +func forUpdate(ctx context.Context, tx *Tx, v TableNamer) error { m, rval, err := getModel(tx, v) if err != nil { return err @@ -314,7 +315,7 @@ func forUpdate(tx *Tx, v TableNamer) error { return err } - _, err = stmt.QueryObject(true, v) + _, err = stmt.QueryObjectContext(ctx, true, v) return err } @@ -323,7 +324,7 @@ func forUpdate(tx *Tx, v TableNamer) error { // // 更新依据为每个对象的主键或是唯一索引列。 // 若不存在此两个类型的字段,则返回错误信息。 -func update(e Engine, v TableNamer, cols ...string) (sql.Result, error) { +func update(ctx context.Context, e Engine, v TableNamer, cols ...string) (sql.Result, error) { stmt := e.SQLBuilder().Update() m, rval, err := getUpdateColumns(e, v, stmt, cols...) @@ -335,7 +336,7 @@ func update(e Engine, v TableNamer, cols ...string) (sql.Result, error) { return nil, err } - return stmt.Exec() + return stmt.ExecContext(ctx) } func getUpdateColumns(e Engine, v TableNamer, stmt *sqlbuilder.UpdateStmt, cols ...string) (*core.Model, reflect.Value, error) { @@ -378,7 +379,7 @@ func getUpdateColumns(e Engine, v TableNamer, stmt *sqlbuilder.UpdateStmt, cols } // 将 v 生成 delete 的 sql 语句 -func del(e Engine, v TableNamer) (sql.Result, error) { +func del(ctx context.Context, e Engine, v TableNamer) (sql.Result, error) { m, rval, err := getModel(e, v) if err != nil { return nil, err @@ -393,7 +394,7 @@ func del(e Engine, v TableNamer) (sql.Result, error) { return nil, err } - return stmt.Exec() + return stmt.ExecContext(ctx) } var errInsertManyHasDifferentType = errors.New("InsertMany 必须是相同的数据类型") diff --git a/sqlbuilder/insert.go b/sqlbuilder/insert.go index 55522ce..213a5f8 100644 --- a/sqlbuilder/insert.go +++ b/sqlbuilder/insert.go @@ -206,7 +206,7 @@ func (stmt *InsertStmt) fromSelect(builder *core.Builder) (string, []any, error) // 并根据表名和自增列 ID 返回当前行的自增 ID 值。 // // NOTE: 对于指定了自增值的,其结果是未知的。 -func (stmt *InsertStmt) LastInsertID(table, col string) (int64, error) { +func (stmt *InsertStmt) LastInsertID(col string) (int64, error) { return stmt.LastInsertIDContext(context.Background(), col) } diff --git a/sqlbuilder/select.go b/sqlbuilder/select.go index 0f30c39..bbe03fa 100644 --- a/sqlbuilder/select.go +++ b/sqlbuilder/select.go @@ -437,25 +437,37 @@ func (stmt *SelectStmt) Union(all bool, sel ...*SelectStmt) *SelectStmt { // // 关于 objs 的类型,可以参考 [fetch.Object] 函数的相关介绍。 func (stmt *SelectStmt) QueryObject(strict bool, objs any) (size int, err error) { - rows, err := stmt.Query() + return stmt.QueryObjectContext(context.Background(), strict, objs) +} + +func (stmt *SelectStmt) QueryObjectContext(ctx context.Context, strict bool, objs any) (size int, err error) { + rows, err := stmt.QueryContext(ctx) if err != nil { return 0, err } - return queryObject(rows, strict, objs) + return fetchObject(rows, strict, objs) } // QueryString 查询指定列的第一行数据,并将其转换成 string func (stmt *SelectStmt) QueryString(colName string) (v string, err error) { - rows, err := stmt.Query() + return stmt.QueryStringContext(context.Background(), colName) +} + +func (stmt *SelectStmt) QueryStringContext(ctx context.Context, colName string) (v string, err error) { + rows, err := stmt.QueryContext(ctx) if err != nil { return "", err } - return queryString(rows, colName) + return fetchString(rows, colName) } // QueryFloat 查询指定列的第一行数据,并将其转换成 float64 func (stmt *SelectStmt) QueryFloat(colName string) (float64, error) { - v, err := stmt.QueryString(colName) + return stmt.QueryFloatContext(context.Background(), colName) +} + +func (stmt *SelectStmt) QueryFloatContext(ctx context.Context, colName string) (float64, error) { + v, err := stmt.QueryStringContext(ctx, colName) if err != nil { return 0, err } @@ -465,10 +477,14 @@ func (stmt *SelectStmt) QueryFloat(colName string) (float64, error) { // QueryInt 查询指定列的第一行数据,并将其转换成 int64 func (stmt *SelectStmt) QueryInt(colName string) (int64, error) { + return stmt.QueryIntContext(context.Background(), colName) +} + +func (stmt *SelectStmt) QueryIntContext(ctx context.Context, colName string) (int64, error) { // NOTE: 可能会出现浮点数的情况。比如: // select avg(xx) as avg form xxx where xxx // 查询 avg 的值可能是 5.000 等值。 - v, err := stmt.QueryString(colName) + v, err := stmt.QueryStringContext(ctx, colName) if err != nil { return 0, err } @@ -503,7 +519,7 @@ func (stmt *SelectQuery) QueryObject(strict bool, objs any, arg ...any) (size in if err != nil { return 0, err } - return queryObject(rows, strict, objs) + return fetchObject(rows, strict, objs) } // QueryString 查询指定列的第一行数据,并将其转换成 string @@ -512,7 +528,7 @@ func (stmt *SelectQuery) QueryString(colName string, arg ...any) (v string, err if err != nil { return "", err } - return queryString(rows, colName) + return fetchString(rows, colName) } // QueryFloat 查询指定列的第一行数据,并将其转换成 float64 @@ -540,13 +556,13 @@ func (stmt *SelectQuery) QueryInt(colName string, arg ...any) (int64, error) { func (stmt *SelectQuery) Close() error { return stmt.stmt.Close() } -func queryObject(rows *sql.Rows, strict bool, objs any) (size int, err error) { +func fetchObject(rows *sql.Rows, strict bool, objs any) (size int, err error) { defer func() { err = errors.Join(err, rows.Close()) }() size, err = fetch.Object(strict, rows, objs) return } -func queryString(rows *sql.Rows, colName string) (v string, err error) { +func fetchString(rows *sql.Rows, colName string) (v string, err error) { defer func() { err = errors.Join(err, rows.Close()) }() cols, err := fetch.ColumnString(true, colName, rows) diff --git a/sqlbuilder/sqlbuilder_test.go b/sqlbuilder/sqlbuilder_test.go index 0881074..9c70fcc 100644 --- a/sqlbuilder/sqlbuilder_test.go +++ b/sqlbuilder/sqlbuilder_test.go @@ -75,7 +75,7 @@ func initDB(t *test.Driver) { sql.Table("users"). Columns("name"). Values("7") - id, err := sql.LastInsertID("users", "id") + id, err := sql.LastInsertID("id") t.NotError(err, "%s@%s", err, t.DriverName). Equal(id, 7, "%d != %d @ %s", id, 7, t.DriverName) @@ -84,7 +84,7 @@ func initDB(t *test.Driver) { Columns("name"). Values("8"). Values("9") - id, err = sql.LastInsertID("users", "id") + id, err = sql.LastInsertID("id") t.Error(err, "%s@%s", err, t.DriverName). Empty(id, "not empty @%s", t.DriverName) } diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index c593f66..78eb702 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -509,7 +509,7 @@ func (stmt *TableExistsStmt) Exists() (bool, error) { return false, err } - name, err := queryString(rows, "name") + name, err := fetchString(rows, "name") switch { case errors.Is(err, ErrNoData): return false, nil diff --git a/sqlbuilder/view.go b/sqlbuilder/view.go index a32da8b..942015a 100644 --- a/sqlbuilder/view.go +++ b/sqlbuilder/view.go @@ -195,7 +195,7 @@ func (stmt *ViewExistsStmt) Exists() (bool, error) { return false, err } - name, err := queryString(rows, "name") + name, err := fetchString(rows, "name") switch { case errors.Is(err, ErrNoData): return false, nil diff --git a/tx.go b/tx.go index a046aba..61c12bb 100644 --- a/tx.go +++ b/tx.go @@ -42,26 +42,65 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { }, nil } -func (tx *Tx) LastInsertID(v TableNamer) (int64, error) { return lastInsertID(tx, v) } +func (tx *Tx) LastInsertID(v TableNamer) (int64, error) { + return tx.LastInsertIDContext(context.Background(), v) +} +func (tx *Tx) LastInsertIDContext(ctx context.Context, v TableNamer) (int64, error) { + return lastInsertID(ctx, tx, v) +} + +func (tx *Tx) Insert(v TableNamer) (sql.Result, error) { + return tx.InsertContext(context.Background(), v) +} + +func (tx *Tx) InsertContext(ctx context.Context, v TableNamer) (sql.Result, error) { + return insert(ctx, tx, v) +} -func (tx *Tx) Insert(v TableNamer) (sql.Result, error) { return insert(tx, v) } +func (tx *Tx) Select(v TableNamer) (bool, error) { return tx.SelectContext(context.Background(), v) } -func (tx *Tx) Select(v TableNamer) (bool, error) { return find(tx, v) } +func (tx *Tx) SelectContext(ctx context.Context, v TableNamer) (bool, error) { return find(ctx, tx, v) } // ForUpdate 读数据并锁定 -func (tx *Tx) ForUpdate(v TableNamer) error { return forUpdate(tx, v) } +func (tx *Tx) ForUpdate(v TableNamer) error { return tx.ForUpdateContext(context.Background(), v) } + +func (tx *Tx) ForUpdateContext(ctx context.Context, v TableNamer) error { return forUpdate(ctx, tx, v) } + +func (tx *Tx) InsertMany(max int, v ...TableNamer) error { + return tx.InsertManyContext(context.Background(), max, v...) +} + +func (tx *Tx) InsertManyContext(ctx context.Context, max int, v ...TableNamer) error { + return txInsertMany(ctx, tx, max, v...) +} + +func (tx *Tx) Update(v TableNamer, cols ...string) (sql.Result, error) { + return tx.UpdateContext(context.Background(), v, cols...) +} + +func (tx *Tx) UpdateContext(ctx context.Context, v TableNamer, cols ...string) (sql.Result, error) { + return update(ctx, tx, v, cols...) +} -func (tx *Tx) InsertMany(max int, v ...TableNamer) error { return txInsertMany(tx, max, v...) } +func (tx *Tx) Delete(v TableNamer) (sql.Result, error) { + return tx.DeleteContext(context.Background(), v) +} + +func (tx *Tx) DeleteContext(ctx context.Context, v TableNamer) (sql.Result, error) { + return del(ctx, tx, v) +} -func (tx *Tx) Update(v TableNamer, cols ...string) (sql.Result, error) { return update(tx, v, cols...) } +func (tx *Tx) Create(v TableNamer) error { return tx.CreateContext(context.Background(), v) } -func (tx *Tx) Delete(v TableNamer) (sql.Result, error) { return del(tx, v) } +func (tx *Tx) CreateContext(ctx context.Context, v TableNamer) error { return create(ctx, tx, v) } -func (tx *Tx) Create(v TableNamer) error { return create(tx, v) } +func (tx *Tx) Drop(v TableNamer) error { return tx.DropContext(context.Background(), v) } -func (tx *Tx) Drop(v TableNamer) error { return drop(tx, v) } +func (tx *Tx) DropContext(ctx context.Context, v TableNamer) error { return drop(ctx, tx, v) } -func (tx *Tx) Truncate(v TableNamer) error { return truncate(tx, v) } +func (tx *Tx) Truncate(v TableNamer) error { return tx.TruncateContext(context.Background(), v) } + +func (tx *Tx) TruncateContext(ctx context.Context, v TableNamer) error { return truncate(ctx, tx, v) } func (tx *Tx) SQLBuilder() *sqlbuilder.SQLBuilder { return sqlbuilder.New(tx) // 事务一般是一个临时对象,没必要像 [DB] 一样固定 sqlbuilder 对象。 @@ -89,31 +128,73 @@ func (tx *Tx) NewEngine(tablePrefix string) Engine { } } -func (p *txEngine) LastInsertID(v TableNamer) (int64, error) { return lastInsertID(p, v) } +func (p *txEngine) LastInsertID(v TableNamer) (int64, error) { + return p.LastInsertIDContext(context.Background(), v) +} + +func (p *txEngine) LastInsertIDContext(ctx context.Context, v TableNamer) (int64, error) { + return lastInsertID(ctx, p, v) +} + +func (p *txEngine) Insert(v TableNamer) (sql.Result, error) { + return p.InsertContext(context.Background(), v) +} + +func (p *txEngine) InsertContext(ctx context.Context, v TableNamer) (sql.Result, error) { + return insert(ctx, p, v) +} -func (p *txEngine) Insert(v TableNamer) (sql.Result, error) { return insert(p, v) } +func (p *txEngine) Delete(v TableNamer) (sql.Result, error) { + return p.DeleteContext(context.Background(), v) +} -func (p *txEngine) Delete(v TableNamer) (sql.Result, error) { return del(p, v) } +func (p *txEngine) DeleteContext(ctx context.Context, v TableNamer) (sql.Result, error) { + return del(ctx, p, v) +} func (p *txEngine) Update(v TableNamer, cols ...string) (sql.Result, error) { - return update(p, v, cols...) + return p.UpdateContext(context.Background(), v, cols...) +} + +func (p *txEngine) UpdateContext(ctx context.Context, v TableNamer, cols ...string) (sql.Result, error) { + return update(ctx, p, v, cols...) } -func (p *txEngine) Select(v TableNamer) (bool, error) { return find(p, v) } +func (p *txEngine) Select(v TableNamer) (bool, error) { + return p.SelectContext(context.Background(), v) +} + +func (p *txEngine) SelectContext(ctx context.Context, v TableNamer) (bool, error) { + return find(ctx, p, v) +} + +func (p *txEngine) Create(v TableNamer) error { return p.CreateContext(context.Background(), v) } + +func (p *txEngine) CreateContext(ctx context.Context, v TableNamer) error { return create(ctx, p, v) } -func (p *txEngine) Create(v TableNamer) error { return create(p, v) } +func (p *txEngine) Drop(v TableNamer) error { return p.DropContext(context.Background(), v) } -func (p *txEngine) Drop(v TableNamer) error { return drop(p, v) } +func (p *txEngine) DropContext(ctx context.Context, v TableNamer) error { return drop(ctx, p, v) } -func (p *txEngine) Truncate(v TableNamer) error { return truncate(p, v) } +func (p *txEngine) Truncate(v TableNamer) error { return p.TruncateContext(context.Background(), v) } -func (p *txEngine) InsertMany(max int, v ...TableNamer) error { return txInsertMany(p, max, v...) } +func (p *txEngine) TruncateContext(ctx context.Context, v TableNamer) error { + return truncate(ctx, p, v) +} + +func (p *txEngine) InsertMany(max int, v ...TableNamer) error { + return p.InsertManyContext(context.Background(), max, v...) +} + +func (p *txEngine) InsertManyContext(ctx context.Context, max int, v ...TableNamer) error { + return txInsertMany(ctx, p, max, v...) +} func (p *txEngine) SQLBuilder() *sqlbuilder.SQLBuilder { return sqlbuilder.New(p) // txPrefix 般是一个临时对象,没必要像 [DB] 一样固定 sqlbuilder 对象。 } -func txInsertMany(tx Engine, max int, v ...TableNamer) error { +func txInsertMany(ctx context.Context, tx Engine, max int, v ...TableNamer) error { l := len(v) for i := 0; i < l; i += max { j := min(i+max, l) @@ -122,7 +203,7 @@ func txInsertMany(tx Engine, max int, v ...TableNamer) error { return err } - if _, err = query.Exec(); err != nil { + if _, err = query.ExecContext(ctx); err != nil { return err } } diff --git a/types.go b/types.go index 21ad7ca..661bddd 100644 --- a/types.go +++ b/types.go @@ -5,6 +5,7 @@ package orm import ( + "context" "database/sql" "time" @@ -51,7 +52,7 @@ type ( Engine interface { core.Engine - // LastInsertID 插入一条数据并返回其自增 ID + // LastInsertIDContext 插入一条数据并返回其自增 ID // // 理论上功能等同于以下两步操作: // rslt, err := engine.Insert(obj) @@ -61,28 +62,32 @@ type ( // 更简单和安全的方法。 // // NOTE: 要求 v 有定义自增列。 + LastInsertIDContext(ctx context.Context, v TableNamer) (int64, error) LastInsertID(v TableNamer) (int64, error) - // Insert 插入数据 + // InsertContext 插入数据 // // NOTE: 若需一次性插入多条数据,请使用 [Engine.InsertMany] 。 + InsertContext(ctx context.Context, v TableNamer) (sql.Result, error) Insert(v TableNamer) (sql.Result, error) // Delete 删除符合条件的数据 // // 查找条件以结构体定义的主键或是唯一约束(在没有主键的情况下)来查找, // 若两者都不存在,则将返回 error + DeleteContext(ctx context.Context, v TableNamer) (sql.Result, error) Delete(v TableNamer) (sql.Result, error) - // Update 更新数据 + // UpdateContext 更新数据 // // 零值不会被提交,cols 指定的列,即使是零值也会被更新。 // // 查找条件以结构体定义的主键或是唯一约束(在没有主键的情况下)来查找, // 若两者都不存在,则将返回 error + UpdateContext(ctx context.Context, v TableNamer, cols ...string) (sql.Result, error) Update(v TableNamer, cols ...string) (sql.Result, error) - // Select 查询一个符合条件的数据 + // SelectContext 查询一个符合条件的数据 // // 查找条件以结构体定义的主键或是唯一约束(在没有主键的情况下 ) 来查找, // 若两者都不存在,则将返回 error @@ -90,21 +95,26 @@ type ( // // 查找条件的查找顺序是为 自增 > 主键 > 唯一约束, // 如果同时存在多个唯一约束满足条件(可能每个唯一约束查询至的结果是不一样的),则返回错误信息。 + SelectContext(ctx context.Context, v TableNamer) (found bool, err error) Select(v TableNamer) (found bool, err error) + CreateContext(ctx context.Context, v TableNamer) error Create(v TableNamer) error + DropContext(ctx context.Context, v TableNamer) error Drop(v TableNamer) error - // Truncate 清空表并重置 ai 但保留表结构 + // TruncateContext 清空表并重置 ai 但保留表结构 + TruncateContext(ctx context.Context, v TableNamer) error Truncate(v TableNamer) error - // InsertMany 插入多条相同的数据 + // InsertManyContext 插入多条相同的数据 // // 若需要向某张表中插入多条记录,此方法会比 [Engine.Insert] 性能上好很多。 // // max 表示一次最多插入的数量,如果超过此值,会分批执行, // 但是依然在一个事务中完成。 + InsertManyContext(ctx context.Context, max int, v ...TableNamer) error InsertMany(max int, v ...TableNamer) error // Where 生成 [WhereStmt] 语句