Skip to content

Commit

Permalink
dbutil: add option to detect database calls with incorrect context (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir authored Aug 11, 2024
1 parent c1b6f86 commit eba5f6e
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 3 deletions.
21 changes: 18 additions & 3 deletions dbutil/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"regexp"
"strings"
"time"

"go.mau.fi/util/exsync"
)

type Dialect int
Expand Down Expand Up @@ -114,12 +116,16 @@ type Database struct {
Dialect Dialect
UpgradeTable UpgradeTable

txnCtxKey contextKey
txnCtxKey contextKey
txnDeadlockMap *exsync.Set[int64]

IgnoreForeignTables bool
IgnoreUnsupportedDatabase bool
DeadlockDetection bool
}

var ForceDeadlockDetection bool

var positionalParamPattern = regexp.MustCompile(`\$(\d+)`)

func (db *Database) mutateQuery(query string) string {
Expand All @@ -144,10 +150,12 @@ func (db *Database) Child(versionTable string, upgradeTable UpgradeTable, log Da
Log: log,
Dialect: db.Dialect,

txnCtxKey: db.txnCtxKey,
txnCtxKey: db.txnCtxKey,
txnDeadlockMap: db.txnDeadlockMap,

IgnoreForeignTables: true,
IgnoreUnsupportedDatabase: db.IgnoreUnsupportedDatabase,
DeadlockDetection: db.DeadlockDetection,
}
}

Expand All @@ -164,7 +172,10 @@ func NewWithDB(db *sql.DB, rawDialect string) (*Database, error) {
IgnoreForeignTables: true,
VersionTable: "version",

txnCtxKey: contextKey(nextContextKeyDatabaseTransaction.Add(1)),
txnCtxKey: contextKey(nextContextKeyDatabaseTransaction.Add(1)),
txnDeadlockMap: exsync.NewSet[int64](),

DeadlockDetection: ForceDeadlockDetection,
}
wrappedDB.LoggingDB.UnderlyingExecable = db
wrappedDB.LoggingDB.db = wrappedDB
Expand Down Expand Up @@ -194,6 +205,8 @@ type PoolConfig struct {
type Config struct {
PoolConfig `yaml:",inline"`
ReadOnlyPool PoolConfig `yaml:"ro_pool"`

DeadlockDetection bool `yaml:"deadlock_detection"`
}

func (db *Database) Close() error {
Expand All @@ -211,6 +224,8 @@ func (db *Database) Close() error {
}

func (db *Database) Configure(cfg Config) error {
db.DeadlockDetection = cfg.DeadlockDetection || ForceDeadlockDetection

if err := db.configure(db.ReadOnlyDB, cfg.ReadOnlyPool); err != nil {
return err
}
Expand Down
140 changes: 140 additions & 0 deletions dbutil/deadlock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

package dbutil_test

import (
"context"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"go.mau.fi/util/dbutil"
_ "go.mau.fi/util/dbutil/litestream"
)

func initTestDB(t *testing.T) *dbutil.Database {
db, err := dbutil.NewFromConfig("", dbutil.Config{
PoolConfig: dbutil.PoolConfig{
Type: "sqlite3-fk-wal",
URI: ":memory:?_txlock=immediate",
MaxOpenConns: 1,
MaxIdleConns: 1,
},
DeadlockDetection: true,
}, nil)
require.NoError(t, err)
ctx := context.Background()
_, err = db.Exec(ctx, `
CREATE TABLE meow (id INTEGER PRIMARY KEY, value TEXT);
INSERT INTO meow (id, value) VALUES (1, 'meow');
INSERT INTO meow (id, value) VALUES (2, 'meow 2');
INSERT INTO meow (value) VALUES ('meow 3');
`)
require.NoError(t, err)
return db
}

func getMeow(ctx context.Context, db dbutil.Execable, id int) (value string, err error) {
err = db.QueryRowContext(ctx, "SELECT value FROM meow WHERE id = ?", id).Scan(&value)
return
}

func TestDatabase_NoDeadlock(t *testing.T) {
db := initTestDB(t)
ctx := context.Background()
require.NoError(t, db.DoTxn(ctx, nil, func(ctx context.Context) error {
_, err := db.Exec(ctx, "INSERT INTO meow (value) VALUES ('meow 4');")
require.NoError(t, err)
return nil
}))
val, err := getMeow(ctx, db.Execable(ctx), 4)
require.NoError(t, err)
require.Equal(t, "meow 4", val)
}

func TestDatabase_NoDeadlock_Goroutine(t *testing.T) {
db := initTestDB(t)
ctx := context.Background()
require.NoError(t, db.DoTxn(ctx, nil, func(ctx context.Context) error {
_, err := db.Exec(ctx, "INSERT INTO meow (value) VALUES ('meow 4');")
require.NoError(t, err)
go func() {
_, err := db.Exec(context.Background(), "INSERT INTO meow (value) VALUES ('meow 5');")
require.NoError(t, err)
}()
time.Sleep(50 * time.Millisecond)
return nil
}))
val, err := getMeow(ctx, db.Execable(ctx), 4)
require.NoError(t, err)
require.Equal(t, "meow 4", val)
val, err = getMeow(ctx, db.Execable(ctx), 5)
require.NoError(t, err)
require.Equal(t, "meow 5", val)
}

func TestDatabase_Deadlock(t *testing.T) {
db := initTestDB(t)
ctx := context.Background()
_ = db.DoTxn(ctx, nil, func(ctx context.Context) error {
assert.PanicsWithError(t, dbutil.ErrQueryDeadlock.Error(), func() {
_, _ = db.Exec(context.Background(), "INSERT INTO meow (value) VALUES ('meow 4');")
})
return fmt.Errorf("meow")
})
}

func TestDatabase_Deadlock_Acquire(t *testing.T) {
db := initTestDB(t)
ctx := context.Background()
_ = db.DoTxn(ctx, nil, func(ctx context.Context) error {
assert.PanicsWithError(t, dbutil.ErrAcquireDeadlock.Error(), func() {
_, _ = db.AcquireConn(context.Background())
})
return fmt.Errorf("meow")
})
}

func TestDatabase_Deadlock_Txn(t *testing.T) {
db := initTestDB(t)
ctx := context.Background()
_ = db.DoTxn(ctx, nil, func(ctx context.Context) error {
assert.PanicsWithError(t, dbutil.ErrTransactionDeadlock.Error(), func() {
_ = db.DoTxn(context.Background(), nil, func(ctx context.Context) error {
return nil
})
})
return fmt.Errorf("meow")
})
}

func TestDatabase_Deadlock_Child(t *testing.T) {
db := initTestDB(t)
ctx := context.Background()
childDB := db.Child("", nil, nil)
_ = db.DoTxn(ctx, nil, func(ctx context.Context) error {
assert.PanicsWithError(t, dbutil.ErrQueryDeadlock.Error(), func() {
_, _ = childDB.Exec(context.Background(), "INSERT INTO meow (value) VALUES ('meow 4');")
})
return fmt.Errorf("meow")
})
}

func TestDatabase_Deadlock_Child2(t *testing.T) {
db := initTestDB(t)
ctx := context.Background()
childDB := db.Child("", nil, nil)
_ = childDB.DoTxn(ctx, nil, func(ctx context.Context) error {
assert.PanicsWithError(t, dbutil.ErrQueryDeadlock.Error(), func() {
_, _ = db.Exec(context.Background(), "INSERT INTO meow (value) VALUES ('meow 4');")
})
return fmt.Errorf("meow")
})
}
17 changes: 17 additions & 0 deletions dbutil/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"sync/atomic"
"time"

"github.com/petermattis/goid"
"github.com/rs/zerolog"

"go.mau.fi/util/exerrors"
Expand Down Expand Up @@ -51,6 +52,10 @@ func (db *Database) QueryRow(ctx context.Context, query string, args ...any) *sq
return db.Execable(ctx).QueryRowContext(ctx, query, args...)
}

var ErrTransactionDeadlock = errors.New("attempt to start new transaction in goroutine with transaction")
var ErrQueryDeadlock = errors.New("attempt to query without context in goroutine with transaction")
var ErrAcquireDeadlock = errors.New("attempt to acquire connection without context in goroutine with transaction")

func (db *Database) BeginTx(ctx context.Context, opts *TxnOptions) (*LoggingTxn, error) {
if ctx == nil {
panic("BeginTx() called with nil ctx")
Expand All @@ -65,6 +70,12 @@ func (db *Database) DoTxn(ctx context.Context, opts *TxnOptions, fn func(ctx con
if ctx.Value(db.txnCtxKey) != nil {
zerolog.Ctx(ctx).Trace().Msg("Already in a transaction, not creating a new one")
return fn(ctx)
} else if db.DeadlockDetection {
goroutineID := goid.Get()
if !db.txnDeadlockMap.Add(goroutineID) {
panic(ErrTransactionDeadlock)
}
defer db.txnDeadlockMap.Remove(goroutineID)
}

log := zerolog.Ctx(ctx).With().Str("db_txn_id", random.String(12)).Logger()
Expand Down Expand Up @@ -141,6 +152,9 @@ func (db *Database) Execable(ctx context.Context) Execable {
if ok {
return txn
}
if db.DeadlockDetection && db.txnDeadlockMap.Has(goid.Get()) {
panic(ErrQueryDeadlock)
}
return &db.LoggingDB
}

Expand All @@ -152,6 +166,9 @@ func (db *Database) AcquireConn(ctx context.Context) (Conn, error) {
if ok {
return nil, fmt.Errorf("cannot acquire connection while in a transaction")
}
if db.DeadlockDetection && db.txnDeadlockMap.Has(goid.Get()) {
panic(ErrAcquireDeadlock)
}
conn, err := db.RawDB.Conn(ctx)
if err != nil {
return nil, err
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.21
require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/mattn/go-sqlite3 v1.14.22
github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6
github.com/rs/zerolog v1.33.0
github.com/stretchr/testify v1.9.0
golang.org/x/exp v0.0.0-20240707233637-46b078467d37
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6 h1:DUDJI8T/9NcGbbL+AWk6vIYlmQ8ZBS8LZqVre6zbkPQ=
github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
Expand Down

0 comments on commit eba5f6e

Please sign in to comment.