From ef2e659db2ed6b498aa36b6cdd04b55002766468 Mon Sep 17 00:00:00 2001 From: tanyinloo Date: Fri, 1 Mar 2024 16:10:46 +0800 Subject: [PATCH] implement TxWithLock --- coredb/tx.go | 90 ++++++++++++++++++++++++++++++++++++++-------- tests/tx_test.go | 90 ++++++++++++++++++++++++++++++++++------------ tests/user_test.go | 19 +++++----- 3 files changed, 152 insertions(+), 47 deletions(-) diff --git a/coredb/tx.go b/coredb/tx.go index 6ade19d..460a69b 100644 --- a/coredb/tx.go +++ b/coredb/tx.go @@ -11,8 +11,8 @@ import ( // Make sure you call Commit or Rollback on the returned Tx. // Refer to https://go.dev/doc/database/execute-transactions on how to use the returned Tx. func BeginTx(ctx context.Context, dbname string, opts *sql.TxOptions) (tx *sql.Tx, err error) { - mydb := getDB(dbname, DBModeWrite) - return mydb.BeginTx(ctx, opts) + myDB := getDB(dbname, DBModeWrite) + return myDB.BeginTx(ctx, opts) } // DefaultTxOpts is package variable with default transaction level @@ -21,6 +21,14 @@ var DefaultTxOpts = sql.TxOptions{ ReadOnly: false, } +func newLockError(lock string, durationInSec int) error { + return fmt.Errorf("fail to acquire lock: %s, durationInSec: %d", lock, durationInSec) +} + +func newReleaseLockError(lock string, durationInSec int) error { + return fmt.Errorf("fail to release lock: %s, durationInSec: %d", lock, durationInSec) +} + // TxContext interface for DAO operations with context. type TxContext interface { // Exec executes a query without returning any rows. @@ -67,7 +75,9 @@ func (t *tx) Query(results any, query string, params ...any) error { if err != nil { return err } - defer rows.Close() + defer func(rows *sql.Rows) { + err = rows.Close() + }(rows) return RowsToStructSliceReflect(rows, results) } @@ -86,7 +96,7 @@ func (t *tx) FindOne(result any, tableName string, whereSQL string, params ...an if err2 != nil { // It's on purpose the hide the error // But should re-consider later - if err2 == sql.ErrNoRows { + if errors.Is(err2, sql.ErrNoRows) { return nil } return err2 @@ -112,27 +122,39 @@ func (t *tx) Rollback() error { return t.Tx.Rollback() } -// Connector for sql database. -type Connector interface { +// TxStarter for sql database. +type TxStarter interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } +type ConnectionGetter interface { + Conn(ctx context.Context) (*sql.Conn, error) +} + +type TxStarterWithConnection interface { + TxStarter + ConnectionGetter +} + // TxProvider ... type TxProvider struct { - conn Connector + conn TxStarterWithConnection } // NewTxProvider ... func NewTxProvider(dbname string) *TxProvider { - mydb := getDB(dbname, DBModeWrite) + myDB := getDB(dbname, DBModeWrite) return &TxProvider{ - conn: mydb, + conn: myDB, } } // acquireWithOpts transaction from db -func (t *TxProvider) acquireWithOpts(ctx context.Context, opts *sql.TxOptions) (*tx, error) { - trx, err := t.conn.BeginTx(ctx, opts) +func (t *TxProvider) acquireWithOpts(ctx context.Context, conn TxStarter, opts *sql.TxOptions) (*tx, error) { + if conn == nil { + conn = t.conn + } + trx, err := conn.BeginTx(ctx, opts) if err != nil { return nil, err } @@ -144,9 +166,9 @@ func (t *TxProvider) acquireWithOpts(ctx context.Context, opts *sql.TxOptions) ( } // TxWithOpts ... -func (t *TxProvider) TxWithOpts(ctx context.Context, fn func(TxContext) error, opts *sql.TxOptions) (err error) { +func (t *TxProvider) TxWithOpts(ctx context.Context, fn func(TxContext) error, conn TxStarter, opts *sql.TxOptions) (err error) { var trx *tx - trx, err = t.acquireWithOpts(ctx, opts) + trx, err = t.acquireWithOpts(ctx, conn, opts) if err != nil { return err } @@ -180,5 +202,45 @@ func (t *TxProvider) TxWithOpts(ctx context.Context, fn func(TxContext) error, o // Tx runs fn in transaction. func (t *TxProvider) Tx(ctx context.Context, fn func(TxContext) error) error { - return t.TxWithOpts(ctx, fn, &DefaultTxOpts) + return t.TxWithOpts(ctx, fn, nil, &DefaultTxOpts) +} + +func (t *TxProvider) TxWithLock(ctx context.Context, lock string, durationInSec int, fn func(txContext TxContext) error) error { + dbConn, err := t.conn.Conn(ctx) + if err != nil { + return fmt.Errorf("fail to get db connection: %w", err) + } + + { + var res int + err = dbConn.QueryRowContext(ctx, "select get_lock(?,?)", lock, durationInSec).Scan(&res) + if err != nil { + return fmt.Errorf("get_lock failed: %w", err) + } + if res != 1 { + return newLockError(lock, durationInSec) + } + } + + defer func() { + var res int + errRelease := dbConn.QueryRowContext(ctx, "select release_lock(?)", lock).Scan(&res) + if errRelease != nil { + if err == nil { + err = fmt.Errorf("release_lock failed: %w", errRelease) + } else { + err = errors.Join(err, fmt.Errorf("release_lock failed: %w", errRelease)) + } + return + } + if res != 1 { + if err == nil { + err = newReleaseLockError(lock, durationInSec) + } else { + err = errors.Join(err, newReleaseLockError(lock, durationInSec)) + } + } + }() + + return t.TxWithOpts(ctx, fn, dbConn, &DefaultTxOpts) } diff --git a/tests/tx_test.go b/tests/tx_test.go index 68db8c4..231eadd 100644 --- a/tests/tx_test.go +++ b/tests/tx_test.go @@ -2,11 +2,12 @@ package tests import ( "context" - "database/sql" "errors" "fmt" + "log" "reflect" - "strings" + "sync" + "testing" "time" _ "github.com/go-sql-driver/mysql" @@ -15,9 +16,52 @@ import ( "github.com/olachat/gola/v2/golalib/testdata/worker" ) -func ExampleNewTxProvider() { +func TestTxWithLock(t *testing.T) { + prov := coredb.NewTxProvider("testdb") + ctx := context.Background() + log.SetFlags(log.LstdFlags | log.Lmicroseconds) + wg := &sync.WaitGroup{} + + wg.Add(1) + go func() { + defer wg.Done() + log.Println("1: start lock") + err1 := prov.TxWithLock(ctx, "lock", 2, func(tx coredb.TxContext) error { + log.Println("1: locked") + time.Sleep(1800 * time.Millisecond) + log.Println("1: start unlock") + return nil + }) + if err1 != nil { + log.Printf("1: error: %v", err1) + } + log.Println("1: unlocked") + }() + + time.Sleep(10 * time.Millisecond) + wg.Add(1) + go func() { + defer wg.Done() + log.Println("2: start lock") + err2 := prov.TxWithLock(ctx, "lock", 1, func(tx coredb.TxContext) error { + log.Println("2: locked") + time.Sleep(800 * time.Millisecond) + log.Println("2: start unlock") + return nil + }) + if err2 != nil { + log.Printf("2: error: %v", err2) + } else { + t.Error("1st goroutine takes 1.8s. 2nd goroutine only wait for the lock 1 second.. should return fail to acquire lock error") + } + log.Println("2: unlocked") + }() - prov := coredb.NewTxProvider("newdb") + wg.Wait() +} + +func ExampleNewTxProvider() { + prov := coredb.NewTxProvider("testdb") err := prov.Tx(context.Background(), func(tx coredb.TxContext) error { _, err := tx.Exec("truncate table worker") panicOnErr(err) @@ -89,7 +133,7 @@ func ExampleNewTxProvider() { }) panicOnErr(err) - prov2 := coredb.NewTxProvider("newdb") + prov2 := coredb.NewTxProvider("testdb") ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() err = prov2.Tx(ctx, func(tx coredb.TxContext) error { @@ -108,7 +152,6 @@ func ExampleNewTxProvider() { if !errors.Is(err, context.DeadlineExceeded) { panic(err) } - } func panicOnErr(err error) { @@ -116,31 +159,32 @@ func panicOnErr(err error) { panic(err) } } + func mustEqual(a, b interface{}) { if !reflect.DeepEqual(a, b) { panic(fmt.Sprintf("%v != %v", a, b)) } } -func open() (db *sql.DB, err error) { - dsn := "root:123456@tcp(127.0.0.1:3307)/newdb" - if !strings.Contains(dsn, "?parseTime=true") { - dsn += "?parseTime=true" - } +// func open() (db *sql.DB, err error) { +// dsn := "root:123456@tcp(127.0.0.1:3307)/testdb" +// if !strings.Contains(dsn, "?parseTime=true") { +// dsn += "?parseTime=true" +// } - maxIdle := 3.0 +// maxIdle := 3.0 - maxOpen := 50.0 +// maxOpen := 50.0 - maxLifetime := 30.0 +// maxLifetime := 30.0 - db, err = sql.Open("mysql", dsn) - if err != nil { - return nil, err - } +// db, err = sql.Open("mysql", dsn) +// if err != nil { +// return nil, err +// } - db.SetConnMaxIdleTime(time.Duration(maxIdle) * time.Second) - db.SetConnMaxLifetime(time.Duration(maxLifetime) * time.Second) - db.SetMaxOpenConns(int(maxOpen)) - return -} +// db.SetConnMaxIdleTime(time.Duration(maxIdle) * time.Second) +// db.SetConnMaxLifetime(time.Duration(maxLifetime) * time.Second) +// db.SetMaxOpenConns(int(maxOpen)) +// return +// } diff --git a/tests/user_test.go b/tests/user_test.go index a820b5a..f4ed9c8 100644 --- a/tests/user_test.go +++ b/tests/user_test.go @@ -23,9 +23,11 @@ const ( testDBName string = "testdb" ) -var tableNames = []string{"users", "blogs", "songs", "song_user_favourites", "profile", "account", +var tableNames = []string{ + "users", "blogs", "songs", "song_user_favourites", "profile", "account", "gifts", "gifts_with_default", "gifts_nn", "gifts_nn_with_default", "wallet", + "worker", } func init() { @@ -52,26 +54,23 @@ func init() { panic(err) } - realdb, err := open() + // realdb, err := open() - if err != nil { - panic(err) - } + // if err != nil { + // panic(err) + // } coredb.Setup(func(dbname string, mode coredb.DBMode) *sql.DB { - if dbname == "newdb" { - return realdb - } return db }) - //create tables + // create tables for _, tableName := range tableNames { query, _ := testdata.Fixtures.ReadFile(tableName + ".sql") db.Exec(string(query)) } - //add data + // add data _, err = db.Exec(` insert into users (name, email, created_at, updated_at, float_type, double_type, hobby, hobby_no_default, sports_no_default, sports) values ("John Doe", "john@doe.com", NOW(), NOW(), 1.55555, 1.8729, 'running','swimming', ('SWIM,TENNIS'), ("TENNIS")),