Skip to content

Commit

Permalink
Merge pull request #58 from olachat/yinloo/tx-conn
Browse files Browse the repository at this point in the history
implement TxWithLock
  • Loading branch information
yinloo-ola authored Mar 1, 2024
2 parents 3902eab + ef2e659 commit 05ee64c
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 47 deletions.
90 changes: 76 additions & 14 deletions coredb/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
// Make sure you call Commit or Rollback on the returned Tx.
// Refer to https://go.dev/doc/database/execute-transactions on how to use the returned Tx.
func BeginTx(ctx context.Context, dbname string, opts *sql.TxOptions) (tx *sql.Tx, err error) {
mydb := getDB(dbname, DBModeWrite)
return mydb.BeginTx(ctx, opts)
myDB := getDB(dbname, DBModeWrite)
return myDB.BeginTx(ctx, opts)
}

// DefaultTxOpts is package variable with default transaction level
Expand All @@ -21,6 +21,14 @@ var DefaultTxOpts = sql.TxOptions{
ReadOnly: false,
}

func newLockError(lock string, durationInSec int) error {
return fmt.Errorf("fail to acquire lock: %s, durationInSec: %d", lock, durationInSec)
}

func newReleaseLockError(lock string, durationInSec int) error {
return fmt.Errorf("fail to release lock: %s, durationInSec: %d", lock, durationInSec)
}

// TxContext interface for DAO operations with context.
type TxContext interface {
// Exec executes a query without returning any rows.
Expand Down Expand Up @@ -67,7 +75,9 @@ func (t *tx) Query(results any, query string, params ...any) error {
if err != nil {
return err
}
defer rows.Close()
defer func(rows *sql.Rows) {
err = rows.Close()
}(rows)
return RowsToStructSliceReflect(rows, results)
}

Expand All @@ -86,7 +96,7 @@ func (t *tx) FindOne(result any, tableName string, whereSQL string, params ...an
if err2 != nil {
// It's on purpose the hide the error
// But should re-consider later
if err2 == sql.ErrNoRows {
if errors.Is(err2, sql.ErrNoRows) {
return nil
}
return err2
Expand All @@ -112,27 +122,39 @@ func (t *tx) Rollback() error {
return t.Tx.Rollback()
}

// Connector for sql database.
type Connector interface {
// TxStarter for sql database.
type TxStarter interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}

type ConnectionGetter interface {
Conn(ctx context.Context) (*sql.Conn, error)
}

type TxStarterWithConnection interface {
TxStarter
ConnectionGetter
}

// TxProvider ...
type TxProvider struct {
conn Connector
conn TxStarterWithConnection
}

// NewTxProvider ...
func NewTxProvider(dbname string) *TxProvider {
mydb := getDB(dbname, DBModeWrite)
myDB := getDB(dbname, DBModeWrite)
return &TxProvider{
conn: mydb,
conn: myDB,
}
}

// acquireWithOpts transaction from db
func (t *TxProvider) acquireWithOpts(ctx context.Context, opts *sql.TxOptions) (*tx, error) {
trx, err := t.conn.BeginTx(ctx, opts)
func (t *TxProvider) acquireWithOpts(ctx context.Context, conn TxStarter, opts *sql.TxOptions) (*tx, error) {
if conn == nil {
conn = t.conn
}
trx, err := conn.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
Expand All @@ -144,9 +166,9 @@ func (t *TxProvider) acquireWithOpts(ctx context.Context, opts *sql.TxOptions) (
}

// TxWithOpts ...
func (t *TxProvider) TxWithOpts(ctx context.Context, fn func(TxContext) error, opts *sql.TxOptions) (err error) {
func (t *TxProvider) TxWithOpts(ctx context.Context, fn func(TxContext) error, conn TxStarter, opts *sql.TxOptions) (err error) {
var trx *tx
trx, err = t.acquireWithOpts(ctx, opts)
trx, err = t.acquireWithOpts(ctx, conn, opts)
if err != nil {
return err
}
Expand Down Expand Up @@ -180,5 +202,45 @@ func (t *TxProvider) TxWithOpts(ctx context.Context, fn func(TxContext) error, o

// Tx runs fn in transaction.
func (t *TxProvider) Tx(ctx context.Context, fn func(TxContext) error) error {
return t.TxWithOpts(ctx, fn, &DefaultTxOpts)
return t.TxWithOpts(ctx, fn, nil, &DefaultTxOpts)
}

func (t *TxProvider) TxWithLock(ctx context.Context, lock string, durationInSec int, fn func(txContext TxContext) error) error {
dbConn, err := t.conn.Conn(ctx)
if err != nil {
return fmt.Errorf("fail to get db connection: %w", err)
}

{
var res int
err = dbConn.QueryRowContext(ctx, "select get_lock(?,?)", lock, durationInSec).Scan(&res)
if err != nil {
return fmt.Errorf("get_lock failed: %w", err)
}
if res != 1 {
return newLockError(lock, durationInSec)
}
}

defer func() {
var res int
errRelease := dbConn.QueryRowContext(ctx, "select release_lock(?)", lock).Scan(&res)
if errRelease != nil {
if err == nil {
err = fmt.Errorf("release_lock failed: %w", errRelease)
} else {
err = errors.Join(err, fmt.Errorf("release_lock failed: %w", errRelease))

Check failure on line 232 in coredb/tx.go

View workflow job for this annotation

GitHub Actions / build

undefined: errors.Join
}
return
}
if res != 1 {
if err == nil {
err = newReleaseLockError(lock, durationInSec)
} else {
err = errors.Join(err, newReleaseLockError(lock, durationInSec))

Check failure on line 240 in coredb/tx.go

View workflow job for this annotation

GitHub Actions / build

undefined: errors.Join
}
}
}()

return t.TxWithOpts(ctx, fn, dbConn, &DefaultTxOpts)
}
90 changes: 67 additions & 23 deletions tests/tx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package tests

import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"reflect"
"strings"
"sync"
"testing"
"time"

_ "github.com/go-sql-driver/mysql"
Expand All @@ -15,9 +16,52 @@ import (
"github.com/olachat/gola/v2/golalib/testdata/worker"
)

func ExampleNewTxProvider() {
func TestTxWithLock(t *testing.T) {
prov := coredb.NewTxProvider("testdb")
ctx := context.Background()
log.SetFlags(log.LstdFlags | log.Lmicroseconds)
wg := &sync.WaitGroup{}

wg.Add(1)
go func() {
defer wg.Done()
log.Println("1: start lock")
err1 := prov.TxWithLock(ctx, "lock", 2, func(tx coredb.TxContext) error {
log.Println("1: locked")
time.Sleep(1800 * time.Millisecond)
log.Println("1: start unlock")
return nil
})
if err1 != nil {
log.Printf("1: error: %v", err1)
}
log.Println("1: unlocked")
}()

time.Sleep(10 * time.Millisecond)
wg.Add(1)
go func() {
defer wg.Done()
log.Println("2: start lock")
err2 := prov.TxWithLock(ctx, "lock", 1, func(tx coredb.TxContext) error {
log.Println("2: locked")
time.Sleep(800 * time.Millisecond)
log.Println("2: start unlock")
return nil
})
if err2 != nil {
log.Printf("2: error: %v", err2)
} else {
t.Error("1st goroutine takes 1.8s. 2nd goroutine only wait for the lock 1 second.. should return fail to acquire lock error")
}
log.Println("2: unlocked")
}()

prov := coredb.NewTxProvider("newdb")
wg.Wait()
}

func ExampleNewTxProvider() {
prov := coredb.NewTxProvider("testdb")
err := prov.Tx(context.Background(), func(tx coredb.TxContext) error {
_, err := tx.Exec("truncate table worker")
panicOnErr(err)
Expand Down Expand Up @@ -89,7 +133,7 @@ func ExampleNewTxProvider() {
})
panicOnErr(err)

prov2 := coredb.NewTxProvider("newdb")
prov2 := coredb.NewTxProvider("testdb")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
err = prov2.Tx(ctx, func(tx coredb.TxContext) error {
Expand All @@ -108,39 +152,39 @@ func ExampleNewTxProvider() {
if !errors.Is(err, context.DeadlineExceeded) {
panic(err)
}

}

func panicOnErr(err error) {
if err != nil {
panic(err)
}
}

func mustEqual(a, b interface{}) {
if !reflect.DeepEqual(a, b) {
panic(fmt.Sprintf("%v != %v", a, b))
}
}

func open() (db *sql.DB, err error) {
dsn := "root:123456@tcp(127.0.0.1:3307)/newdb"
if !strings.Contains(dsn, "?parseTime=true") {
dsn += "?parseTime=true"
}
// func open() (db *sql.DB, err error) {
// dsn := "root:123456@tcp(127.0.0.1:3307)/testdb"
// if !strings.Contains(dsn, "?parseTime=true") {
// dsn += "?parseTime=true"
// }

maxIdle := 3.0
// maxIdle := 3.0

maxOpen := 50.0
// maxOpen := 50.0

maxLifetime := 30.0
// maxLifetime := 30.0

db, err = sql.Open("mysql", dsn)
if err != nil {
return nil, err
}
// db, err = sql.Open("mysql", dsn)
// if err != nil {
// return nil, err
// }

db.SetConnMaxIdleTime(time.Duration(maxIdle) * time.Second)
db.SetConnMaxLifetime(time.Duration(maxLifetime) * time.Second)
db.SetMaxOpenConns(int(maxOpen))
return
}
// db.SetConnMaxIdleTime(time.Duration(maxIdle) * time.Second)
// db.SetConnMaxLifetime(time.Duration(maxLifetime) * time.Second)
// db.SetMaxOpenConns(int(maxOpen))
// return
// }
19 changes: 9 additions & 10 deletions tests/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ const (
testDBName string = "testdb"
)

var tableNames = []string{"users", "blogs", "songs", "song_user_favourites", "profile", "account",
var tableNames = []string{
"users", "blogs", "songs", "song_user_favourites", "profile", "account",
"gifts", "gifts_with_default",
"gifts_nn", "gifts_nn_with_default", "wallet",
"worker",
}

func init() {
Expand All @@ -52,26 +54,23 @@ func init() {
panic(err)
}

realdb, err := open()
// realdb, err := open()

if err != nil {
panic(err)
}
// if err != nil {
// panic(err)
// }

coredb.Setup(func(dbname string, mode coredb.DBMode) *sql.DB {
if dbname == "newdb" {
return realdb
}
return db
})

//create tables
// create tables
for _, tableName := range tableNames {
query, _ := testdata.Fixtures.ReadFile(tableName + ".sql")
db.Exec(string(query))
}

//add data
// add data
_, err = db.Exec(`
insert into users (name, email, created_at, updated_at, float_type, double_type, hobby, hobby_no_default, sports_no_default, sports) values
("John Doe", "[email protected]", NOW(), NOW(), 1.55555, 1.8729, 'running','swimming', ('SWIM,TENNIS'), ("TENNIS")),
Expand Down

0 comments on commit 05ee64c

Please sign in to comment.