diff --git a/Makefile b/Makefile index b092ad5f..77de6b71 100644 --- a/Makefile +++ b/Makefile @@ -141,6 +141,9 @@ endif test: go test -count=1 -v $(shell go list ./... | grep -v "hub/test") +test-db: + go test -count=1 -timeout=6h -v ./database... + # Run Hub REST API tests. test-api: HUB_BASE_URL=$(HUB_BASE_URL) go test -count=1 -p=1 -v -failfast ./test/api/... diff --git a/database/db_test.go b/database/db_test.go index dde8c6fa..3a4590f0 100644 --- a/database/db_test.go +++ b/database/db_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/konveyor/tackle2-hub/api" "github.com/konveyor/tackle2-hub/model" "gorm.io/gorm" "k8s.io/utils/env" @@ -13,6 +14,40 @@ import ( var N, _ = env.GetInt("TEST_CONCURRENT", 10) +func TestDriver(t *testing.T) { + pid := os.Getpid() + Settings.DB.Path = fmt.Sprintf("/tmp/driver-%d.db", pid) + defer func() { + _ = os.Remove(Settings.DB.Path) + }() + db, err := Open(true) + if err != nil { + panic(err) + } + key := "driver" + m := &model.Setting{Key: key, Value: "Test"} + // insert. + err = db.Create(m).Error + if err != nil { + panic(err) + } + // update + err = db.Save(m).Error + if err != nil { + panic(err) + } + // select + err = db.First(m, m.ID).Error + if err != nil { + panic(err) + } + // delete + err = db.Delete(m).Error + if err != nil { + panic(err) + } +} + func TestConcurrent(t *testing.T) { pid := os.Getpid() Settings.DB.Path = fmt.Sprintf("/tmp/concurrent-%d.db", pid) @@ -23,12 +58,35 @@ func TestConcurrent(t *testing.T) { if err != nil { panic(err) } + + type A struct { + model.Model + } + + type B struct { + N int + model.Model + A A + AID uint + } + err = db.Migrator().AutoMigrate(&A{}, &B{}) + if err != nil { + panic(err) + } + + a := A{} + err = db.Create(&a).Error + if err != nil { + panic(err) + } + dq := make(chan int, N) for w := 0; w < N; w++ { go func(id int) { fmt.Printf("Started %d\n", id) - for n := 0; n < N*10; n++ { - m := &model.Setting{Key: fmt.Sprintf("key-%d-%d", id, n), Value: n} + for n := 0; n < N*100; n++ { + m := &B{N: n, A: a} + m.CreateUser = "Test" fmt.Printf("(%.4d) CREATE: %.4d\n", id, n) uErr := db.Create(m).Error if uErr != nil { @@ -45,6 +103,20 @@ func TestConcurrent(t *testing.T) { panic(uErr) } } + for i := 0; i < 10; i++ { + fmt.Printf("(%.4d) LIST: %.4d/%.4d\n", id, n, i) + page := api.Page{} + cursor := api.Cursor{} + mx := B{} + dbx := db.Model(mx) + dbx = dbx.Joins("A") + dbx = dbx.Limit(10) + cursor.With(dbx, page) + for cursor.Next(&mx) { + time.Sleep(time.Millisecond + 10) + fmt.Printf("(%.4d) NEXT: %.4d/%.4d ID=%d\n", id, n, i, mx.ID) + } + } for i := 0; i < 4; i++ { uErr = db.Transaction(func(tx *gorm.DB) (err error) { time.Sleep(time.Millisecond * 10) diff --git a/database/driver.go b/database/driver.go index 14bfbb1f..0c50360b 100644 --- a/database/driver.go +++ b/database/driver.go @@ -9,12 +9,16 @@ import ( "github.com/mattn/go-sqlite3" ) +// Driver is a wrapper around the SQLite driver. +// The purpose is to prevent database locked errors using +// a mutex around write operations. type Driver struct { mutex sync.Mutex wrapped driver.Driver dsn string } +// Open a connection. func (d *Driver) Open(dsn string) (conn driver.Conn, err error) { d.wrapped = &sqlite3.SQLiteDriver{} conn, err = d.wrapped.Open(dsn) @@ -28,27 +32,33 @@ func (d *Driver) Open(dsn string) (conn driver.Conn, err error) { return } +// OpenConnector opens a connection. func (d *Driver) OpenConnector(dsn string) (dc driver.Connector, err error) { d.dsn = dsn dc = d return } +// Connect opens a connection. func (d *Driver) Connect(context.Context) (conn driver.Conn, err error) { conn, err = d.Open(d.dsn) return } +// Driver returns the underlying driver. func (d *Driver) Driver() driver.Driver { return d } +// Conn is a DB connection. type Conn struct { mutex *sync.Mutex wrapped driver.Conn hasMutex bool + hasTx bool } +// Ping the DB. func (c *Conn) Ping(ctx context.Context) (err error) { if p, cast := c.wrapped.(driver.Pinger); cast { err = p.Ping(ctx) @@ -56,22 +66,35 @@ func (c *Conn) Ping(ctx context.Context) (err error) { return } +// ResetSession reset the connection. +// - Reset the Tx. +// - Release the mutex. func (c *Conn) ResetSession(ctx context.Context) (err error) { + defer func() { + c.hasTx = false + c.release() + }() if p, cast := c.wrapped.(driver.SessionResetter); cast { err = p.ResetSession(ctx) } return } + +// IsValid returns true when the connection is valid. +// When true, the connection may be reused by the sql package. func (c *Conn) IsValid() (b bool) { + b = true if p, cast := c.wrapped.(driver.Validator); cast { b = p.IsValid() } return } +// QueryContext execute a query with context. func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Rows, err error) { if c.needsMutex(query) { c.acquire() + defer c.release() } if p, cast := c.wrapped.(driver.QueryerContext); cast { r, err = p.QueryContext(ctx, query, args) @@ -79,24 +102,34 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam return } -func (c *Conn) PrepareContext(ctx context.Context, query string) (s driver.Stmt, err error) { +// ExecContext executes an SQL/DDL statement with context. +func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) { if c.needsMutex(query) { c.acquire() + defer c.release() } - if p, cast := c.wrapped.(driver.ConnPrepareContext); cast { - s, err = p.PrepareContext(ctx, query) + if p, cast := c.wrapped.(driver.ExecerContext); cast { + r, err = p.ExecContext(ctx, query, args) } return } -func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) { +// Begin a transaction. +func (c *Conn) Begin() (tx driver.Tx, err error) { c.acquire() - if p, cast := c.wrapped.(driver.ExecerContext); cast { - r, err = p.ExecContext(ctx, query, args) + tx, err = c.wrapped.Begin() + if err != nil { + return + } + tx = &Tx{ + conn: c, + wrapped: tx, } + c.hasTx = true return } +// BeginTx begins a transaction. func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) { c.acquire() if p, cast := c.wrapped.(driver.ConnBeginTx); cast { @@ -104,32 +137,49 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx } else { tx, err = c.wrapped.Begin() } + tx = &Tx{ + conn: c, + wrapped: tx, + } + c.hasTx = true return } -func (c *Conn) Prepare(query string) (s driver.Stmt, err error) { - if c.needsMutex(query) { - c.acquire() +// Prepare a statement. +func (c *Conn) Prepare(query string) (stmt driver.Stmt, err error) { + stmt, err = c.wrapped.Prepare(query) + stmt = &Stmt{ + conn: c, + wrapped: stmt, + query: query, } - s, err = c.wrapped.Prepare(query) return } -func (c *Conn) Close() (err error) { - err = c.wrapped.Close() - c.release() +// PrepareContext prepares a statement with context. +func (c *Conn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) { + if p, cast := c.wrapped.(driver.ConnPrepareContext); cast { + stmt, err = p.PrepareContext(ctx, query) + } else { + stmt, err = c.Prepare(query) + } + stmt = &Stmt{ + conn: c, + wrapped: stmt, + query: query, + } return } -func (c *Conn) Begin() (tx driver.Tx, err error) { - c.acquire() - tx, err = c.wrapped.Begin() - if err != nil { - return - } +// Close the connection. +func (c *Conn) Close() (err error) { + err = c.wrapped.Close() + c.hasMutex = false + c.release() return } +// needsMutex returns true when the query should is a write operation. func (c *Conn) needsMutex(query string) (matched bool) { if query == "" { return @@ -144,6 +194,9 @@ func (c *Conn) needsMutex(query string) (matched bool) { return } +// acquire the mutex. +// Since Locks are not reentrant, the mutex is acquired +// only if this connection has not already acquired it. func (c *Conn) acquire() { if !c.hasMutex { c.mutex.Lock() @@ -151,9 +204,127 @@ func (c *Conn) acquire() { } } +// release the mutex. +// Released only when: +// - This connection has acquired it +// - Not in a transaction. func (c *Conn) release() { - if c.hasMutex { + if c.hasMutex && !c.hasTx { c.mutex.Unlock() c.hasMutex = false } } + +// endTx report transaction has ended. +func (c *Conn) endTx() { + c.hasTx = false +} + +// Stmt is a SQL/DDL statement. +type Stmt struct { + wrapped driver.Stmt + conn *Conn + query string +} + +// Close the statement. +func (s *Stmt) Close() (err error) { + err = s.wrapped.Close() + return +} + +// NumInput returns the number of (query) input parameters. +func (s *Stmt) NumInput() (n int) { + n = s.wrapped.NumInput() + return +} + +// Exec executes the statement. +func (s *Stmt) Exec(args []driver.Value) (r driver.Result, err error) { + if s.needsMutex() { + s.conn.acquire() + defer s.conn.release() + } + r, err = s.wrapped.Exec(args) + return +} + +// ExecContext executes the statement with context. +func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) { + if s.needsMutex() { + s.conn.acquire() + defer s.conn.release() + } + if p, cast := s.wrapped.(driver.StmtExecContext); cast { + r, err = p.ExecContext(ctx, args) + } else { + r, err = s.Exec(s.values(args)) + } + return +} + +// Query executes a query. +func (s *Stmt) Query(args []driver.Value) (r driver.Rows, err error) { + if s.needsMutex() { + s.conn.acquire() + defer s.conn.release() + } + r, err = s.wrapped.Query(args) + return +} + +// QueryContext executes a query. +func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (r driver.Rows, err error) { + if s.needsMutex() { + s.conn.acquire() + defer s.conn.release() + } + if p, cast := s.wrapped.(driver.StmtQueryContext); cast { + r, err = p.QueryContext(ctx, args) + } else { + r, err = s.Query(s.values(args)) + } + return +} + +// values converts named-values to values. +func (s *Stmt) values(named []driver.NamedValue) (out []driver.Value) { + for i := range named { + out = append(out, named[i].Value) + } + return +} + +// needsMutex returns true when the query should is a write operation. +func (s *Stmt) needsMutex() (matched bool) { + matched = s.conn.needsMutex(s.query) + return +} + +// Tx is a transaction. +type Tx struct { + wrapped driver.Tx + conn *Conn +} + +// Commit the transaction. +// Releases the mutex. +func (t *Tx) Commit() (err error) { + defer func() { + t.conn.endTx() + t.conn.release() + }() + err = t.wrapped.Commit() + return +} + +// Rollback the transaction. +// Releases the mutex. +func (t *Tx) Rollback() (err error) { + defer func() { + t.conn.endTx() + t.conn.release() + }() + err = t.wrapped.Rollback() + return +} diff --git a/database/pkg.go b/database/pkg.go index 6e9c5eb2..83cfc4f3 100644 --- a/database/pkg.go +++ b/database/pkg.go @@ -18,7 +18,7 @@ var log = logr.WithName("db") var Settings = &settings.Settings const ( - ConnectionString = "file:%s?_journal=WAL&_timeout=100" + ConnectionString = "file:%s?_journal=WAL" FKsOn = "&_foreign_keys=yes" FKsOff = "&_foreign_keys=no" )