Skip to content

Commit

Permalink
pgx: don't use database/sql interface
Browse files Browse the repository at this point in the history
it'd be nice to have contexts for many of these methods, but that'd be a much wider change
  • Loading branch information
serprex committed Nov 29, 2024
1 parent c378583 commit 62b3399
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 174 deletions.
106 changes: 41 additions & 65 deletions database/pgx/pgx.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package pgx

import (
"context"
"database/sql"
"fmt"
"io"
nurl "net/url"
Expand All @@ -22,8 +21,7 @@ import (
"github.com/hashicorp/go-multierror"
"github.com/jackc/pgconn"
"github.com/jackc/pgerrcode"
_ "github.com/jackc/pgx/v4/stdlib"
"github.com/lib/pq"
"github.com/jackc/pgx/v4"
)

const (
Expand Down Expand Up @@ -69,27 +67,26 @@ type Config struct {

type Postgres struct {
// Locking and unlocking need to use the same connection
conn *sql.Conn
db *sql.DB
conn *pgx.Conn
isLocked atomic.Bool

// Open and WithInstance need to guarantee that config is never nil
config *Config
}

func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
func WithInstance(instance *pgx.Conn, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}

if err := instance.Ping(); err != nil {
if err := instance.Ping(context.Background()); err != nil {
return nil, err
}

if config.DatabaseName == "" {
query := `SELECT CURRENT_DATABASE()`
var databaseName string
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
if err := instance.QueryRow(context.Background(), query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -103,7 +100,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config.SchemaName == "" {
query := `SELECT CURRENT_SCHEMA()`
var schemaName string
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
if err := instance.QueryRow(context.Background(), query).Scan(&schemaName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand Down Expand Up @@ -139,15 +136,8 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
}
}

conn, err := instance.Conn(context.Background())

if err != nil {
return nil, err
}

px := &Postgres{
conn: conn,
db: instance,
conn: instance,
config: config,
}

Expand All @@ -173,7 +163,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
// i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db
purl.Scheme = "postgres"

db, err := sql.Open("pgx/v4", migrate.FilterCustomQuery(purl).String())
db, err := pgx.Connect(context.Background(), migrate.FilterCustomQuery(purl).String())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -240,10 +230,9 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
}

func (p *Postgres) Close() error {
connErr := p.conn.Close()
dbErr := p.db.Close()
if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
connErr := p.conn.Close(context.Background())
if connErr != nil {
return fmt.Errorf("conn: %w", connErr)
}
return nil
}
Expand Down Expand Up @@ -283,19 +272,19 @@ func (p *Postgres) applyAdvisoryLock() error {

// This will wait indefinitely until the lock can be acquired.
query := `SELECT pg_advisory_lock($1)`
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
if _, err := p.conn.Exec(context.Background(), query, aid); err != nil {
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
}
return nil
}

func (p *Postgres) applyTableLock() error {
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
tx, err := p.conn.BeginTx(context.Background(), pgx.TxOptions{})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}
defer func() {
errRollback := tx.Rollback()
errRollback := tx.Rollback(context.Background())
if errRollback != nil {
err = multierror.Append(err, errRollback)
}
Expand All @@ -306,30 +295,25 @@ func (p *Postgres) applyTableLock() error {
return err
}

query := "SELECT * FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
rows, err := tx.Query(query, aid)
query := "SELECT * FROM " + quoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
rows, err := tx.Query(context.Background(), query, aid)
if err != nil {
return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
}

defer func() {
if errClose := rows.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()
defer rows.Close()

// If row exists at all, lock is present
locked := rows.Next()
if locked {
return database.ErrLocked
}

query = "INSERT INTO " + pq.QuoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)"
if _, err := tx.Exec(query, aid); err != nil {
query = "INSERT INTO " + quoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)"
if _, err := tx.Exec(context.Background(), query, aid); err != nil {
return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
}

return tx.Commit()
return tx.Commit(context.Background())
}

func (p *Postgres) releaseAdvisoryLock() error {
Expand All @@ -339,7 +323,7 @@ func (p *Postgres) releaseAdvisoryLock() error {
}

query := `SELECT pg_advisory_unlock($1)`
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
if _, err := p.conn.Exec(context.Background(), query, aid); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -352,8 +336,8 @@ func (p *Postgres) releaseTableLock() error {
return err
}

query := "DELETE FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
if _, err := p.db.Exec(query, aid); err != nil {
query := "DELETE FROM " + quoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
if _, err := p.conn.Exec(context.Background(), query, aid); err != nil {
return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
}

Expand Down Expand Up @@ -391,7 +375,7 @@ func (p *Postgres) runStatement(statement []byte) error {
if strings.TrimSpace(query) == "" {
return nil
}
if _, err := p.conn.ExecContext(ctx, query); err != nil {
if _, err := p.conn.Exec(ctx, query); err != nil {

if pgErr, ok := err.(*pgconn.PgError); ok {
var line uint
Expand Down Expand Up @@ -448,14 +432,14 @@ func runesLastIndex(input []rune, target rune) int {
}

func (p *Postgres) SetVersion(version int, dirty bool) error {
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
tx, err := p.conn.BeginTx(context.Background(), pgx.TxOptions{})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}

query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName)
if _, err := tx.Exec(query); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
if _, err := tx.Exec(context.Background(), query); err != nil {
if errRollback := tx.Rollback(context.Background()); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
Expand All @@ -466,15 +450,15 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == database.NilVersion && dirty) {
query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
if _, err := tx.Exec(query, version, dirty); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
if _, err := tx.Exec(context.Background(), query, version, dirty); err != nil {
if errRollback := tx.Rollback(context.Background()); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}

if err := tx.Commit(); err != nil {
if err := tx.Commit(context.Background()); err != nil {
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
}

Expand All @@ -483,9 +467,9 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {

func (p *Postgres) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
err = p.conn.QueryRow(context.Background(), query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
case err == pgx.ErrNoRows:
return database.NilVersion, false, nil

case err != nil:
Expand All @@ -504,15 +488,11 @@ func (p *Postgres) Version() (version int, dirty bool, err error) {
func (p *Postgres) Drop() (err error) {
// select all tables in current schema
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
tables, err := p.conn.QueryContext(context.Background(), query)
tables, err := p.conn.Query(context.Background(), query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer func() {
if errClose := tables.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()
defer tables.Close()

// delete one table after another
tableNames := make([]string, 0)
Expand All @@ -539,7 +519,7 @@ func (p *Postgres) Drop() (err error) {
// delete one by one ...
for _, t := range tableNames {
query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
if _, err := p.conn.Exec(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
Expand Down Expand Up @@ -571,7 +551,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
row := p.conn.QueryRow(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)

var count int
err = row.Scan(&count)
Expand All @@ -584,7 +564,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
}

query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
if _, err = p.conn.Exec(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -598,15 +578,15 @@ func (p *Postgres) ensureLockTable() error {

var count int
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
if err := p.db.QueryRow(query, p.config.LockTable).Scan(&count); err != nil {
if err := p.conn.QueryRow(context.Background(), query, p.config.LockTable).Scan(&count); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if count == 1 {
return nil
}

query = `CREATE TABLE ` + pq.QuoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
if _, err := p.db.Exec(query); err != nil {
query = `CREATE TABLE ` + quoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
if _, err := p.conn.Exec(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -615,9 +595,5 @@ func (p *Postgres) ensureLockTable() error {

// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
func quoteIdentifier(name string) string {
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
}
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
return pgx.Identifier([]string{name}).Sanitize()
}
Loading

0 comments on commit 62b3399

Please sign in to comment.