Skip to content

Commit

Permalink
Fix database-locked error.
Browse files Browse the repository at this point in the history
Signed-off-by: Jeff Ortel <[email protected]>
  • Loading branch information
jortel committed Aug 20, 2024
1 parent f4a4874 commit 4a2e5ac
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 107 deletions.
19 changes: 18 additions & 1 deletion database/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ import (
"testing"

"github.com/konveyor/tackle2-hub/model"
"k8s.io/utils/env"
)

var N = 800
var N, _ = env.GetInt("TEST_CONCURRENT", 10)

func TestConcurrent(t *testing.T) {
Settings.DB.Path = "/tmp/concurrent.db"
Expand All @@ -25,10 +26,26 @@ func TestConcurrent(t *testing.T) {
for n := 0; n < N; n++ {
v, _ := json.Marshal(fmt.Sprintf("Test-%d", n))
m := &model.Setting{Key: fmt.Sprintf("key-%d-%d", id, n), Value: v}
fmt.Printf("(%.4d) CREATE: %.4d\n", id, n)
uErr := db.Create(m).Error
if uErr != nil {
panic(uErr)
}
for i := 0; i < 10; i++ {
fmt.Printf("(%.4d) BEGIN: %.4d/%.4d\n", id, n, i)
tx := db.Begin()
fmt.Printf("(%.4d) FIRST: %.4d/%.4d\n", id, n, i)
uErr = tx.First(m).Error
if uErr != nil {
panic(uErr)
}
fmt.Printf("(%.4d) SAVE: %.4d/%.4d\n", id, n, i)
uErr = tx.Save(m).Error
if uErr != nil {
panic(uErr)
}
tx.Commit()
}
}
dq <- id
}(w)
Expand Down
131 changes: 25 additions & 106 deletions database/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ func (d *Driver) Driver() driver.Driver {
}

type Conn struct {
mutex *sync.Mutex
wrapped driver.Conn
tx driver.Tx
mutex *sync.Mutex
wrapped driver.Conn
hasMutex bool
}

func (c *Conn) Ping(ctx context.Context) (err error) {
Expand All @@ -70,9 +70,8 @@ func (c *Conn) IsValid() (b bool) {
}

func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Rows, err error) {
if c.tx == nil {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.needsMutex(query) {
c.acquire()
}
if p, cast := c.wrapped.(driver.QueryerContext); cast {
r, err = p.QueryContext(ctx, query, args)
Expand All @@ -81,91 +80,58 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam
}

func (c *Conn) PrepareContext(ctx context.Context, query string) (s driver.Stmt, err error) {
if c.needsMutex(query) {
c.acquire()
}
if p, cast := c.wrapped.(driver.ConnPrepareContext); cast {
s, err = p.PrepareContext(ctx, query)
}
if err != nil {
return
}
stmtLocked := c.stmtLocked(query)
s = &Stmt{
mutex: c.mutex,
locked: stmtLocked,
wrapped: s,
}
if stmtLocked {
c.mutex.Lock()
}
return
}

func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) {
if c.tx == nil {
c.mutex.Lock()
defer c.mutex.Unlock()
}
c.acquire()
if p, cast := c.wrapped.(driver.ExecerContext); cast {
r, err = p.ExecContext(ctx, query, args)
}
return
}

func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) {
c.acquire()
if p, cast := c.wrapped.(driver.ConnBeginTx); cast {
tx, err = p.BeginTx(ctx, opts)
} else {
tx, err = c.wrapped.Begin()
}
if err != nil {
return
}
tx = &Tx{
mutex: c.mutex,
wrapped: tx,
}
c.tx = tx
c.mutex.Lock()
return
}

func (c *Conn) Prepare(query string) (s driver.Stmt, err error) {
s, err = c.wrapped.Prepare(query)
if err != nil {
return
}
stmtLocked := c.stmtLocked(query)
s = &Stmt{
mutex: c.mutex,
locked: stmtLocked,
wrapped: s,
}
if stmtLocked {
c.mutex.Lock()
if c.needsMutex(query) {
c.acquire()
}
s, err = c.wrapped.Prepare(query)
return
}

func (c *Conn) Close() (err error) {
err = c.wrapped.Close()
c.release()
return
}

func (c *Conn) Begin() (tx driver.Tx, err error) {
c.acquire()
tx, err = c.wrapped.Begin()
if err != nil {
return
}
tx = &Tx{
mutex: c.mutex,
wrapped: tx,
}
c.tx = tx
c.mutex.Lock()
return
}

func (c *Conn) stmtLocked(query string) (matched bool) {
if c.tx != nil || query == "" {
func (c *Conn) needsMutex(query string) (matched bool) {
if query == "" {
return
}
query = strings.ToUpper(query)
Expand All @@ -178,63 +144,16 @@ func (c *Conn) stmtLocked(query string) (matched bool) {
return
}

type Tx struct {
mutex *sync.Mutex
wrapped driver.Tx
}

func (tx *Tx) Commit() (err error) {
defer func() {
tx.mutex.Unlock()
}()
err = tx.wrapped.Commit()
return
}
func (tx *Tx) Rollback() (err error) {
defer func() {
tx.mutex.Unlock()
}()
err = tx.wrapped.Rollback()
return
}

type Stmt struct {
mutex *sync.Mutex
wrapped driver.Stmt
locked bool
}

func (s *Stmt) Close() (err error) {
if s.locked {
s.mutex.Unlock()
}
err = s.wrapped.Close()
return
}
func (s *Stmt) NumInput() (n int) {
n = s.wrapped.NumInput()
return
}
func (s *Stmt) Exec(args []driver.Value) (r driver.Result, err error) {
r, err = s.wrapped.Exec(args)
return
}

func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) {
if p, cast := s.wrapped.(driver.StmtExecContext); cast {
r, err = p.ExecContext(ctx, args)
func (c *Conn) acquire() {
if !c.hasMutex {
c.mutex.Lock()
c.hasMutex = true
}
return
}

func (s *Stmt) Query(args []driver.Value) (r driver.Rows, err error) {
r, err = s.wrapped.Query(args)
return
}

func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (r driver.Rows, err error) {
if p, cast := s.wrapped.(driver.StmtQueryContext); cast {
r, err = p.QueryContext(ctx, args)
func (c *Conn) release() {
if c.hasMutex {
c.mutex.Unlock()
c.hasMutex = false
}
return
}

0 comments on commit 4a2e5ac

Please sign in to comment.