Skip to content

Commit

Permalink
Merge pull request #64 from olachat/yinloo/tx-engine-2
Browse files Browse the repository at this point in the history
refactor tx_engine
  • Loading branch information
yinloo-ola authored Nov 26, 2024
2 parents 927feed + aaf97db commit d77b375
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 14 deletions.
7 changes: 7 additions & 0 deletions coredb/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ func BeginTx(ctx context.Context, dbname string, opts *sql.TxOptions) (tx *sql.T
return myDB.BeginTx(ctx, opts)
}

// Conn returns a single connection by either opening a new connection or returning an existing connection from the connection pool. Conn will block until either a connection is returned or ctx is canceled. Queries run on the same Conn will be run in the same database session.
//
// Every Conn must be returned to the database pool after use by calling [Conn.Close].
func Conn(ctx context.Context, dbname string, mode DBMode) (*sql.Conn, error) {
return getDB(dbname, DBModeWrite).Conn(ctx)
}

// DefaultTxOpts is package variable with default transaction level
var DefaultTxOpts = sql.TxOptions{
Isolation: sql.LevelDefault,
Expand Down
103 changes: 89 additions & 14 deletions coredb/txengine/tx_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@ import (
"database/sql"
"errors"
"fmt"
"log"
"time"

"github.com/olachat/gola/v2/coredb"
)

type TypedTx[T any] sql.Tx
type Tx sql.Tx
type (
TypedTx[T any] sql.Tx
Tx sql.Tx
)

func WithTypedTx[T any](tx *sql.Tx) *TypedTx[T] {
return (*TypedTx[T])(tx)
Expand All @@ -20,32 +24,103 @@ func WithTx(tx *sql.Tx) *Tx {
return (*Tx)(tx)
}

func StartTx(ctx context.Context, tx *sql.Tx, fn func(ctx context.Context, sqlTx *sql.Tx) error) (err error) {
func RunTransaction(ctx context.Context, dbName string, fn func(ctx context.Context, sqlTx *sql.Tx) error) (err error) {
tx, err := coredb.BeginTx(ctx, dbName, &coredb.DefaultTxOpts)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
err = runTransaction(ctx, tx, nil, fn)
return
}

func runTransaction(ctx context.Context, tx *sql.Tx, conn *sql.Conn, fn func(ctx context.Context, sqlTx *sql.Tx) error) (err error) {
if tx == nil && conn == nil {
return errors.New("wrong usage. tx and conn cannot both be nil")
}
if tx == nil {
tx, err = conn.BeginTx(ctx, &coredb.DefaultTxOpts)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
}
defer func() {
//nolint:gocritic
if r := recover(); r != nil {
_ = tx.Rollback()
errRollback := tx.Rollback()
var ok bool
err, ok = r.(error)
errPanic, ok := r.(error)
if !ok {
err = fmt.Errorf("%v", r)
errPanic = fmt.Errorf("%v", r)
}
err = errors.Join(err, errPanic, errRollback)
} else if err != nil {
errRollback := tx.Rollback()
if errors.Is(errRollback, sql.ErrTxDone) && ctx.Err() != nil {
errRollback = nil
}
if errRollback != nil {
err = fmt.Errorf("%v encountered. but rollback failed: %w", err, errRollback)
}
err = errors.Join(err, errRollback)
} else {
err = tx.Commit()
errCommit := tx.Commit()
err = errors.Join(err, errCommit)
}
}()

err = fn(ctx, tx)

return err
return
}

const lockTimeoutBuffer = 5 * time.Millisecond

func RunTransactionWithLock(ctx context.Context, dbName string, lock string, durationInSec int, fn func(ctx context.Context, sqlTx *sql.Tx) error) (err error) {
connCtx, cancel := context.WithTimeout(ctx, time.Duration(durationInSec)*time.Second+lockTimeoutBuffer)
defer cancel()

conn, err := coredb.Conn(connCtx, dbName, coredb.DBModeWrite)
if err != nil {
return fmt.Errorf("fail to get db connection: %w", err)
}

defer func() {
if conn != nil {
errCloseConn := conn.Close()
if errCloseConn != nil {
log.Printf("fail to close db connection: %#v", errCloseConn)
err = errors.Join(err, errCloseConn)
}
}
}()

{
var res int
err = conn.QueryRowContext(ctx, "select get_lock(?,?)", lock, durationInSec).Scan(&res)
if err != nil {
return fmt.Errorf("get_lock failed: %w", err)
}
if res != 1 {
return newLockError(lock, durationInSec)
}
}

defer func() {
var res int
errRelease := conn.QueryRowContext(ctx, "select release_lock(?)", lock).Scan(&res)
if errRelease != nil {
err = errors.Join(err, fmt.Errorf("release_lock failed: %w", errRelease))
return
}
if res != 1 {
err = errors.Join(err, newReleaseLockError(lock, durationInSec))
}
}()

err = runTransaction(ctx, nil, conn, fn)
return
}

func newLockError(lock string, durationInSec int) error {
return fmt.Errorf("fail to acquire lock: %s, durationInSec: %d", lock, durationInSec)
}

func newReleaseLockError(lock string, durationInSec int) error {
return fmt.Errorf("fail to release lock: %s, durationInSec: %d", lock, durationInSec)
}

// FindOne returns a row from given table type with where query.
Expand Down

0 comments on commit d77b375

Please sign in to comment.