Skip to content

Commit

Permalink
🐛 Fixes database-locked. (konveyor#747)
Browse files Browse the repository at this point in the history
Seems the underlying sqlite driver keeps the lock until the connection
is closed.
The `Conn` acquires the mutex and holds it until the connection is
closed. The `Tx` and `Stmt` are no longer necessary.

---------

Signed-off-by: Jeff Ortel <[email protected]>
  • Loading branch information
jortel authored Aug 21, 2024
1 parent f4a4874 commit cd7d9ed
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 114 deletions.
45 changes: 38 additions & 7 deletions database/db_test.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
package database

import (
"encoding/json"
"fmt"
"os"
"testing"
"time"

"github.com/konveyor/tackle2-hub/model"
"gorm.io/gorm"
"k8s.io/utils/env"
)

var N = 800
var N, _ = env.GetInt("TEST_CONCURRENT", 10)

func TestConcurrent(t *testing.T) {
Settings.DB.Path = "/tmp/concurrent.db"
_ = os.Remove(Settings.DB.Path)
pid := os.Getpid()
Settings.DB.Path = fmt.Sprintf("/tmp/concurrent-%d.db", pid)
defer func() {
_ = os.Remove(Settings.DB.Path)
}()
db, err := Open(true)
if err != nil {
panic(err)
Expand All @@ -22,13 +27,39 @@ func TestConcurrent(t *testing.T) {
for w := 0; w < N; w++ {
go func(id int) {
fmt.Printf("Started %d\n", id)
for n := 0; n < N; n++ {
v, _ := json.Marshal(fmt.Sprintf("Test-%d", n))
m := &model.Setting{Key: fmt.Sprintf("key-%d-%d", id, n), Value: v}
for n := 0; n < N*10; n++ {
m := &model.Setting{Key: fmt.Sprintf("key-%d-%d", id, n), Value: n}
fmt.Printf("(%.4d) CREATE: %.4d\n", id, n)
uErr := db.Create(m).Error
if uErr != nil {
panic(uErr)
}
uErr = db.Save(m).Error
if uErr != nil {
panic(uErr)
}
for i := 0; i < 10; i++ {
fmt.Printf("(%.4d) READ: %.4d/%.4d\n", id, n, i)
uErr = db.First(m).Error
if uErr != nil {
panic(uErr)
}
}
for i := 0; i < 4; i++ {
uErr = db.Transaction(func(tx *gorm.DB) (err error) {
time.Sleep(time.Millisecond * 10)
for i := 0; i < 3; i++ {
err = tx.Save(m).Error
if err != nil {
break
}
}
return
})
if uErr != nil {
panic(uErr)
}
}
}
dq <- id
}(w)
Expand Down
131 changes: 25 additions & 106 deletions database/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ func (d *Driver) Driver() driver.Driver {
}

type Conn struct {
mutex *sync.Mutex
wrapped driver.Conn
tx driver.Tx
mutex *sync.Mutex
wrapped driver.Conn
hasMutex bool
}

func (c *Conn) Ping(ctx context.Context) (err error) {
Expand All @@ -70,9 +70,8 @@ func (c *Conn) IsValid() (b bool) {
}

func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Rows, err error) {
if c.tx == nil {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.needsMutex(query) {
c.acquire()
}
if p, cast := c.wrapped.(driver.QueryerContext); cast {
r, err = p.QueryContext(ctx, query, args)
Expand All @@ -81,91 +80,58 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam
}

func (c *Conn) PrepareContext(ctx context.Context, query string) (s driver.Stmt, err error) {
if c.needsMutex(query) {
c.acquire()
}
if p, cast := c.wrapped.(driver.ConnPrepareContext); cast {
s, err = p.PrepareContext(ctx, query)
}
if err != nil {
return
}
stmtLocked := c.stmtLocked(query)
s = &Stmt{
mutex: c.mutex,
locked: stmtLocked,
wrapped: s,
}
if stmtLocked {
c.mutex.Lock()
}
return
}

func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) {
if c.tx == nil {
c.mutex.Lock()
defer c.mutex.Unlock()
}
c.acquire()
if p, cast := c.wrapped.(driver.ExecerContext); cast {
r, err = p.ExecContext(ctx, query, args)
}
return
}

func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) {
c.acquire()
if p, cast := c.wrapped.(driver.ConnBeginTx); cast {
tx, err = p.BeginTx(ctx, opts)
} else {
tx, err = c.wrapped.Begin()
}
if err != nil {
return
}
tx = &Tx{
mutex: c.mutex,
wrapped: tx,
}
c.tx = tx
c.mutex.Lock()
return
}

func (c *Conn) Prepare(query string) (s driver.Stmt, err error) {
s, err = c.wrapped.Prepare(query)
if err != nil {
return
}
stmtLocked := c.stmtLocked(query)
s = &Stmt{
mutex: c.mutex,
locked: stmtLocked,
wrapped: s,
}
if stmtLocked {
c.mutex.Lock()
if c.needsMutex(query) {
c.acquire()
}
s, err = c.wrapped.Prepare(query)
return
}

func (c *Conn) Close() (err error) {
err = c.wrapped.Close()
c.release()
return
}

func (c *Conn) Begin() (tx driver.Tx, err error) {
c.acquire()
tx, err = c.wrapped.Begin()
if err != nil {
return
}
tx = &Tx{
mutex: c.mutex,
wrapped: tx,
}
c.tx = tx
c.mutex.Lock()
return
}

func (c *Conn) stmtLocked(query string) (matched bool) {
if c.tx != nil || query == "" {
func (c *Conn) needsMutex(query string) (matched bool) {
if query == "" {
return
}
query = strings.ToUpper(query)
Expand All @@ -178,63 +144,16 @@ func (c *Conn) stmtLocked(query string) (matched bool) {
return
}

type Tx struct {
mutex *sync.Mutex
wrapped driver.Tx
}

func (tx *Tx) Commit() (err error) {
defer func() {
tx.mutex.Unlock()
}()
err = tx.wrapped.Commit()
return
}
func (tx *Tx) Rollback() (err error) {
defer func() {
tx.mutex.Unlock()
}()
err = tx.wrapped.Rollback()
return
}

type Stmt struct {
mutex *sync.Mutex
wrapped driver.Stmt
locked bool
}

func (s *Stmt) Close() (err error) {
if s.locked {
s.mutex.Unlock()
}
err = s.wrapped.Close()
return
}
func (s *Stmt) NumInput() (n int) {
n = s.wrapped.NumInput()
return
}
func (s *Stmt) Exec(args []driver.Value) (r driver.Result, err error) {
r, err = s.wrapped.Exec(args)
return
}

func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) {
if p, cast := s.wrapped.(driver.StmtExecContext); cast {
r, err = p.ExecContext(ctx, args)
func (c *Conn) acquire() {
if !c.hasMutex {
c.mutex.Lock()
c.hasMutex = true
}
return
}

func (s *Stmt) Query(args []driver.Value) (r driver.Rows, err error) {
r, err = s.wrapped.Query(args)
return
}

func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (r driver.Rows, err error) {
if p, cast := s.wrapped.(driver.StmtQueryContext); cast {
r, err = p.QueryContext(ctx, args)
func (c *Conn) release() {
if c.hasMutex {
c.mutex.Unlock()
c.hasMutex = false
}
return
}
2 changes: 1 addition & 1 deletion database/pkg.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ var log = logr.WithName("db")
var Settings = &settings.Settings

const (
ConnectionString = "file:%s?_journal=WAL"
ConnectionString = "file:%s?_journal=WAL&_timeout=100"
FKsOn = "&_foreign_keys=yes"
FKsOff = "&_foreign_keys=no"
)
Expand Down

0 comments on commit cd7d9ed

Please sign in to comment.