diff --git a/database/db_test.go b/database/db_test.go index 36555f21b..37102df91 100644 --- a/database/db_test.go +++ b/database/db_test.go @@ -1,19 +1,24 @@ package database import ( - "encoding/json" "fmt" "os" "testing" + "time" "github.com/konveyor/tackle2-hub/model" + "gorm.io/gorm" + "k8s.io/utils/env" ) -var N = 800 +var N, _ = env.GetInt("TEST_CONCURRENT", 10) func TestConcurrent(t *testing.T) { - Settings.DB.Path = "/tmp/concurrent.db" - _ = os.Remove(Settings.DB.Path) + pid := os.Getpid() + Settings.DB.Path = fmt.Sprintf("/tmp/concurrent-%d.db", pid) + defer func() { + _ = os.Remove(Settings.DB.Path) + }() db, err := Open(true) if err != nil { panic(err) @@ -22,13 +27,39 @@ func TestConcurrent(t *testing.T) { for w := 0; w < N; w++ { go func(id int) { fmt.Printf("Started %d\n", id) - for n := 0; n < N; n++ { - v, _ := json.Marshal(fmt.Sprintf("Test-%d", n)) - m := &model.Setting{Key: fmt.Sprintf("key-%d-%d", id, n), Value: v} + for n := 0; n < N*10; n++ { + m := &model.Setting{Key: fmt.Sprintf("key-%d-%d", id, n), Value: n} + fmt.Printf("(%.4d) CREATE: %.4d\n", id, n) uErr := db.Create(m).Error if uErr != nil { panic(uErr) } + uErr = db.Save(m).Error + if uErr != nil { + panic(uErr) + } + for i := 0; i < 10; i++ { + fmt.Printf("(%.4d) READ: %.4d/%.4d\n", id, n, i) + uErr = db.First(m).Error + if uErr != nil { + panic(uErr) + } + } + for i := 0; i < 4; i++ { + uErr = db.Transaction(func(tx *gorm.DB) (err error) { + time.Sleep(time.Millisecond * 10) + for i := 0; i < 3; i++ { + err = tx.Save(m).Error + if err != nil { + break + } + } + return + }) + if uErr != nil { + panic(uErr) + } + } } dq <- id }(w) diff --git a/database/driver.go b/database/driver.go index 708228530..14bfbb1fa 100644 --- a/database/driver.go +++ b/database/driver.go @@ -44,9 +44,9 @@ func (d *Driver) Driver() driver.Driver { } type Conn struct { - mutex *sync.Mutex - wrapped driver.Conn - tx driver.Tx + mutex *sync.Mutex + wrapped driver.Conn + hasMutex bool } func (c *Conn) Ping(ctx context.Context) (err error) { @@ -70,9 +70,8 @@ func (c *Conn) IsValid() (b bool) { } func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Rows, err error) { - if c.tx == nil { - c.mutex.Lock() - defer c.mutex.Unlock() + if c.needsMutex(query) { + c.acquire() } if p, cast := c.wrapped.(driver.QueryerContext); cast { r, err = p.QueryContext(ctx, query, args) @@ -81,29 +80,17 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam } func (c *Conn) PrepareContext(ctx context.Context, query string) (s driver.Stmt, err error) { + if c.needsMutex(query) { + c.acquire() + } if p, cast := c.wrapped.(driver.ConnPrepareContext); cast { s, err = p.PrepareContext(ctx, query) } - if err != nil { - return - } - stmtLocked := c.stmtLocked(query) - s = &Stmt{ - mutex: c.mutex, - locked: stmtLocked, - wrapped: s, - } - if stmtLocked { - c.mutex.Lock() - } return } func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) { - if c.tx == nil { - c.mutex.Lock() - defer c.mutex.Unlock() - } + c.acquire() if p, cast := c.wrapped.(driver.ExecerContext); cast { r, err = p.ExecContext(ctx, query, args) } @@ -111,61 +98,40 @@ func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.Name } func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) { + c.acquire() if p, cast := c.wrapped.(driver.ConnBeginTx); cast { tx, err = p.BeginTx(ctx, opts) } else { tx, err = c.wrapped.Begin() } - if err != nil { - return - } - tx = &Tx{ - mutex: c.mutex, - wrapped: tx, - } - c.tx = tx - c.mutex.Lock() return } func (c *Conn) Prepare(query string) (s driver.Stmt, err error) { - s, err = c.wrapped.Prepare(query) - if err != nil { - return - } - stmtLocked := c.stmtLocked(query) - s = &Stmt{ - mutex: c.mutex, - locked: stmtLocked, - wrapped: s, - } - if stmtLocked { - c.mutex.Lock() + if c.needsMutex(query) { + c.acquire() } + s, err = c.wrapped.Prepare(query) return } func (c *Conn) Close() (err error) { err = c.wrapped.Close() + c.release() return } func (c *Conn) Begin() (tx driver.Tx, err error) { + c.acquire() tx, err = c.wrapped.Begin() if err != nil { return } - tx = &Tx{ - mutex: c.mutex, - wrapped: tx, - } - c.tx = tx - c.mutex.Lock() return } -func (c *Conn) stmtLocked(query string) (matched bool) { - if c.tx != nil || query == "" { +func (c *Conn) needsMutex(query string) (matched bool) { + if query == "" { return } query = strings.ToUpper(query) @@ -178,63 +144,16 @@ func (c *Conn) stmtLocked(query string) (matched bool) { return } -type Tx struct { - mutex *sync.Mutex - wrapped driver.Tx -} - -func (tx *Tx) Commit() (err error) { - defer func() { - tx.mutex.Unlock() - }() - err = tx.wrapped.Commit() - return -} -func (tx *Tx) Rollback() (err error) { - defer func() { - tx.mutex.Unlock() - }() - err = tx.wrapped.Rollback() - return -} - -type Stmt struct { - mutex *sync.Mutex - wrapped driver.Stmt - locked bool -} - -func (s *Stmt) Close() (err error) { - if s.locked { - s.mutex.Unlock() - } - err = s.wrapped.Close() - return -} -func (s *Stmt) NumInput() (n int) { - n = s.wrapped.NumInput() - return -} -func (s *Stmt) Exec(args []driver.Value) (r driver.Result, err error) { - r, err = s.wrapped.Exec(args) - return -} - -func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) { - if p, cast := s.wrapped.(driver.StmtExecContext); cast { - r, err = p.ExecContext(ctx, args) +func (c *Conn) acquire() { + if !c.hasMutex { + c.mutex.Lock() + c.hasMutex = true } - return } -func (s *Stmt) Query(args []driver.Value) (r driver.Rows, err error) { - r, err = s.wrapped.Query(args) - return -} - -func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (r driver.Rows, err error) { - if p, cast := s.wrapped.(driver.StmtQueryContext); cast { - r, err = p.QueryContext(ctx, args) +func (c *Conn) release() { + if c.hasMutex { + c.mutex.Unlock() + c.hasMutex = false } - return } diff --git a/database/pkg.go b/database/pkg.go index 3e9adfdae..b513c1c6c 100644 --- a/database/pkg.go +++ b/database/pkg.go @@ -18,7 +18,7 @@ var log = logr.WithName("db") var Settings = &settings.Settings const ( - ConnectionString = "file:%s?_journal=WAL" + ConnectionString = "file:%s?_journal=WAL&_timeout=100" FKsOn = "&_foreign_keys=yes" FKsOff = "&_foreign_keys=no" )