diff --git a/dbutil/database.go b/dbutil/database.go index a15383c..b0cd9a3 100644 --- a/dbutil/database.go +++ b/dbutil/database.go @@ -14,6 +14,8 @@ import ( "regexp" "strings" "time" + + "go.mau.fi/util/exsync" ) type Dialect int @@ -114,10 +116,12 @@ type Database struct { Dialect Dialect UpgradeTable UpgradeTable - txnCtxKey contextKey + txnCtxKey contextKey + txnDeadlockMap *exsync.Set[int64] IgnoreForeignTables bool IgnoreUnsupportedDatabase bool + DeadlockDetection bool } var positionalParamPattern = regexp.MustCompile(`\$(\d+)`) @@ -144,10 +148,12 @@ func (db *Database) Child(versionTable string, upgradeTable UpgradeTable, log Da Log: log, Dialect: db.Dialect, - txnCtxKey: db.txnCtxKey, + txnCtxKey: db.txnCtxKey, + txnDeadlockMap: db.txnDeadlockMap, IgnoreForeignTables: true, IgnoreUnsupportedDatabase: db.IgnoreUnsupportedDatabase, + DeadlockDetection: db.DeadlockDetection, } } @@ -164,7 +170,8 @@ func NewWithDB(db *sql.DB, rawDialect string) (*Database, error) { IgnoreForeignTables: true, VersionTable: "version", - txnCtxKey: contextKey(nextContextKeyDatabaseTransaction.Add(1)), + txnCtxKey: contextKey(nextContextKeyDatabaseTransaction.Add(1)), + txnDeadlockMap: exsync.NewSet[int64](), } wrappedDB.LoggingDB.UnderlyingExecable = db wrappedDB.LoggingDB.db = wrappedDB @@ -194,6 +201,8 @@ type PoolConfig struct { type Config struct { PoolConfig `yaml:",inline"` ReadOnlyPool PoolConfig `yaml:"ro_pool"` + + DeadlockDetection bool `yaml:"deadlock_detection"` } func (db *Database) Close() error { @@ -211,6 +220,8 @@ func (db *Database) Close() error { } func (db *Database) Configure(cfg Config) error { + db.DeadlockDetection = cfg.DeadlockDetection + if err := db.configure(db.ReadOnlyDB, cfg.ReadOnlyPool); err != nil { return err } diff --git a/dbutil/transaction.go b/dbutil/transaction.go index 9211136..25ef1ef 100644 --- a/dbutil/transaction.go +++ b/dbutil/transaction.go @@ -15,6 +15,7 @@ import ( "sync/atomic" "time" + "github.com/petermattis/goid" "github.com/rs/zerolog" "go.mau.fi/util/exerrors" @@ -51,10 +52,21 @@ func (db *Database) QueryRow(ctx context.Context, query string, args ...any) *sq return db.Execable(ctx).QueryRowContext(ctx, query, args...) } +var ErrTransactionDeadlock = errors.New("attempt to start new transaction in goroutine with transaction") +var ErrQueryDeadlock = errors.New("attempt to query without context in goroutine with transaction") +var ErrAcquireDeadlock = errors.New("attempt to acquire connection without context in goroutine with transaction") + func (db *Database) BeginTx(ctx context.Context, opts *TxnOptions) (*LoggingTxn, error) { if ctx == nil { panic("BeginTx() called with nil ctx") } + if db.DeadlockDetection { + goroutineID := goid.Get() + if !db.txnDeadlockMap.Add(goroutineID) { + panic(ErrTransactionDeadlock) + } + defer db.txnDeadlockMap.Remove(goroutineID) + } return db.LoggingDB.BeginTx(ctx, opts) } @@ -141,6 +153,9 @@ func (db *Database) Execable(ctx context.Context) Execable { if ok { return txn } + if db.DeadlockDetection && db.txnDeadlockMap.Has(goid.Get()) { + panic(ErrQueryDeadlock) + } return &db.LoggingDB } @@ -152,6 +167,9 @@ func (db *Database) AcquireConn(ctx context.Context) (Conn, error) { if ok { return nil, fmt.Errorf("cannot acquire connection while in a transaction") } + if db.DeadlockDetection && db.txnDeadlockMap.Has(goid.Get()) { + panic(ErrAcquireDeadlock) + } conn, err := db.RawDB.Conn(ctx) if err != nil { return nil, err diff --git a/go.mod b/go.mod index 764b376..5e1cb74 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.21 require ( github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/mattn/go-sqlite3 v1.14.22 + github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 golang.org/x/exp v0.0.0-20240707233637-46b078467d37 diff --git a/go.sum b/go.sum index 8db4904..abc4474 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6 h1:DUDJI8T/9NcGbbL+AWk6vIYlmQ8ZBS8LZqVre6zbkPQ= +github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=