diff --git a/dbutil/database.go b/dbutil/database.go index 7d63576..a9960ce 100644 --- a/dbutil/database.go +++ b/dbutil/database.go @@ -14,6 +14,8 @@ import ( "regexp" "strings" "time" + + "go.mau.fi/util/exsync" ) type Dialect int @@ -114,12 +116,16 @@ type Database struct { Dialect Dialect UpgradeTable UpgradeTable - txnCtxKey contextKey + txnCtxKey contextKey + txnDeadlockMap *exsync.Set[int64] IgnoreForeignTables bool IgnoreUnsupportedDatabase bool + DeadlockDetection bool } +var ForceDeadlockDetection bool + var positionalParamPattern = regexp.MustCompile(`\$(\d+)`) func (db *Database) mutateQuery(query string) string { @@ -144,10 +150,12 @@ func (db *Database) Child(versionTable string, upgradeTable UpgradeTable, log Da Log: log, Dialect: db.Dialect, - txnCtxKey: db.txnCtxKey, + txnCtxKey: db.txnCtxKey, + txnDeadlockMap: db.txnDeadlockMap, IgnoreForeignTables: true, IgnoreUnsupportedDatabase: db.IgnoreUnsupportedDatabase, + DeadlockDetection: db.DeadlockDetection, } } @@ -164,7 +172,10 @@ func NewWithDB(db *sql.DB, rawDialect string) (*Database, error) { IgnoreForeignTables: true, VersionTable: "version", - txnCtxKey: contextKey(nextContextKeyDatabaseTransaction.Add(1)), + txnCtxKey: contextKey(nextContextKeyDatabaseTransaction.Add(1)), + txnDeadlockMap: exsync.NewSet[int64](), + + DeadlockDetection: ForceDeadlockDetection, } wrappedDB.LoggingDB.UnderlyingExecable = db wrappedDB.LoggingDB.db = wrappedDB @@ -194,6 +205,8 @@ type PoolConfig struct { type Config struct { PoolConfig `yaml:",inline"` ReadOnlyPool PoolConfig `yaml:"ro_pool"` + + DeadlockDetection bool `yaml:"deadlock_detection"` } func (db *Database) Close() error { @@ -211,6 +224,8 @@ func (db *Database) Close() error { } func (db *Database) Configure(cfg Config) error { + db.DeadlockDetection = cfg.DeadlockDetection || ForceDeadlockDetection + if err := db.configure(db.ReadOnlyDB, cfg.ReadOnlyPool); err != nil { return err } diff --git a/dbutil/deadlock_test.go b/dbutil/deadlock_test.go new file mode 100644 index 0000000..5e12c94 --- /dev/null +++ b/dbutil/deadlock_test.go @@ -0,0 +1,140 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package dbutil_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.mau.fi/util/dbutil" + _ "go.mau.fi/util/dbutil/litestream" +) + +func initTestDB(t *testing.T) *dbutil.Database { + db, err := dbutil.NewFromConfig("", dbutil.Config{ + PoolConfig: dbutil.PoolConfig{ + Type: "sqlite3-fk-wal", + URI: ":memory:?_txlock=immediate", + MaxOpenConns: 1, + MaxIdleConns: 1, + }, + DeadlockDetection: true, + }, nil) + require.NoError(t, err) + ctx := context.Background() + _, err = db.Exec(ctx, ` + CREATE TABLE meow (id INTEGER PRIMARY KEY, value TEXT); + INSERT INTO meow (id, value) VALUES (1, 'meow'); + INSERT INTO meow (id, value) VALUES (2, 'meow 2'); + INSERT INTO meow (value) VALUES ('meow 3'); + `) + require.NoError(t, err) + return db +} + +func getMeow(ctx context.Context, db dbutil.Execable, id int) (value string, err error) { + err = db.QueryRowContext(ctx, "SELECT value FROM meow WHERE id = ?", id).Scan(&value) + return +} + +func TestDatabase_NoDeadlock(t *testing.T) { + db := initTestDB(t) + ctx := context.Background() + require.NoError(t, db.DoTxn(ctx, nil, func(ctx context.Context) error { + _, err := db.Exec(ctx, "INSERT INTO meow (value) VALUES ('meow 4');") + require.NoError(t, err) + return nil + })) + val, err := getMeow(ctx, db.Execable(ctx), 4) + require.NoError(t, err) + require.Equal(t, "meow 4", val) +} + +func TestDatabase_NoDeadlock_Goroutine(t *testing.T) { + db := initTestDB(t) + ctx := context.Background() + require.NoError(t, db.DoTxn(ctx, nil, func(ctx context.Context) error { + _, err := db.Exec(ctx, "INSERT INTO meow (value) VALUES ('meow 4');") + require.NoError(t, err) + go func() { + _, err := db.Exec(context.Background(), "INSERT INTO meow (value) VALUES ('meow 5');") + require.NoError(t, err) + }() + time.Sleep(50 * time.Millisecond) + return nil + })) + val, err := getMeow(ctx, db.Execable(ctx), 4) + require.NoError(t, err) + require.Equal(t, "meow 4", val) + val, err = getMeow(ctx, db.Execable(ctx), 5) + require.NoError(t, err) + require.Equal(t, "meow 5", val) +} + +func TestDatabase_Deadlock(t *testing.T) { + db := initTestDB(t) + ctx := context.Background() + _ = db.DoTxn(ctx, nil, func(ctx context.Context) error { + assert.PanicsWithError(t, dbutil.ErrQueryDeadlock.Error(), func() { + _, _ = db.Exec(context.Background(), "INSERT INTO meow (value) VALUES ('meow 4');") + }) + return fmt.Errorf("meow") + }) +} + +func TestDatabase_Deadlock_Acquire(t *testing.T) { + db := initTestDB(t) + ctx := context.Background() + _ = db.DoTxn(ctx, nil, func(ctx context.Context) error { + assert.PanicsWithError(t, dbutil.ErrAcquireDeadlock.Error(), func() { + _, _ = db.AcquireConn(context.Background()) + }) + return fmt.Errorf("meow") + }) +} + +func TestDatabase_Deadlock_Txn(t *testing.T) { + db := initTestDB(t) + ctx := context.Background() + _ = db.DoTxn(ctx, nil, func(ctx context.Context) error { + assert.PanicsWithError(t, dbutil.ErrTransactionDeadlock.Error(), func() { + _ = db.DoTxn(context.Background(), nil, func(ctx context.Context) error { + return nil + }) + }) + return fmt.Errorf("meow") + }) +} + +func TestDatabase_Deadlock_Child(t *testing.T) { + db := initTestDB(t) + ctx := context.Background() + childDB := db.Child("", nil, nil) + _ = db.DoTxn(ctx, nil, func(ctx context.Context) error { + assert.PanicsWithError(t, dbutil.ErrQueryDeadlock.Error(), func() { + _, _ = childDB.Exec(context.Background(), "INSERT INTO meow (value) VALUES ('meow 4');") + }) + return fmt.Errorf("meow") + }) +} + +func TestDatabase_Deadlock_Child2(t *testing.T) { + db := initTestDB(t) + ctx := context.Background() + childDB := db.Child("", nil, nil) + _ = childDB.DoTxn(ctx, nil, func(ctx context.Context) error { + assert.PanicsWithError(t, dbutil.ErrQueryDeadlock.Error(), func() { + _, _ = db.Exec(context.Background(), "INSERT INTO meow (value) VALUES ('meow 4');") + }) + return fmt.Errorf("meow") + }) +} diff --git a/dbutil/transaction.go b/dbutil/transaction.go index 9211136..d010ea5 100644 --- a/dbutil/transaction.go +++ b/dbutil/transaction.go @@ -15,6 +15,7 @@ import ( "sync/atomic" "time" + "github.com/petermattis/goid" "github.com/rs/zerolog" "go.mau.fi/util/exerrors" @@ -51,6 +52,10 @@ func (db *Database) QueryRow(ctx context.Context, query string, args ...any) *sq return db.Execable(ctx).QueryRowContext(ctx, query, args...) } +var ErrTransactionDeadlock = errors.New("attempt to start new transaction in goroutine with transaction") +var ErrQueryDeadlock = errors.New("attempt to query without context in goroutine with transaction") +var ErrAcquireDeadlock = errors.New("attempt to acquire connection without context in goroutine with transaction") + func (db *Database) BeginTx(ctx context.Context, opts *TxnOptions) (*LoggingTxn, error) { if ctx == nil { panic("BeginTx() called with nil ctx") @@ -65,6 +70,12 @@ func (db *Database) DoTxn(ctx context.Context, opts *TxnOptions, fn func(ctx con if ctx.Value(db.txnCtxKey) != nil { zerolog.Ctx(ctx).Trace().Msg("Already in a transaction, not creating a new one") return fn(ctx) + } else if db.DeadlockDetection { + goroutineID := goid.Get() + if !db.txnDeadlockMap.Add(goroutineID) { + panic(ErrTransactionDeadlock) + } + defer db.txnDeadlockMap.Remove(goroutineID) } log := zerolog.Ctx(ctx).With().Str("db_txn_id", random.String(12)).Logger() @@ -141,6 +152,9 @@ func (db *Database) Execable(ctx context.Context) Execable { if ok { return txn } + if db.DeadlockDetection && db.txnDeadlockMap.Has(goid.Get()) { + panic(ErrQueryDeadlock) + } return &db.LoggingDB } @@ -152,6 +166,9 @@ func (db *Database) AcquireConn(ctx context.Context) (Conn, error) { if ok { return nil, fmt.Errorf("cannot acquire connection while in a transaction") } + if db.DeadlockDetection && db.txnDeadlockMap.Has(goid.Get()) { + panic(ErrAcquireDeadlock) + } conn, err := db.RawDB.Conn(ctx) if err != nil { return nil, err diff --git a/go.mod b/go.mod index 764b376..5e1cb74 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.21 require ( github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/mattn/go-sqlite3 v1.14.22 + github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 golang.org/x/exp v0.0.0-20240707233637-46b078467d37 diff --git a/go.sum b/go.sum index 8db4904..abc4474 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6 h1:DUDJI8T/9NcGbbL+AWk6vIYlmQ8ZBS8LZqVre6zbkPQ= +github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=