From 4438472747ead27e812c4641703af7f89e8dd640 Mon Sep 17 00:00:00 2001 From: Mauro Leggieri Date: Sun, 5 May 2024 22:19:45 -0300 Subject: [PATCH] Improvements and new features --- .gitignore | 1 + common_test.go | 85 +++++++++ connection.go | 95 ++++++++++ helpers.go | 4 + internal.go | 20 +-- migration.go | 429 ++++++++++++++++++++++++++++++++++++++++++++++ migration_test.go | 155 +++++++++++++++++ postgres.go | 152 +++++++++++----- postgres_test.go | 95 ++-------- qodana.yaml | 29 ++++ transaction.go | 31 +++- 11 files changed, 950 insertions(+), 146 deletions(-) create mode 100644 common_test.go create mode 100644 connection.go create mode 100644 migration.go create mode 100644 migration_test.go create mode 100644 qodana.yaml diff --git a/.gitignore b/.gitignore index af72e14..82a074d 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ app demo vendor/* +qodana.yml \ No newline at end of file diff --git a/common_test.go b/common_test.go new file mode 100644 index 0000000..b653de1 --- /dev/null +++ b/common_test.go @@ -0,0 +1,85 @@ +package postgres_test + +import ( + "crypto/rand" + "encoding/json" + "flag" + "testing" +) + +// ----------------------------------------------------------------------------- + +var ( + pgUrl string + pgHost string + pgPort uint + pgUsername string + pgPassword string + pgDatabaseName string +) + +var ( + testJSON TestJSON + testBLOB []byte + testJSONBytes []byte +) + +// ----------------------------------------------------------------------------- + +func init() { + flag.StringVar(&pgUrl, "url", "", "Specifies the Postgres URL.") + flag.StringVar(&pgHost, "host", "127.0.0.1", "Specifies the Postgres server host. (Defaults to '127.0.0.1')") + flag.UintVar(&pgPort, "port", 5432, "Specifies the Postgres server port. (Defaults to 5432)") + flag.StringVar(&pgUsername, "user", "postgres", "Specifies the user name. (Defaults to 'postgres')") + flag.StringVar(&pgPassword, "password", "", "Specifies the user password.") + flag.StringVar(&pgDatabaseName, "db", "", "Specifies the database name.") + + testJSON = TestJSON{ + Id: 1, + Text: "demo", + } + + testBLOB = make([]byte, 1024) + _, _ = rand.Read(testBLOB) + + testJSONBytes, _ = json.Marshal(testJSON) +} + +// ----------------------------------------------------------------------------- + +func checkSettings(t *testing.T) { + if len(pgHost) == 0 { + t.Fatalf("Server host not specified") + } + if pgPort > 65535 { + t.Fatalf("Server port not specified or invalid") + } + if len(pgUsername) == 0 { + t.Fatalf("User name to access database server not specified") + } + if len(pgPassword) == 0 { + t.Fatalf("User password to access database server not specified") + } + if len(pgDatabaseName) == 0 { + t.Fatalf("Database name not specified") + } +} + +func addressOf[T any](x T) *T { + return &x +} + +func jsonReEncode(src string) (string, error) { + var v interface{} + + err := json.Unmarshal([]byte(src), &v) + if err == nil { + var reencoded []byte + + reencoded, err = json.Marshal(v) + if err == nil { + return string(reencoded), nil + } + } + return "", err +} diff --git a/connection.go b/connection.go new file mode 100644 index 0000000..22a0f8b --- /dev/null +++ b/connection.go @@ -0,0 +1,95 @@ +package postgres + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +// ----------------------------------------------------------------------------- + +// Conn encloses a single connection object. +type Conn struct { + db *Database + conn *pgxpool.Conn +} + +// ----------------------------------------------------------------------------- + +// DB returns the underlying database driver. +func (c *Conn) DB() *Database { + return c.db +} + +// Exec executes an SQL statement within the single connection. +func (c *Conn) Exec(ctx context.Context, sql string, args ...interface{}) (int64, error) { + affectedRows := int64(0) + ct, err := c.conn.Exec(ctx, sql, args...) + if err == nil { + affectedRows = ct.RowsAffected() + } + return affectedRows, c.db.processError(err) +} + +// QueryRow executes a SQL query within the single connection. +func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { + return &rowGetter{ + db: c.db, + row: c.conn.QueryRow(ctx, sql, args...), + } +} + +// QueryRows executes a SQL query within the single connection. +func (c *Conn) QueryRows(ctx context.Context, sql string, args ...interface{}) Rows { + rows, err := c.conn.Query(ctx, sql, args...) + return &rowsGetter{ + db: c.db, + ctx: ctx, + rows: rows, + err: err, + } +} + +// Copy executes a SQL copy query within the single connection. +func (c *Conn) Copy(ctx context.Context, tableName string, columnNames []string, callback CopyCallback) (int64, error) { + n, err := c.conn.CopyFrom( + ctx, + pgx.Identifier{tableName}, + columnNames, + ©WithCallback{ + ctx: ctx, + callback: callback, + }, + ) + + // Done + return n, c.db.processError(err) +} + +// WithinTx executes a callback function within the context of a single connection. +func (c *Conn) WithinTx(ctx context.Context, cb WithinTxCallback) error { + innerTx, err := c.conn.BeginTx(ctx, pgx.TxOptions{ + IsoLevel: pgx.ReadCommitted, //pgx.Serializable, + AccessMode: pgx.ReadWrite, + DeferrableMode: pgx.NotDeferrable, + }) + if err == nil { + err = cb(ctx, Tx{ + db: c.db, + tx: innerTx, + }) + if err == nil { + err = innerTx.Commit(ctx) + if err != nil { + err = newError(err, "unable to commit db transaction") + } + } + if err != nil { + _ = innerTx.Rollback(context.Background()) // Using context.Background() on purpose + } + } else { + err = newError(err, "unable to start transaction") + } + return c.db.processError(err) +} diff --git a/helpers.go b/helpers.go index c684fad..d504bea 100644 --- a/helpers.go +++ b/helpers.go @@ -48,3 +48,7 @@ func newError(wrappedErr error, text string) *Error { func encodeDSN(s string) string { return strings.ReplaceAll(s, "'", "\\'") } + +func quoteIdentifier(s string) string { + return "\"" + strings.ReplaceAll(s, "\"", "\"\"") + "\"" +} diff --git a/internal.go b/internal.go index f8497d5..5d5d9ae 100644 --- a/internal.go +++ b/internal.go @@ -1,7 +1,6 @@ package postgres import ( - "context" "errors" "github.com/jackc/pgx/v5" @@ -13,31 +12,18 @@ var errNoRows = &NoRowsError{} // ----------------------------------------------------------------------------- -// Gets a connection from the pool and initiates a transaction. -func (db *Database) getTx(ctx context.Context) (pgx.Tx, error) { - tx, err := db.pool.BeginTx(ctx, pgx.TxOptions{ - IsoLevel: pgx.ReadCommitted, //pgx.Serializable, - AccessMode: pgx.ReadWrite, - DeferrableMode: pgx.NotDeferrable, - }) - if err != nil { - return nil, newError(err, "unable to start transaction") - } - - //Done - return tx, nil -} - func (db *Database) processError(err error) error { + isNoRows := false if errors.Is(err, pgx.ErrNoRows) { err = errNoRows + isNoRows = true } // Only deal with fatal database errors. Cancellation, timeouts and empty result sets are not considered fatal. db.err.mutex.Lock() defer db.err.mutex.Unlock() - if err != nil && IsDatabaseError(err) && err != errNoRows { + if err != nil && (!isNoRows) && IsDatabaseError(err) { if db.err.last == nil { db.err.last = err if db.err.handler != nil { diff --git a/migration.go b/migration.go new file mode 100644 index 0000000..55cab2e --- /dev/null +++ b/migration.go @@ -0,0 +1,429 @@ +package postgres + +import ( + "context" + "errors" + "hash/fnv" + "math" + "strings" + "unicode/utf8" +) + +// ----------------------------------------------------------------------------- + +// MigrationStep contains details about the SQL sentence to execute in this step. +// Pass an empty struct to indicate the end. +type MigrationStep struct { + // Name is a user defined name for this migration step. I.e.: "v1->v2" + Name string + + // The index of the SQL sentence within a named block. + SequenceNo int + + // Actual SQL sentence to execute in this migration step. + Sql string +} + +// MigrationStepCallback is called to get the migration step details at stepIdx position (starting from 1) +type MigrationStepCallback func(ctx context.Context, stepIdx int) (MigrationStep, error) + +// ----------------------------------------------------------------------------- + +// CreateMigrationStepsFromSqlContent creates an array of migration steps based on the provided content +// +// The expected format is the following: +// # a comment with the step name (starting and ending spaces and dashes will be removed) +// A single SQL sentence +// (extra comment/sql sentence pairs) +func CreateMigrationStepsFromSqlContent(content string) ([]MigrationStep, error) { + steps := make([]MigrationStep, 0) + + currentName := "" + currentSeqNo := 1 + + // Parse content + contentLen := len(content) + for ofs := 0; ofs < contentLen; { + // Check if we can ignore the current line + deltaOfs := shouldIgnoreLine(content[ofs:]) + if deltaOfs > 0 { + ofs += deltaOfs + continue + } + + // Is it a comment at the beginning of the line? + if content[ofs] == '#' { + // Yes, assume new zone if we are not in the middle of an sql sentence + startOfs := ofs + ofs += findEol(content[ofs:]) + + currentName = truncStrBytes(strings.Trim(content[startOfs:ofs], " \t-=#"), 255) + if len(currentName) == 0 { + return nil, errors.New("empty start of block comment") + } + currentSeqNo = 1 + + continue + } + + // At this point we start to parse an SQL sentence + if len(currentName) == 0 { + return nil, errors.New("SQL sentence found outside a block") + } + + currentSql := strings.Builder{} + addSpace := false + for ofs < contentLen { + deltaOfs = skipSpaces(content[ofs:]) + if deltaOfs > 0 { + addSpace = true + ofs += deltaOfs + continue + } + deltaOfs = skipEol(content[ofs:]) + if deltaOfs > 0 { + addSpace = true + ofs += deltaOfs + continue + } + + if content[ofs] == '#' { + // We find a comment, skip until EOL + addSpace = true + ofs += 1 + ofs += findEol(content[ofs:]) + continue + } + + if content[ofs] == ';' { + // Reached the end of the SQL sentence + if currentSql.Len() > 0 { + currentSql.WriteRune(';') + + steps = append(steps, MigrationStep{ + Name: currentName, + SequenceNo: currentSeqNo, + Sql: currentSql.String(), + }) + + // Reset + currentSql = strings.Builder{} + currentSeqNo += 1 + } + + addSpace = false + ofs += 1 + break + } + + if content[ofs] == '\'' { + // Start of a single-quote string + startOfs := ofs + ofs += 1 + + for { + if ofs >= contentLen { + // Open string found + return nil, errors.New("invalid SQL content (open string)") + } + + r, rSize := utf8.DecodeRuneInString(content[ofs:]) + if r == utf8.RuneError || rSize == 0 { + return nil, errors.New("invalid SQL content (invalid char)") + } + ofs += rSize + + // Reached the end of the string or double single-quotes? + if r == '\'' { + if ofs >= contentLen || content[ofs] != '\'' { + break // End of string + } + // Double single-quotes + ofs += 1 + } + } + + if addSpace { + currentSql.WriteRune(' ') + addSpace = false + } + currentSql.WriteString(content[startOfs:ofs]) + continue + } + + if content[ofs] == '"' { + // Start of a double-quotes string + startOfs := ofs + ofs += 1 + + escapedCharacter := false + for { + if ofs >= contentLen { + // Open string found + return nil, errors.New("invalid SQL content (open string)") + } + + r, rSize := utf8.DecodeRuneInString(content[ofs:]) + if r == utf8.RuneError || rSize == 0 { + return nil, errors.New("invalid SQL content (invalid char)") + } + ofs += rSize + + if escapedCharacter { + escapedCharacter = false + continue + } + + // Reached the end of the string? + if r == '"' { + break + } + + // Escaped character? + if r == '\\' { + escapedCharacter = true + } + } + + if addSpace { + currentSql.WriteRune(' ') + addSpace = false + } + currentSql.WriteString(content[startOfs:ofs]) + continue + } + + if content[ofs] == '$' { + // Dollar tag + startOfs := ofs + ofs += 1 + + for { + if ofs >= contentLen { + return nil, errors.New("invalid SQL content (dollar tag)") + } + if content[ofs] == '$' { + ofs += 1 + break + } + if !(content[ofs] == '_' || (content[ofs] >= '0' && content[ofs] <= '9') || + (content[ofs] >= 'A' && content[ofs] <= 'Z') || + (content[ofs] >= 'a' && content[ofs] <= 'z')) { + return nil, errors.New("invalid SQL content (dollar tag)") + } + ofs += 1 + } + tag := content[startOfs:ofs] + + // Find the next tag + deltaOfs = strings.Index(content[ofs:], tag) + if deltaOfs < 0 { + return nil, errors.New("invalid SQL content (open dollar tag)") + } + ofs += deltaOfs + len(tag) + + if addSpace { + currentSql.WriteRune(' ') + addSpace = false + } + currentSql.WriteString(content[startOfs:ofs]) + continue + } + + // If we reached here, it is a single character of an sql sentence + r, rSize := utf8.DecodeRuneInString(content[ofs:]) + if r == utf8.RuneError || rSize == 0 { + return nil, errors.New("invalid SQL content (invalid char)") + } + ofs += rSize + + if addSpace { + currentSql.WriteRune(' ') + addSpace = false + } + currentSql.WriteRune(r) + } + + // At this point we are at the end of the content or reached the end of an SQL sentence + if currentSql.Len() > 0 { + currentSql.WriteRune(';') + + steps = append(steps, MigrationStep{ + Name: currentName, + SequenceNo: currentSeqNo, + Sql: currentSql.String(), + }) + + // Reset + currentSql = strings.Builder{} + currentSeqNo += 1 + } + } + + // Done + return steps, nil +} + +// ----------------------------------------------------------------------------- + +func (db *Database) RunMigrations(ctx context.Context, tableName string, cb MigrationStepCallback) error { + // Lock concurrent access from multiple instances/threads + lockId := db.getMigrationLockId(tableName) + + // Quote table name + tableName = quoteIdentifier(tableName) + + // We must execute migrations within a single connection + return db.WithinConn(ctx, func(ctx context.Context, conn Conn) error { + var stepIdx int32 + + _, err := conn.Exec(ctx, "SELECT pg_advisory_lock($1)", lockId) + if err != nil { + return err + } + defer func() { + _, _ = conn.Exec(ctx, "SELECT pg_advisory_unlock($1)", lockId) + }() + + // Create migration table if it does not exist + _, err = conn.Exec(ctx, + `CREATE TABLE IF NOT EXISTS `+tableName+` ( + id int NOT NULL PRIMARY KEY, + name varchar(255) NOT NULL, + sequence int NOT NULL, + executedAt timestamp NOT NULL + )`) + if err != nil { + return err + } + + // Calculate the next step index to execute based on the last stored + row := conn.QueryRow(ctx, `SELECT id FROM `+tableName+` ORDER BY id DESC LIMIT 1`) + err = row.Scan(&stepIdx) + if err == nil { + stepIdx += 1 + } else { + if !IsNoRowsError(err) { + return err + } + stepIdx = 1 + } + + // Run migrations + for { + var stepInfo MigrationStep + + stepInfo, err = cb(ctx, int(stepIdx)) + if err != nil { + return err + } + // If no name or sql sentence was provided, assume we finished + if len(stepInfo.Name) == 0 { + break + } + + // Execute step + err = conn.WithinTx(ctx, func(ctx context.Context, tx Tx) error { + _, stepErr := tx.Exec(ctx, stepInfo.Sql) + if stepErr == nil { + _, stepErr = tx.Exec( + ctx, + `INSERT INTO `+tableName+` (id, name, sequence, executedAt) VALUES ($1, $2, $3, NOW());`, + stepIdx, stepInfo.Name, stepInfo.SequenceNo, + ) + } + // Done + return stepErr + }) + if err != nil { + return err + } + + // Increment index + stepIdx += 1 + } + + // Done + return nil + }) +} + +func (db *Database) getMigrationLockId(tableName string) int64 { + h := fnv.New64a() + _, _ = h.Write(db.nameHash[:]) + _, _ = h.Write([]byte(tableName)) + return int64(h.Sum64() & math.MaxInt64) +} + +func truncStrBytes(s string, maxBytes int) string { + if len(s) <= maxBytes { + return s + } + truncated := s[:maxBytes] + l := maxBytes + for l > 0 { + // Decode last rune. If it's invalid, we'll move back until we find a valid one. + r, size := utf8.DecodeLastRuneInString(truncated) + if r != utf8.RuneError { + break + } + if size == 0 { + return "" + } + // If the last rune is invalid, trim the string byte by byte until we get a valid rune. + l -= 1 + truncated = truncated[:l] + } + return truncated +} + +// This function returns > 0 if the line must be ignored. Ignored +func shouldIgnoreLine(s string) int { + ofs := skipSpaces(s) + if ofs >= len(s) { + return ofs // Yes + } + // End of line? + if s[ofs] == '\r' || s[ofs] == '\n' { + ofs += skipEol(s[ofs:]) + return ofs // Yes + } + // Comment not at the beginning? + if s[ofs] == '#' && ofs > 0 { + ofs += findEol(s[ofs:]) + ofs += skipEol(s[ofs:]) + return ofs // Yes + } + // Do not skip this line + return 0 +} + +func findEol(s string) int { + eolOfs := strings.IndexByte(s, '\n') + if eolOfs < 0 { + eolOfs = len(s) + } + eolOfs2 := strings.IndexByte(s, '\r') + if eolOfs2 >= 0 && eolOfs2 < eolOfs { + eolOfs = eolOfs2 + } + return eolOfs +} + +func skipSpaces(s string) int { + count := 0 + l := len(s) + for count < l && (s[count] == ' ' || s[count] == '\t') { + count += 1 + } + return count +} + +func skipEol(s string) int { + count := 0 + l := len(s) + for count < l && (s[count] == '\r' || s[count] == '\n') { + count += 1 + } + return count +} diff --git a/migration_test.go b/migration_test.go new file mode 100644 index 0000000..a417ac7 --- /dev/null +++ b/migration_test.go @@ -0,0 +1,155 @@ +package postgres_test + +import ( + "context" + "flag" + "fmt" + "testing" + + "github.com/randlabs/go-postgres" +) + +// ----------------------------------------------------------------------------- + +func TestMigration(t *testing.T) { + var db *postgres.Database + var err error + + // Parse and check command-line parameters + flag.Parse() + checkSettings(t) + + ctx := context.Background() + + // Create database driver + if len(pgUrl) > 0 { + db, err = postgres.NewFromURL(ctx, pgUrl) + } else { + db, err = postgres.New(ctx, postgres.Options{ + Host: pgHost, + Port: uint16(pgPort), + User: pgUsername, + Password: pgPassword, + Name: pgDatabaseName, + }) + } + if err != nil { + t.Fatal(err.Error()) + } + defer db.Close() + + // t.Log("Run migration test") + err = runMigrationTest(ctx, db) + if err != nil { + t.Fatal(err.Error()) + } +} + +func TestMigrationStepParser(t *testing.T) { + steps, err := postgres.CreateMigrationStepsFromSqlContent(` +# Simple table with single quotes in the default values +CREATE TABLE "Employee" ( + "EmployeeID" SERIAL PRIMARY KEY, + "FirstName" VARCHAR(100) NOT NULL, + "LastName" VARCHAR(100) NOT NULL, + "DateOfBirth" DATE NOT NULL DEFAULT '1990-01-01' +); + +# Table with special characters in column names +CREATE TABLE "Order-Details" ( + "Order_ID" INT, + "Product_Name" VARCHAR(255) DEFAULT 'unknown', + "Unit_Price" NUMERIC(10, 2) DEFAULT '0.00', + "Quantity" INT DEFAULT '1', + PRIMARY KEY ("Order_ID", "Product_Name") +); + +# Creating an index on the FirstName and LastName columns +CREATE INDEX "idx_employee_name" ON "Employee" ("FirstName", "LastName"); + +# Creating a unique index using a function +CREATE UNIQUE INDEX "idx_lower_last_name" ON "Employee" (LOWER("LastName")); + +# A function to calculate age from the DateOfBirth +CREATE FUNCTION "calculate_age" ("dob" DATE) RETURNS INT AS $$ +BEGIN + RETURN DATE_PART('year', AGE("dob")); +END; +$$ LANGUAGE plpgsql; + +# A function to concatenate first and last name with a space between +CREATE FUNCTION "full_name" ("first" TEXT, "last" TEXT) RETURNS TEXT AS $tag$ +BEGIN + RETURN "first" || ' ' || "last"; +END; +$tag$ LANGUAGE plpgsql; +`) + if err != nil { + t.Fatal(err.Error()) + } + if len(steps) != 6 { + t.Fatalf("Wrong number of steps: %d", len(steps)) + } +} + +// ----------------------------------------------------------------------------- + +func runMigrationTest(ctx context.Context, db *postgres.Database) error { + var stepIdx int + + // Destroy old test tables if exists + _, err := db.Exec(ctx, `DROP TABLE IF EXISTS migrations_test`) + if err == nil { + _, err = db.Exec(ctx, `DROP TABLE IF EXISTS migrations`) + } + if err != nil { + return fmt.Errorf("unable to drop tables [err=%v]", err.Error()) + } + + // Run migrations + err = db.RunMigrations(ctx, "migrations", func(ctx context.Context, stepIdx int) (postgres.MigrationStep, error) { + switch stepIdx { + case 1: + return postgres.MigrationStep{ + Name: "v1", + SequenceNo: 1, + Sql: `CREATE TABLE migrations_test (id int NOT NULL PRIMARY KEY, name varchar(255) NOT NULL);`, + }, nil + + case 2: + return postgres.MigrationStep{ + Name: "v1", + SequenceNo: 2, + Sql: `ALTER TABLE migrations_test ADD COLUMN description TEXT;`, + }, nil + } + return postgres.MigrationStep{}, nil + }) + if err != nil { + return fmt.Errorf("unable to run migrations [err=%v]", err.Error()) + } + + // Check last step + row := db.QueryRow(ctx, `SELECT id FROM migrations ORDER BY id DESC LIMIT 1`) + err = row.Scan(&stepIdx) + if err != nil { + return fmt.Errorf("unable to get last migration step [err=%v]", err.Error()) + } + if stepIdx != 2 { + return fmt.Errorf("last migration step mismatch [got=%v] [expected=2]", stepIdx) + } + + // Run more migrations + err = db.RunMigrations(ctx, "migrations", func(ctx context.Context, stepIdx int) (postgres.MigrationStep, error) { + if stepIdx != 3 { + return postgres.MigrationStep{}, fmt.Errorf("migration step mismatch [got=%v] [expected=3]", stepIdx) + } + return postgres.MigrationStep{}, nil + }) + if err != nil { + return fmt.Errorf("unable to run more migrations [err=%v]", err.Error()) + } + + // Done + return nil +} diff --git a/postgres.go b/postgres.go index 30f7eb4..e4b5b77 100644 --- a/postgres.go +++ b/postgres.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "crypto/sha256" "errors" "fmt" "net/url" @@ -16,9 +17,18 @@ import ( // ----------------------------------------------------------------------------- +const ( + defaultPoolMaxConns = 32 +) + +// ----------------------------------------------------------------------------- + // WithinTxCallback defines a callback called in the context of the initiated transaction. type WithinTxCallback = func(ctx context.Context, tx Tx) error +// WithinConnCallback defines a callback called in the context of a single connection. +type WithinConnCallback = func(ctx context.Context, conn Conn) error + // CopyCallback defines a callback that is called for each record being copied to the database type CopyCallback func(ctx context.Context, idx int) ([]interface{}, error) @@ -32,17 +42,19 @@ type Database struct { handler ErrorHandler last error } + nameHash [32]byte } // Options defines the database connection options. type Options struct { - Host string `json:"host"` - Port uint16 `json:"port"` - User string `json:"user"` - Password string `json:"password"` - Name string `json:"name"` - MaxConns int32 `json:"maxConns"` - SSLMode SSLMode + Host string `json:"host"` + Port uint16 `json:"port"` + User string `json:"user"` + Password string `json:"password"` + Name string `json:"name"` + MaxConns int32 `json:"maxConns"` + SSLMode SSLMode + ExtendedSettings map[string]string `json:"extendedSettings"` } // ErrorHandler defines a custom error handler. @@ -86,14 +98,27 @@ func New(ctx context.Context, opts Options) (*Database, error) { db := Database{} db.err.mutex = sync.Mutex{} - connString := fmt.Sprintf( + // Create a hash of the database name + h := sha256.New() + _, _ = h.Write([]byte(opts.Name)) + copy(db.nameHash[:], h.Sum(nil)) + + // Create PGX pool configuration. Usage of ParseConfig is mandatory :( + sbConnString := strings.Builder{} + _, _ = sbConnString.WriteString(fmt.Sprintf( "host='%s' port=%d user='%s' password='%s' dbname='%s' sslmode=%s", encodeDSN(opts.Host), opts.Port, encodeDSN(opts.User), encodeDSN(opts.Password), encodeDSN(opts.Name), sslMode, - ) - - // Create PGX pool configuration. Usage of ParseConfig is mandatory :( - poolConfig, err := pgxpool.ParseConfig(connString) + )) + if opts.ExtendedSettings != nil { + for k, v := range opts.ExtendedSettings { + _, _ = sbConnString.WriteRune(' ') + _, _ = sbConnString.WriteString(k) + _, _ = sbConnString.WriteRune('=') + _, _ = sbConnString.WriteString(encodeDSN(v)) + } + } + poolConfig, err := pgxpool.ParseConfig(sbConnString.String()) if err != nil { db.Close() return nil, errors.New("unable to parse connection string") @@ -102,12 +127,13 @@ func New(ctx context.Context, opts Options) (*Database, error) { // Override some settings poolConfig.MaxConns = opts.MaxConns if opts.MaxConns <= 0 { - poolConfig.MaxConns = 32 + poolConfig.MaxConns = defaultPoolMaxConns } poolConfig.MaxConnIdleTime = 10 * time.Minute poolConfig.HealthCheckPeriod = time.Minute poolConfig.MaxConnLifetime = 1 * time.Hour poolConfig.MaxConnLifetimeJitter = time.Minute + poolConfig.ConnConfig.ConnectTimeout = 15 * time.Second // Create the database connection pool db.pool, err = pgxpool.NewWithConfig(ctx, poolConfig) @@ -120,9 +146,19 @@ func New(ctx context.Context, opts Options) (*Database, error) { return &db, nil } +// IsPostgresURL returns true if the url schema is postgres +func IsPostgresURL(rawUrl string) bool { + return strings.HasPrefix(rawUrl, "pg://") || + strings.HasPrefix(rawUrl, "postgres://") || + strings.HasPrefix(rawUrl, "postgresql://") +} + // NewFromURL creates a new postgresql database driver from an URL func NewFromURL(ctx context.Context, rawUrl string) (*Database, error) { - opts := Options{} + opts := Options{ + SSLMode: SSLModeAllow, + MaxConns: defaultPoolMaxConns, + } u, err := url.ParseRequestURI(rawUrl) if err != nil { @@ -139,11 +175,11 @@ func NewFromURL(ctx context.Context, rawUrl string) (*Database, error) { if len(opts.Host) == 0 { return nil, errors.New("invalid host") } - s := u.Port() - if len(s) == 0 { + port := u.Port() + if len(port) == 0 { opts.Port = 5432 } else { - val, err2 := strconv.Atoi(s) + val, err2 := strconv.Atoi(port) if err2 != nil || val < 1 || val > 65535 { return nil, errors.New("invalid port") } @@ -160,36 +196,49 @@ func NewFromURL(ctx context.Context, rawUrl string) (*Database, error) { } // Check database name - if len(u.Path) < 1 || (!strings.HasPrefix(u.Path, "/")) || strings.Index(u.Path[1:], "/") >= 0 { + if len(u.Path) < 2 || (!strings.HasPrefix(u.Path, "/")) || strings.Index(u.Path[1:], "/") >= 0 { return nil, errors.New("invalid database name") } opts.Name = u.Path[1:] - // Check ssl mode - opts.SSLMode = SSLModeDisable - switch u.Query().Get("sslmode") { - case "allow": - opts.SSLMode = SSLModeAllow - - case "required": - opts.SSLMode = SSLModeRequired + // Parse query parameters + for k, values := range u.Query() { + v := "" + if len(values) > 0 { + v = values[0] + } + if k == "sslmode" { + // Check ssl mode + switch v { + case "allow": + opts.SSLMode = SSLModeAllow - case "disabled": - fallthrough - case "": + case "required": + opts.SSLMode = SSLModeRequired - default: - return nil, errors.New("invalid SSL mode") - } + case "disabled": + fallthrough + case "": - // Check max connections count - s = u.Query().Get("maxconn") - if len(s) > 0 { - val, err2 := strconv.Atoi(s) - if err2 != nil || val < 0 { - return nil, errors.New("invalid max connections count") + default: + return nil, errors.New("invalid SSL mode") + } + } else if k == "maxconns" { + // Check max connections count + if len(v) > 0 { + val, err2 := strconv.Atoi(v) + if err2 != nil || val < 0 { + return nil, errors.New("invalid max connections count") + } + opts.MaxConns = int32(val) + } + } else { + // Extended setting + if opts.ExtendedSettings == nil { + opts.ExtendedSettings = make(map[string]string) + } + opts.ExtendedSettings[k] = v } - opts.MaxConns = int32(val) } // Create @@ -274,10 +323,14 @@ func (db *Database) Copy(ctx context.Context, tableName string, columnNames []st } // WithinTx executes a callback function within the context of a transaction -func (db *Database) WithinTx(ctx context.Context, callback WithinTxCallback) error { - tx, err := db.getTx(ctx) +func (db *Database) WithinTx(ctx context.Context, cb WithinTxCallback) error { + tx, err := db.pool.BeginTx(ctx, pgx.TxOptions{ + IsoLevel: pgx.ReadCommitted, //pgx.Serializable, + AccessMode: pgx.ReadWrite, + DeferrableMode: pgx.NotDeferrable, + }) if err == nil { - err = callback(ctx, Tx{ + err = cb(ctx, Tx{ db: db, tx: tx, }) @@ -290,6 +343,21 @@ func (db *Database) WithinTx(ctx context.Context, callback WithinTxCallback) err if err != nil { _ = tx.Rollback(context.Background()) // Using context.Background() on purpose } + } else { + err = newError(err, "unable to start transaction") + } + return db.processError(err) +} + +// WithinConn executes a callback function within the context of a single connection +func (db *Database) WithinConn(ctx context.Context, cb WithinConnCallback) error { + conn, err := db.pool.Acquire(ctx) + if err == nil { + err = cb(ctx, Conn{ + db: db, + conn: conn, + }) + conn.Release() } return db.processError(err) } diff --git a/postgres_test.go b/postgres_test.go index e8515e3..01f0e29 100644 --- a/postgres_test.go +++ b/postgres_test.go @@ -2,8 +2,6 @@ package postgres_test import ( "context" - "crypto/rand" - "encoding/json" "errors" "flag" "fmt" @@ -65,42 +63,6 @@ type TestJSON struct { Text string `json:"text"` } -var ( - pgUrl string - pgHost string - pgPort uint - pgUsername string - pgPassword string - pgDatabaseName string -) - -var ( - testJSON TestJSON - testBLOB []byte - testJSONBytes []byte -) - -// ----------------------------------------------------------------------------- - -func init() { - flag.StringVar(&pgUrl, "url", "", "Specifies the Postgres URL.") - flag.StringVar(&pgHost, "host", "127.0.0.1", "Specifies the Postgres server host. (Defaults to '127.0.0.1')") - flag.UintVar(&pgPort, "port", 5432, "Specifies the Postgres server port. (Defaults to 5432)") - flag.StringVar(&pgUsername, "user", "postgres", "Specifies the user name. (Defaults to 'postgres')") - flag.StringVar(&pgPassword, "password", "", "Specifies the user password.") - flag.StringVar(&pgDatabaseName, "db", "", "Specifies the database name.") - - testJSON = TestJSON{ - Id: 1, - Text: "demo", - } - - testBLOB = make([]byte, 1024) - _, _ = rand.Read(testBLOB) - - testJSONBytes, _ = json.Marshal(testJSON) -} - // ----------------------------------------------------------------------------- func TestPostgres(t *testing.T) { @@ -157,29 +119,11 @@ func TestPostgres(t *testing.T) { // ----------------------------------------------------------------------------- -func checkSettings(t *testing.T) { - if len(pgHost) == 0 { - t.Fatalf("Server host not specified") - } - if pgPort > 65535 { - t.Fatalf("Server port not specified or invalid") - } - if len(pgUsername) == 0 { - t.Fatalf("User name to access database server not specified") - } - if len(pgPassword) == 0 { - t.Fatalf("User password to access database server not specified") - } - if len(pgDatabaseName) == 0 { - t.Fatalf("Database name not specified") - } -} - func createTestTable(ctx context.Context, db *postgres.Database) error { // Destroy old test table if exists _, err := db.Exec(ctx, `DROP TABLE IF EXISTS go_postgres_test_table CASCADE`) if err != nil { - return fmt.Errorf("Unable to drop tables [err=%v]", err.Error()) + return fmt.Errorf("unable to drop tables [err=%v]", err.Error()) } // Create the test table @@ -203,7 +147,7 @@ func createTestTable(ctx context.Context, db *postgres.Database) error { PRIMARY KEY (id) )`) if err != nil { - return fmt.Errorf("Unable to create test table [err=%v]", err.Error()) + return fmt.Errorf("unable to create test table [err=%v]", err.Error()) } // Done @@ -216,13 +160,13 @@ func insertTestData(ctx context.Context, db *postgres.Database) error { rd := genTestRowDef(idx, true) err := insertTestRowDef(ctx, tx, rd) if err != nil { - return fmt.Errorf("Unable to insert test data [id=%v/err=%v]", rd.id, err.Error()) + return fmt.Errorf("unable to insert test data [id=%v/err=%v]", rd.id, err.Error()) } nrd := genTestNullableRowDef(idx, true) err = insertTestNullableRowDef(ctx, tx, nrd) if err != nil { - return fmt.Errorf("Unable to insert test data [id=%v/err=%v]", nrd.id, err.Error()) + return fmt.Errorf("unable to insert test data [id=%v/err=%v]", nrd.id, err.Error()) } } // Done @@ -235,7 +179,7 @@ func readTestData(ctx context.Context, db *postgres.Database) error { compareRd := genTestRowDef(idx, false) rd, err := readTestRowDef(ctx, db, compareRd.id) if err != nil { - return fmt.Errorf("Unable to verify test data [id=%v/err=%v]", compareRd.id, err.Error()) + return fmt.Errorf("unable to verify test data [id=%v/err=%v]", compareRd.id, err.Error()) } // Do deep comparison if !reflect.DeepEqual(compareRd, rd) { @@ -245,12 +189,12 @@ func readTestData(ctx context.Context, db *postgres.Database) error { compareNrd := genTestNullableRowDef(idx, false) nrd, err := readTestNullableRowDef(ctx, db, compareNrd.id) if err != nil { - return fmt.Errorf("Unable to verify test data [id=%v/err=%v]", compareNrd.id, err.Error()) + return fmt.Errorf("unable to verify test data [id=%v/err=%v]", compareNrd.id, err.Error()) } // Do deep comparison if !reflect.DeepEqual(compareNrd, nrd) { - return fmt.Errorf("Data mismatch while comparing test data [id=%v]", compareNrd.id) + return fmt.Errorf("data mismatch while comparing test data [id=%v]", compareNrd.id) } } @@ -265,17 +209,17 @@ func readMultiTestData(ctx context.Context, db *postgres.Database) error { } rd, err := readMultiTestRowDef(ctx, db, compareRd) if err != nil { - return fmt.Errorf("Unable to verify test data [err=%v]", err.Error()) + return fmt.Errorf("unable to verify test data [err=%v]", err.Error()) } // Do deep comparison if len(compareRd) != len(rd) { - return fmt.Errorf("Data mismatch while comparing test data [len1=%d/len2=%d]", len(compareRd), len(rd)) + return fmt.Errorf("data mismatch while comparing test data [len1=%d/len2=%d]", len(compareRd), len(rd)) } for idx := 0; idx < len(rd); idx++ { if !reflect.DeepEqual(compareRd[idx], rd[idx]) { - return fmt.Errorf("Data mismatch while comparing test data [id=%v]", compareRd[idx].id) + return fmt.Errorf("data mismatch while comparing test data [id=%v]", compareRd[idx].id) } } @@ -516,22 +460,3 @@ func readTestNullableRowDef(ctx context.Context, db *postgres.Database, id int) // Done return nrd, nil } - -func addressOf[T any](x T) *T { - return &x -} - -func jsonReEncode(src string) (string, error) { - var v interface{} - - err := json.Unmarshal([]byte(src), &v) - if err == nil { - var reencoded []byte - - reencoded, err = json.Marshal(v) - if err == nil { - return string(reencoded), nil - } - } - return "", err -} diff --git a/qodana.yaml b/qodana.yaml new file mode 100644 index 0000000..215d808 --- /dev/null +++ b/qodana.yaml @@ -0,0 +1,29 @@ +#-------------------------------------------------------------------------------# +# Qodana analysis is configured by qodana.yaml file # +# https://www.jetbrains.com/help/qodana/qodana-yaml.html # +#-------------------------------------------------------------------------------# +version: "1.0" + +#Specify inspection profile for code analysis +profile: + name: qodana.starter + +#Enable inspections +#include: +# - name: + +#Disable inspections +#exclude: +# - name: +# paths: +# - + +#Execute shell command before Qodana execution (Applied in CI/CD pipeline) +#bootstrap: sh ./prepare-qodana.sh + +#Install IDE plugins before Qodana execution (Applied in CI/CD pipeline) +#plugins: +# - id: #(plugin id can be found at https://plugins.jetbrains.com) + +#Specify Qodana linter for analysis (Applied in CI/CD pipeline) +linter: jetbrains/qodana-go:latest diff --git a/transaction.go b/transaction.go index 1493daf..980fb1b 100644 --- a/transaction.go +++ b/transaction.go @@ -8,7 +8,7 @@ import ( // ----------------------------------------------------------------------------- -// Tx encloses a transation object. +// Tx encloses a transaction object. type Tx struct { db *Database tx pgx.Tx @@ -23,8 +23,12 @@ func (tx *Tx) DB() *Database { // Exec executes an SQL statement within the transaction. func (tx *Tx) Exec(ctx context.Context, sql string, args ...interface{}) (int64, error) { + affectedRows := int64(0) ct, err := tx.tx.Exec(ctx, sql, args...) - return ct.RowsAffected(), tx.db.processError(err) + if err == nil { + affectedRows = ct.RowsAffected() + } + return affectedRows, tx.db.processError(err) } // QueryRow executes a SQL query within the transaction. @@ -61,3 +65,26 @@ func (tx *Tx) Copy(ctx context.Context, tableName string, columnNames []string, // Done return n, tx.db.processError(err) } + +// WithinTx executes a callback function within the context of a nested transaction. +func (tx *Tx) WithinTx(ctx context.Context, cb WithinTxCallback) error { + innerTx, err := tx.tx.Begin(ctx) + if err == nil { + err = cb(ctx, Tx{ + db: tx.db, + tx: innerTx, + }) + if err == nil { + err = innerTx.Commit(ctx) + if err != nil { + err = newError(err, "unable to commit db transaction") + } + } + if err != nil { + _ = innerTx.Rollback(context.Background()) // Using context.Background() on purpose + } + } else { + err = newError(err, "unable to start transaction") + } + return tx.db.processError(err) +}