Skip to content

Commit

Permalink
Merge pull request #6 from randlabs/updates
Browse files Browse the repository at this point in the history
Updated dependencies and added conn URL parser
  • Loading branch information
mxmauro authored Mar 7, 2024
2 parents cf82d3c + edcdced commit fa37e31
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 34 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@ func main() {

## LICENSE

See `LICENSE` file for details.
See [LICENSE](/LICENSE) file for details.
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ module github.com/randlabs/go-postgres

go 1.19

require github.com/jackc/pgx/v5 v5.5.1
require github.com/jackc/pgx/v5 v5.5.4

require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/sync v0.5.0 // indirect
golang.org/x/text v0.14.0 // indirect
)
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI=
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8=
github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
Expand All @@ -14,8 +14,8 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
Expand Down
102 changes: 94 additions & 8 deletions postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"context"
"errors"
"fmt"
"net/url"
"strconv"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -58,16 +61,19 @@ const (

// New creates a new postgresql database driver.
func New(ctx context.Context, opts Options) (*Database, error) {
var sslMode string

// Create database object
db := Database{}
db.err.mutex = sync.Mutex{}

// Setup basic configuration options
// Validate options
if len(opts.Host) == 0 {
return nil, errors.New("invalid host")
}
if len(opts.User) == 0 {
return nil, errors.New("invalid user name")
}
if len(opts.Name) == 0 {
return nil, errors.New("invalid database name")
}
sslMode := "disable"
switch opts.SSLMode {
case SSLModeDisable:
sslMode = "disable"
case SSLModeAllow:
sslMode = "prefer"
case SSLModeRequired:
Expand All @@ -76,6 +82,10 @@ func New(ctx context.Context, opts Options) (*Database, error) {
return nil, errors.New("invalid SSL mode")
}

// Create database object
db := Database{}
db.err.mutex = sync.Mutex{}

connString := fmt.Sprintf(
"host='%s' port=%d user='%s' password='%s' dbname='%s' sslmode=%s",
encodeDSN(opts.Host), opts.Port, encodeDSN(opts.User), encodeDSN(opts.Password), encodeDSN(opts.Name),
Expand Down Expand Up @@ -110,6 +120,82 @@ func New(ctx context.Context, opts Options) (*Database, error) {
return &db, nil
}

// NewFromURL creates a new postgresql database driver from an URL
func NewFromURL(ctx context.Context, rawUrl string) (*Database, error) {
opts := Options{}

u, err := url.ParseRequestURI(rawUrl)
if err != nil {
return nil, errors.New("invalid url provided")
}

// Check schema
if u.Scheme != "pg" && u.Scheme != "postgres" && u.Scheme != "postgresql" {
return nil, errors.New("invalid url schema")
}

// Check host name and port
opts.Host = u.Hostname()
if len(opts.Host) == 0 {
return nil, errors.New("invalid host")
}
s := u.Port()
if len(s) == 0 {
opts.Port = 5432
} else {
val, err2 := strconv.Atoi(s)
if err2 != nil || val < 1 || val > 65535 {
return nil, errors.New("invalid port")
}
opts.Port = uint16(val)
}

// Check user and password
if u.User == nil {
return nil, errors.New("invalid user name")
}
opts.User = u.User.Username()
if len(opts.User) == 0 {
return nil, errors.New("invalid user name")
}

// Check database name
if len(u.Path) < 1 || (!strings.HasPrefix(u.Path, "/")) || strings.Index(u.Path[1:], "/") >= 0 {
return nil, errors.New("invalid database name")
}
opts.Name = u.Path[1:]

// Check ssl mode
opts.SSLMode = SSLModeDisable
switch u.Query().Get("sslmode") {
case "allow":
opts.SSLMode = SSLModeAllow

case "required":
opts.SSLMode = SSLModeRequired

case "disabled":
fallthrough
case "":

default:
return nil, errors.New("invalid SSL mode")
}

// Check max connections count
s = u.Query().Get("maxconn")
if len(s) > 0 {
val, err2 := strconv.Atoi(s)
if err2 != nil || val < 0 {
return nil, errors.New("invalid max connections count")
}
opts.MaxConns = int32(val)
}

// Create
return New(ctx, opts)
}

// Close shutdown the connection pool
func (db *Database) Close() {
if db.pool != nil {
Expand Down
39 changes: 20 additions & 19 deletions postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ type TestJSON struct {
}

var (
pgUrl string
pgHost string
pgPort uint
pgUsername string
Expand All @@ -82,10 +83,11 @@ var (
// -----------------------------------------------------------------------------

func init() {
flag.StringVar(&pgUrl, "url", "", "Specifies the Postgres URL.")
flag.StringVar(&pgHost, "host", "127.0.0.1", "Specifies the Postgres server host. (Defaults to '127.0.0.1')")
flag.UintVar(&pgPort, "port", 5432, "Specifies the Postgres server port. (Defaults to 5432)")
flag.StringVar(&pgUsername, "user", "postgres", "Specifies the user name. (Defaults to 'postgres')")
flag.StringVar(&pgPassword, "password", "", "Specifies the user passwonrd.")
flag.StringVar(&pgPassword, "password", "", "Specifies the user password.")
flag.StringVar(&pgDatabaseName, "db", "", "Specifies the database name.")

testJSON = TestJSON{
Expand All @@ -102,56 +104,55 @@ func init() {
// -----------------------------------------------------------------------------

func TestPostgres(t *testing.T) {
var db *postgres.Database
var err error

// Parse and check command-line parameters
flag.Parse()
checkSettings(t)

ctx := context.Background()

// Create database driver
db, err := postgres.New(context.Background(), postgres.Options{
Host: pgHost,
Port: uint16(pgPort),
User: pgUsername,
Password: pgPassword,
Name: pgDatabaseName,
})
if len(pgUrl) > 0 {
db, err = postgres.NewFromURL(ctx, pgUrl)
} else {
db, err = postgres.New(ctx, postgres.Options{
Host: pgHost,
Port: uint16(pgPort),
User: pgUsername,
Password: pgPassword,
Name: pgDatabaseName,
})
}
if err != nil {
t.Fatalf("%v", err.Error())
}
// We comment the next defer line because we want to do a clean database pool shutdown on errors and
// calling fatal exits the process.
// defer db.Close()

ctx := context.Background()
defer db.Close()

t.Log("Creating test table")
err = createTestTable(ctx, db)
if err != nil {
db.Close()
t.Fatalf("%v", err.Error())
}

t.Log("Inserting test data")
err = insertTestData(ctx, db)
if err != nil {
db.Close()
t.Fatalf("%v", err.Error())
}

t.Log("Reading test data")
err = readTestData(ctx, db)
if err != nil {
db.Close()
t.Fatalf("%v", err.Error())
}

t.Log("Reading test data (multi-row)")
err = readMultiTestData(ctx, db)
if err != nil {
db.Close()
t.Fatalf("%v", err.Error())
}

db.Close()
}

// -----------------------------------------------------------------------------
Expand Down

0 comments on commit fa37e31

Please sign in to comment.