Skip to content

Commit

Permalink
Merge pull request #7 from randlabs/improvements
Browse files Browse the repository at this point in the history
Improvements and new features
  • Loading branch information
mxmauro authored May 6, 2024
2 parents fa37e31 + 4438472 commit f47d004
Show file tree
Hide file tree
Showing 11 changed files with 950 additions and 146 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ app
demo

vendor/*
qodana.yml
85 changes: 85 additions & 0 deletions common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package postgres_test

import (
"crypto/rand"
"encoding/json"
"flag"
"testing"
)

// -----------------------------------------------------------------------------

var (
pgUrl string
pgHost string
pgPort uint
pgUsername string
pgPassword string
pgDatabaseName string
)

var (
testJSON TestJSON
testBLOB []byte
testJSONBytes []byte
)

// -----------------------------------------------------------------------------

func init() {
flag.StringVar(&pgUrl, "url", "", "Specifies the Postgres URL.")
flag.StringVar(&pgHost, "host", "127.0.0.1", "Specifies the Postgres server host. (Defaults to '127.0.0.1')")
flag.UintVar(&pgPort, "port", 5432, "Specifies the Postgres server port. (Defaults to 5432)")
flag.StringVar(&pgUsername, "user", "postgres", "Specifies the user name. (Defaults to 'postgres')")
flag.StringVar(&pgPassword, "password", "", "Specifies the user password.")
flag.StringVar(&pgDatabaseName, "db", "", "Specifies the database name.")

testJSON = TestJSON{
Id: 1,
Text: "demo",
}

testBLOB = make([]byte, 1024)
_, _ = rand.Read(testBLOB)

testJSONBytes, _ = json.Marshal(testJSON)
}

// -----------------------------------------------------------------------------

func checkSettings(t *testing.T) {
if len(pgHost) == 0 {
t.Fatalf("Server host not specified")
}
if pgPort > 65535 {
t.Fatalf("Server port not specified or invalid")
}
if len(pgUsername) == 0 {
t.Fatalf("User name to access database server not specified")
}
if len(pgPassword) == 0 {
t.Fatalf("User password to access database server not specified")
}
if len(pgDatabaseName) == 0 {
t.Fatalf("Database name not specified")
}
}

func addressOf[T any](x T) *T {
return &x
}

func jsonReEncode(src string) (string, error) {
var v interface{}

err := json.Unmarshal([]byte(src), &v)
if err == nil {
var reencoded []byte

reencoded, err = json.Marshal(v)
if err == nil {
return string(reencoded), nil
}
}
return "", err
}
95 changes: 95 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package postgres

import (
"context"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)

// -----------------------------------------------------------------------------

// Conn encloses a single connection object.
type Conn struct {
db *Database
conn *pgxpool.Conn
}

// -----------------------------------------------------------------------------

// DB returns the underlying database driver.
func (c *Conn) DB() *Database {
return c.db
}

// Exec executes an SQL statement within the single connection.
func (c *Conn) Exec(ctx context.Context, sql string, args ...interface{}) (int64, error) {
affectedRows := int64(0)
ct, err := c.conn.Exec(ctx, sql, args...)
if err == nil {
affectedRows = ct.RowsAffected()
}
return affectedRows, c.db.processError(err)
}

// QueryRow executes a SQL query within the single connection.
func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Row {
return &rowGetter{
db: c.db,
row: c.conn.QueryRow(ctx, sql, args...),
}
}

// QueryRows executes a SQL query within the single connection.
func (c *Conn) QueryRows(ctx context.Context, sql string, args ...interface{}) Rows {
rows, err := c.conn.Query(ctx, sql, args...)
return &rowsGetter{
db: c.db,
ctx: ctx,
rows: rows,
err: err,
}
}

// Copy executes a SQL copy query within the single connection.
func (c *Conn) Copy(ctx context.Context, tableName string, columnNames []string, callback CopyCallback) (int64, error) {
n, err := c.conn.CopyFrom(
ctx,
pgx.Identifier{tableName},
columnNames,
&copyWithCallback{
ctx: ctx,
callback: callback,
},
)

// Done
return n, c.db.processError(err)
}

// WithinTx executes a callback function within the context of a single connection.
func (c *Conn) WithinTx(ctx context.Context, cb WithinTxCallback) error {
innerTx, err := c.conn.BeginTx(ctx, pgx.TxOptions{
IsoLevel: pgx.ReadCommitted, //pgx.Serializable,
AccessMode: pgx.ReadWrite,
DeferrableMode: pgx.NotDeferrable,
})
if err == nil {
err = cb(ctx, Tx{
db: c.db,
tx: innerTx,
})
if err == nil {
err = innerTx.Commit(ctx)
if err != nil {
err = newError(err, "unable to commit db transaction")
}
}
if err != nil {
_ = innerTx.Rollback(context.Background()) // Using context.Background() on purpose
}
} else {
err = newError(err, "unable to start transaction")
}
return c.db.processError(err)
}
4 changes: 4 additions & 0 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,7 @@ func newError(wrappedErr error, text string) *Error {
func encodeDSN(s string) string {
return strings.ReplaceAll(s, "'", "\\'")
}

func quoteIdentifier(s string) string {
return "\"" + strings.ReplaceAll(s, "\"", "\"\"") + "\""
}
20 changes: 3 additions & 17 deletions internal.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package postgres

import (
"context"
"errors"

"github.com/jackc/pgx/v5"
Expand All @@ -13,31 +12,18 @@ var errNoRows = &NoRowsError{}

// -----------------------------------------------------------------------------

// Gets a connection from the pool and initiates a transaction.
func (db *Database) getTx(ctx context.Context) (pgx.Tx, error) {
tx, err := db.pool.BeginTx(ctx, pgx.TxOptions{
IsoLevel: pgx.ReadCommitted, //pgx.Serializable,
AccessMode: pgx.ReadWrite,
DeferrableMode: pgx.NotDeferrable,
})
if err != nil {
return nil, newError(err, "unable to start transaction")
}

//Done
return tx, nil
}

func (db *Database) processError(err error) error {
isNoRows := false
if errors.Is(err, pgx.ErrNoRows) {
err = errNoRows
isNoRows = true
}

// Only deal with fatal database errors. Cancellation, timeouts and empty result sets are not considered fatal.
db.err.mutex.Lock()
defer db.err.mutex.Unlock()

if err != nil && IsDatabaseError(err) && err != errNoRows {
if err != nil && (!isNoRows) && IsDatabaseError(err) {
if db.err.last == nil {
db.err.last = err
if db.err.handler != nil {
Expand Down
Loading

0 comments on commit f47d004

Please sign in to comment.