diff --git a/cmd/worklog/pgstore/db.go b/cmd/worklog/pgstore/db.go new file mode 100644 index 0000000..c6dd1c5 --- /dev/null +++ b/cmd/worklog/pgstore/db.go @@ -0,0 +1,830 @@ +// Copyright ©2024 Dan Kortschak. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package pgstore provides a worklog data storage layer using PostgreSQL. +package pgstore + +import ( + "compress/gzip" + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/url" + "os" + "path" + "path/filepath" + "sort" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "golang.org/x/sys/execabs" + + worklog "github.com/kortschak/dex/cmd/worklog/api" +) + +// DB is a persistent store. +type DB struct { + name string + host string + store *pgx.Conn + roStore *pgx.Conn +} + +type execer interface { + Exec(ctx context.Context, query string, args ...any) (pgconn.CommandTag, error) +} + +type result struct { + pgconn.CommandTag +} + +func (r result) RowsAffected() (int64, error) { + return r.CommandTag.RowsAffected(), nil +} + +func (r result) LastInsertId() (int64, error) { + return 0, nil +} + +type querier interface { + Query(ctx context.Context, query string, args ...any) (pgx.Rows, error) +} + +func txDone(ctx context.Context, tx pgx.Tx, err *error) { + if *err == nil { + *err = tx.Commit(ctx) + } else { + *err = errors.Join(*err, tx.Rollback(ctx)) + } +} + +// Open opens a PostgresSQL DB. See [pgx.Connect] for name handling details. +// Two connections to the database are created, one using the username within +// the name parameter and one with the same username, but "_ro" appended. If +// the second user does not have SELECT role_table_grants for the buckets' and +// 'events' the database, or has any non-SELECT role_table_grants, the second +// connection will be closed and a [Warning] error will be returned. If the +// second connection is closed, the [DB.Select] method will return a non-nil +// error and perform no DB operation. Open attempts to get the CNAME for the +// host, which may wait indefinitely, so a timeout context can be provided to +// fall back to the kernel-provided hostname. +func Open(ctx context.Context, name, host string) (*DB, error) { + u, err := url.Parse(name) + if err != nil { + return nil, err + } + if host == "" { + host, err = hostname(ctx) + if err != nil { + return nil, err + } + } + + db, err := pgx.Connect(ctx, name) + if err != nil { + return nil, err + } + _, err = db.Exec(ctx, Schema) + if err != nil { + return nil, errors.Join(err, db.Close(ctx)) + } + + var ( + dbRO *pgx.Conn + warn error + ) + if u.User != nil { + roUser := u.User.Username() + "_ro" + pass, ok := u.User.Password() + if ok { + u.User = url.UserPassword(roUser, pass) + } else { + u.User = url.User(roUser) + } + const ( + nonSelects = `select + count(*) = 0 + from + information_schema.role_table_grants + where + not privilege_type ilike 'SELECT' + and grantee = $1` + selects = `select + count(distinct table_name) = 2 + from + information_schema.role_table_grants + where + privilege_type ilike 'SELECT' and table_name in ('buckets', 'events') + and grantee = $1` + ) + for _, check := range []string{ + nonSelects, + selects, + } { + var ok bool + err = db.QueryRow(ctx, check, roUser).Scan(&ok) + if err != nil { + db.Close(ctx) + return nil, err + } + if !ok { + warn = Warning{errors.New("ro user failed capability checks")} + break + } + } + if warn == nil { + dbRO, warn = pgx.Connect(ctx, u.String()) + if warn != nil { + warn = Warning{warn} + dbRO = nil + } + } + u.User = nil + } + + return &DB{name: u.String(), host: host, store: db, roStore: dbRO}, warn +} + +// Warning is a warning-only error. +type Warning struct { + error +} + +// hostname returns the FQDN of the local host, falling back to the hostname +// reported by the kernel if CNAME lookup fails. +func hostname(ctx context.Context) (string, error) { + host, err := os.Hostname() + if err != nil { + return "", err + } + cname, err := net.DefaultResolver.LookupCNAME(ctx, host) + if err != nil { + return host, nil + } + return strings.TrimSuffix(cname, "."), nil +} + +// Name returns the name of the database as provided to Open. +func (db *DB) Name() string { + if db == nil { + return "" + } + return db.name +} + +// Backup creates a backup of the DB using the pg_dump command into the provided +// directory as a gzip file. It returns the path of the backup. +func (db *DB) Backup(ctx context.Context, dir string) (string, error) { + u, err := url.Parse(db.name) + if err != nil { + return "", err + } + host, port, err := net.SplitHostPort(u.Host) + if err != nil { + return "", err + } + dbname := path.Base(u.Path) + + dst := filepath.Join(dir, dbname+"_"+time.Now().In(time.UTC).Format("20060102150405")+".gz") + cmd := execabs.Command("pg_dump", "-h", host, "-p", port, dbname) + f, err := os.OpenFile(dst, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o600) + if err != nil { + return "", err + } + w := gzip.NewWriter(f) + cmd.Stdout = w + err = errors.Join(cmd.Run(), w.Close(), f.Sync(), f.Close()) + + return dst, err +} + +// Close closes the database. +func (db *DB) Close(ctx context.Context) error { + var roErr error + if db.roStore != nil { + roErr = db.roStore.Close(ctx) + } + return errors.Join(db.store.Close(ctx), roErr) +} + +// Schema is the DB schema. +const Schema = ` +create table if not exists buckets ( + rowid SERIAL PRIMARY KEY, + id TEXT UNIQUE NOT NULL, + name TEXT NOT NULL, + type TEXT NOT NULL, + client TEXT NOT NULL, + hostname TEXT NOT NULL, + created TIMESTAMP WITH TIME ZONE NOT NULL, + timezone TEXT NOT NULL, -- tz of created, not mutated after first write + datastr JSONB NOT NULL +); +create table if not exists events ( + id SERIAL PRIMARY KEY, + bucketrow INTEGER NOT NULL, + starttime TIMESTAMP WITH TIME ZONE NOT NULL, + endtime TIMESTAMP WITH TIME ZONE NOT NULL, + timezone TEXT NOT NULL, -- tz of starttime, not mutated after first write + datastr JSONB NOT NULL, + FOREIGN KEY (bucketrow) REFERENCES buckets(rowid) +); +create index if not exists event_index_id ON events(id); +create index if not exists event_index_starttime ON events(bucketrow, starttime); +create index if not exists event_index_endtime ON events(bucketrow, endtime); +` + +const tzFormat = "-07:00" + +// BucketID returns the internal bucket ID for the provided bucket uid. +func (db *DB) BucketID(uid string) string { + return fmt.Sprintf("%s_%s", uid, db.host) +} + +const CreateBucket = `insert into buckets(id, name, type, client, hostname, created, timezone, datastr) values ($1, $2, $3, $4, $5, $6, $7, $8) +on conflict (id) do nothing;` + +// CreateBucket creates a new entry in the bucket table. If the entry already +// exists it will return an sqlite.Error with the code sqlite3.SQLITE_CONSTRAINT_UNIQUE. +// The SQL command run is [CreateBucket]. +func (db *DB) CreateBucket(ctx context.Context, uid, name, typ, client string, created time.Time, data map[string]any) (m *worklog.BucketMetadata, err error) { + bid := db.BucketID(uid) + tx, err := db.store.Begin(ctx) + if err != nil { + return nil, err + } + defer txDone(ctx, tx, &err) + return createBucket(ctx, tx, bid, name, typ, client, db.host, created, data) +} + +func createBucket(ctx context.Context, tx pgx.Tx, bid, name, typ, client, host string, created time.Time, data map[string]any) (*worklog.BucketMetadata, error) { + if data == nil { + // datastr has a NOT NULL constraint. + data = make(map[string]any) + } + _, err := tx.Exec(ctx, CreateBucket, bid, name, typ, client, host, created, created.Format(tzFormat), data) + if err != nil { + return nil, err + } + m, err := bucketMetadata(ctx, tx, bid) + if err != nil { + return nil, err + } + return m, nil +} + +const BucketMetadata = `select id, name, type, client, hostname, created, timezone, datastr from buckets where id = $1` + +// BucketMetadata returns the metadata for the bucket with the provided internal +// bucket ID. +// The SQL command run is [BucketMetadata]. +func (db *DB) BucketMetadata(ctx context.Context, bid string) (*worklog.BucketMetadata, error) { + return bucketMetadata(ctx, db.store, bid) +} + +func bucketMetadata(ctx context.Context, db querier, bid string) (*worklog.BucketMetadata, error) { + rows, err := db.Query(ctx, BucketMetadata, bid) + if err != nil { + return nil, err + } + defer rows.Close() + if !rows.Next() { + return nil, io.EOF + } + var ( + m worklog.BucketMetadata + tz string + ) + err = rows.Scan(&m.ID, &m.Name, &m.Type, &m.Client, &m.Hostname, &m.Created, &tz, &m.Data) + if err != nil { + return nil, err + } + timezone, err := time.Parse(tzFormat, tz) + if err != nil { + return &m, fmt.Errorf("invalid timezone for %s bucket %s: %w", m.Name, m.ID, err) + } + m.Created = m.Created.In(timezone.Location()) + if rows.Next() { + return &m, errors.New("unexpected item") + } + return &m, nil +} + +const InsertEvent = `insert into events(bucketrow, starttime, endtime, timezone, datastr) values ((select rowid from buckets where id = $1), $2, $3, $4, $5)` + +// InsertEvent inserts a new event into the events table. +// The SQL command run is [InsertEvent]. +func (db *DB) InsertEvent(ctx context.Context, e *worklog.Event) (sql.Result, error) { + bid := fmt.Sprintf("%s_%s", e.Bucket, db.host) + return insertEvent(ctx, db.store, bid, e) +} + +func insertEvent(ctx context.Context, db execer, bid string, e *worklog.Event) (sql.Result, error) { + res, err := db.Exec(ctx, InsertEvent, bid, e.Start, e.End, e.Start.Format(tzFormat), e.Data) + return result{res}, err +} + +const UpdateEvent = `update events set starttime = $1, endtime = $2, datastr = $3 where id = $4 and bucketrow = ( + select rowid from buckets where id = $5 +)` + +// UpdateEvent updates the event in the store corresponding to the provided +// event. +// The SQL command run is [UpdateEvent]. +func (db *DB) UpdateEvent(ctx context.Context, e *worklog.Event) (sql.Result, error) { + bid := fmt.Sprintf("%s_%s", e.Bucket, db.host) + res, err := db.store.Exec(ctx, UpdateEvent, e.Start, e.End, e.Data, e.ID, bid) + return result{res}, err +} + +const LastEvent = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) and endtime = ( + select max(endtime) from events where bucketrow = ( + select rowid from buckets where id = $1 + ) limit 1 +) limit 1` + +// LastEvent returns the last event in the named bucket. +// The SQL command run is [LastEvent]. +func (db *DB) LastEvent(ctx context.Context, uid string) (*worklog.Event, error) { + bid := db.BucketID(uid) + rows, err := db.store.Query(ctx, LastEvent, bid) + if err != nil { + return nil, err + } + defer rows.Close() + if !rows.Next() { + return nil, io.EOF + } + var ( + e worklog.Event + tz string + ) + err = rows.Scan(&e.ID, &e.Start, &e.End, &tz, &e.Data) + if err != nil { + return nil, err + } + timezone, err := time.Parse(tzFormat, tz) + if err != nil { + return &e, fmt.Errorf("invalid timezone for event %d: %w", e.ID, err) + } + loc := timezone.Location() + e.Start = e.Start.In(loc) + e.End = e.End.In(loc) + if rows.Next() { + return &e, errors.New("unexpected item") + } + return &e, nil +} + +// Dump dumps the complete database into a slice of [worklog.BucketMetadata]. +func (db *DB) Dump(ctx context.Context) ([]worklog.BucketMetadata, error) { + m, err := db.buckets(ctx) + if err != nil { + return nil, err + } + for i, b := range m { + bucket, ok := strings.CutSuffix(b.ID, "_"+b.Hostname) + if !ok { + return m, fmt.Errorf("invalid bucket ID at %d: %s", i, b.ID) + } + e, err := db.events(ctx, b.ID) + if err != nil { + return m, err + } + for j := range e { + e[j].Bucket = bucket + } + m[i].Events = e + } + return m, nil +} + +// DumpRange dumps the database spanning the specified time range into a slice +// of [worklog.BucketMetadata]. +func (db *DB) DumpRange(ctx context.Context, start, end time.Time) ([]worklog.BucketMetadata, error) { + m, err := db.buckets(ctx) + if err != nil { + return nil, err + } + for i, b := range m { + bucket, ok := strings.CutSuffix(b.ID, "_"+b.Hostname) + if !ok { + return m, fmt.Errorf("invalid bucket ID at %d: %s", i, b.ID) + } + e, err := db.dumpEventsRange(ctx, b.ID, start, end, nil) + if err != nil { + return m, err + } + for j := range e { + e[j].Bucket = bucket + } + m[i].Events = e + } + return m, nil +} + +const ( + dumpEventsRange = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) and endtime >= $2 and starttime <= $3 limit $4` + + dumpEventsRangeUntil = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) and starttime <= $2 limit $3` + + dumpEventsRangeFrom = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) and endtime >= $2 limit $3` + + dumpEventsLimit = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) limit $2` +) + +func (db *DB) dumpEventsRange(ctx context.Context, bid string, start, end time.Time, limit *int) ([]worklog.Event, error) { + var e []worklog.Event + err := db.eventsRangeFunc(ctx, bid, start, end, limit, func(m worklog.Event) error { + e = append(e, m) + return nil + }, false) + return e, err +} + +// Load loads a complete database from a slice of [worklog.BucketMetadata]. +// Event IDs will be regenerated by the backing database and so will not +// match the input data. If replace is true and a bucket already exists matching +// the bucket in the provided buckets slice, the existing events will be +// deleted and replaced. If replace is false, the new events will be added to +// the existing events in the store. +func (db *DB) Load(ctx context.Context, buckets []worklog.BucketMetadata, replace bool) (err error) { + tx, err := db.store.Begin(ctx) + if err != nil { + return err + } + defer txDone(ctx, tx, &err) + for _, m := range buckets { + var b *worklog.BucketMetadata + b, err = createBucket(ctx, tx, m.ID, m.Name, m.Type, m.Client, m.Hostname, m.Created, m.Data) + if !sameBucket(&m, b) { + return fmt.Errorf("mismatched bucket: %s != %s", bucketString(&m), bucketString(b)) + } + if replace { + _, err = tx.Exec(ctx, DeleteBucketEvents, m.ID) + if err != nil { + return err + } + } + for i, e := range m.Events { + bid := fmt.Sprintf("%s_%s", e.Bucket, m.Hostname) + _, err = insertEvent(ctx, tx, bid, &m.Events[i]) + if err != nil { + return err + } + } + } + return nil +} + +func sameBucket(a, b *worklog.BucketMetadata) bool { + return a.ID == b.ID && + a.Name == b.Name && + a.Type == b.Type && + a.Client == b.Client && + a.Hostname == b.Hostname +} + +func bucketString(b *worklog.BucketMetadata) string { + return fmt.Sprintf("{id=%s name=%s type=%s client=%s hostname=%s}", b.ID, + b.Name, + b.Type, + b.Client, + b.Hostname) +} + +const Buckets = `select id, name, type, client, hostname, created, timezone, datastr from buckets` + +// Buckets returns the full set of bucket metadata. +// The SQL command run is [Buckets]. +func (db *DB) Buckets(ctx context.Context) ([]worklog.BucketMetadata, error) { + return db.buckets(ctx) +} + +func (db *DB) buckets(ctx context.Context) ([]worklog.BucketMetadata, error) { + rows, err := db.store.Query(ctx, Buckets) + if err != nil { + return nil, err + } + defer rows.Close() + var b []worklog.BucketMetadata + for rows.Next() { + var ( + m worklog.BucketMetadata + tz string + ) + err = rows.Scan(&m.ID, &m.Name, &m.Type, &m.Client, &m.Hostname, &m.Created, &tz, &m.Data) + if err != nil { + return nil, err + } + timezone, err := time.Parse(tzFormat, tz) + if err != nil { + return nil, fmt.Errorf("invalid timezone for %s bucket %s: %w", m.Name, m.ID, err) + } + m.Created = m.Created.In(timezone.Location()) + b = append(b, m) + } + return b, nil +} + +const Event = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) and id = $2 limit 1` + +const Events = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +)` + +// Buckets returns the full set of events in the bucket with the provided +// internal bucket ID. +// The SQL command run is [Events]. +func (db *DB) Events(ctx context.Context, bid string) ([]worklog.Event, error) { + return db.events(ctx, bid) +} + +func (db *DB) events(ctx context.Context, bid string) ([]worklog.Event, error) { + rows, err := db.store.Query(ctx, Events, bid) + if err != nil { + return nil, err + } + defer rows.Close() + var e []worklog.Event + for rows.Next() { + var ( + m worklog.Event + tz string + ) + err = rows.Scan(&m.ID, &m.Start, &m.End, &tz, &m.Data) + if err != nil { + return nil, err + } + timezone, err := time.Parse(tzFormat, tz) + if err != nil { + return nil, fmt.Errorf("invalid timezone for event %d: %w", m.ID, err) + } + loc := timezone.Location() + m.Start = m.Start.In(loc) + m.End = m.End.In(loc) + e = append(e, m) + } + return e, nil +} + +const ( + EventsRange = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) and endtime >= $2 and starttime <= $3 order by endtime desc limit $4` + + EventsRangeUntil = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) and starttime <= $2 order by endtime desc limit $3` + + EventsRangeFrom = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) and endtime >= $2 order by endtime desc limit $3` + + EventsLimit = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) order by endtime desc limit $2` +) + +// EventsRange returns the events in the bucket with the provided bucket ID +// within the specified time range, sorted descending by end time. +// The SQL command run is [EventsRange], [EventsRangeUntil], [EventsRangeFrom] +// or [EventsLimit] depending on whether start and end are zero. +func (db *DB) EventsRange(ctx context.Context, bid string, start, end time.Time, limit int) ([]worklog.Event, error) { + var lim *int + if limit >= 0 { + lim = &limit + } + var e []worklog.Event + err := db.eventsRangeFunc(ctx, bid, start, end, lim, func(m worklog.Event) error { + e = append(e, m) + return nil + }, true) + return e, err +} + +// EventsRange calls fn on all the events in the bucket with the provided +// bucket ID within the specified time range, sorted descending by end time. +// The SQL command run is [EventsRange], [EventsRangeUntil], [EventsRangeFrom] +// or [EventsLimit] depending on whether start and end are zero. +func (db *DB) EventsRangeFunc(ctx context.Context, bid string, start, end time.Time, limit int, fn func(worklog.Event) error) error { + var lim *int + if limit >= 0 { + lim = &limit + } + return db.eventsRangeFunc(ctx, bid, start, end, lim, fn, true) +} + +func (db *DB) eventsRangeFunc(ctx context.Context, bid string, start, end time.Time, limit *int, fn func(worklog.Event) error, order bool) error { + var ( + query string + rows pgx.Rows + err error + ) + switch { + case !start.IsZero() && !end.IsZero(): + query = EventsRange + if !order { + query = dumpEventsRange + } + rows, err = db.store.Query(ctx, query, bid, start, end, limit) + case !start.IsZero(): + query = EventsRangeFrom + if !order { + query = dumpEventsRangeFrom + } + rows, err = db.store.Query(ctx, query, bid, start, limit) + case !end.IsZero(): + query = EventsRangeUntil + if !order { + query = dumpEventsRangeUntil + } + rows, err = db.store.Query(ctx, query, bid, end, limit) + default: + query = EventsLimit + if !order { + query = dumpEventsLimit + } + rows, err = db.store.Query(ctx, query, bid, limit) + } + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var ( + m worklog.Event + tz string + ) + err = rows.Scan(&m.ID, &m.Start, &m.End, &tz, &m.Data) + if err != nil { + return err + } + timezone, err := time.Parse(tzFormat, tz) + if err != nil { + return fmt.Errorf("invalid timezone for event %d: %w", m.ID, err) + } + loc := timezone.Location() + m.Start = m.Start.In(loc) + m.End = m.End.In(loc) + err = fn(m) + if err != nil { + return err + } + } + return nil +} + +// Select allows running an SQLite SELECT query. The query is run on a read-only +// connection to the database. +func (db *DB) Select(query string) ([]map[string]any, error) { + if db.roStore == nil { + return nil, errors.New("no read-only connection") + } + + rows, err := db.roStore.Query(context.Background(), query) + if err != nil { + return nil, err + } + defer rows.Close() + + descs := rows.FieldDescriptions() + var e []map[string]any + for rows.Next() { + cols, err := rows.Values() + if err != nil { + return nil, err + } + result := make(map[string]any) + for i, v := range cols { + result[descs[i].Name] = v + } + e = append(e, result) + } + return e, rows.Err() +} + +const AmendEventsPrepare = `update events set datastr = jsonb_set(datastr, '{amend}', '[]') + where + starttime < $3 and + endtime > $2 and + not datastr::jsonb ? 'amend' and + bucketrow = ( + select rowid from buckets where id = $1 + );` +const AmendEventsUpdate = `update events set datastr = jsonb_set( + datastr, + '{amend}', + datastr->'amend' || jsonb_build_object( + 'time', $2::text, + 'msg', $3::text, + 'replace', ( + with replace as ( + select jsonb($6::text) replacements + ) + select + jsonb_agg(new order by idx) trimmed_replacements + from + replace, lateral ( + select idx, jsonb_object_agg(key, + case + when key = 'start' + then to_jsonb(greatest(old::text::timestamptz, starttime)) + when key = 'end' + then to_jsonb(least(old::text::timestamptz, endtime)) + else old + end + ) + from + jsonb_array_elements(replacements) + with ordinality rs(r, idx), + jsonb_each(r) each(key, old) + where + (r->>'start')::timestamptz < endtime and + (r->>'end')::timestamptz > starttime + group BY idx + ) news(idx, new) + ) + ) + ) + where + starttime < $5 and + endtime > $4 and + bucketrow = ( + select rowid from buckets where id = $1 + );` + +// AmendEvents adds amendment notes to the data for events in the store +// overlapping the note. On return the note.Replace slice will be sorted. +// +// The SQL commands run are [AmendEventsPrepare] and [AmendEventsUpdate] +// in a transaction. +func (db *DB) AmendEvents(ctx context.Context, ts time.Time, note *worklog.Amendment) (sql.Result, error) { + if len(note.Replace) == 0 { + return driver.RowsAffected(0), nil + } + sort.Slice(note.Replace, func(i, j int) bool { + return note.Replace[i].Start.Before(note.Replace[j].Start) + }) + start := note.Replace[0].Start + end := note.Replace[0].End + for i, r := range note.Replace[1:] { + if note.Replace[i].End.After(r.Start) { + return nil, fmt.Errorf("overlapping replacements: [%d].end (%s) is after [%d].start (%s)", + i, note.Replace[i].End.Format(time.RFC3339), i+1, r.Start.Format(time.RFC3339)) + } + if r.End.After(end) { + end = r.End + } + } + replace, err := json.Marshal(note.Replace) + if err != nil { + return nil, err + } + var res pgconn.CommandTag + err = pgx.BeginFunc(ctx, db.store, func(tx pgx.Tx) error { + _, err = tx.Exec(ctx, AmendEventsPrepare, db.BucketID(note.Bucket), start.Format(time.RFC3339Nano), end.Format(time.RFC3339Nano)) + if err != nil { + return fmt.Errorf("prepare amendment list: %w", err) + } + res, err = tx.Exec(ctx, AmendEventsUpdate, db.BucketID(note.Bucket), ts.Format(time.RFC3339Nano), note.Message, start.Format(time.RFC3339Nano), end.Format(time.RFC3339Nano), replace) + if err != nil { + return fmt.Errorf("add amendments: %w", err) + } + return nil + }) + if err != nil { + return nil, err + } + return result{res}, err +} + +const DeleteEvent = `delete from events where bucketrow = ( + select rowid from buckets where id = $1 +) and id = $2` + +const DeleteBucketEvents = `delete from events where bucketrow in ( + select rowid from buckets where id = $1 +)` + +const DeleteBucket = `delete from buckets where id = $1` diff --git a/cmd/worklog/pgstore/db_test.go b/cmd/worklog/pgstore/db_test.go new file mode 100644 index 0000000..1b2bda6 --- /dev/null +++ b/cmd/worklog/pgstore/db_test.go @@ -0,0 +1,825 @@ +// Copyright ©2024 Dan Kortschak. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pgstore + +import ( + "bufio" + "compress/gzip" + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "io/fs" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/jackc/pgx/v5" + "golang.org/x/exp/slices" + + worklog "github.com/kortschak/dex/cmd/worklog/api" +) + +var ( + verbose = flag.Bool("verbose_log", false, "print full logging") + lines = flag.Bool("show_lines", false, "log source code position") + keep = flag.Bool("keep", false, "keep test database after tests") +) + +const testDir = "testdata" + +func TestDB(t *testing.T) { + if !*keep { + t.Cleanup(func() { + os.RemoveAll(testDir) + }) + } + const dbBaseName = "test_worklog_database" + + pgHost := os.Getenv("PGHOST") + if pgHost == "" { + t.Fatal("must have postgres host in $PGHOST") + } + pgPort := os.Getenv("PGPORT") + if pgPort == "" { + t.Fatal("must have postgres port number in $PGPORT") + } + pgHostPort := pgHost + ":" + pgPort + + pgUser := os.Getenv("PGUSER") + if pgUser == "" { + t.Fatal("must have postgres user in $PGUSER") + } + var pgUserinfo *url.Userinfo + pgPassword := os.Getenv("PGPASSWORD") + if pgPassword != "" { + pgUserinfo = url.UserPassword(pgUser, pgPassword) + } else { + home, err := os.UserHomeDir() + if err != nil { + t.Fatalf("could not get home directory: %v", err) + } + pgpass, err := os.Open(filepath.Join(home, ".pgpass")) + if err != nil { + t.Fatalf("could not open .pgpass: %v", err) + } + defer pgpass.Close() + fi, err := pgpass.Stat() + if err != nil { + t.Fatalf("could not stat .pgpass: %v", err) + } + if fi.Mode()&0o077 != 0o000 { + t.Fatalf(".pgpass permissions too relaxed: %s", fi.Mode()) + } + sc := bufio.NewScanner(pgpass) + found := false + for sc.Scan() { + line := strings.TrimSpace(sc.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + e, err := parsePgPassLine(line) + if err != nil { + t.Fatalf("could not parse .pgpass: %v", err) + } + if e.match(pgUser, pgHost, pgPort, "*") { + found = true + break + } + } + if sc.Err() != nil { + t.Fatalf("unexpected error reading .pgpass: %v", err) + } + if !found { + t.Fatal("must have postgres password in $PGPASSWORD or .pgpass") + } + + pgUserinfo = url.User(pgUser) + } + + ctx := context.Background() + for _, interval := range []struct { + name string + duration time.Duration + }{ + {name: "second", duration: time.Second}, + {name: "subsecond", duration: time.Second / 10}, + } { + dbName := dbBaseName + "_" + interval.name + dropTestDB(t, ctx, pgUserinfo, pgHostPort, dbName) + drop := createTestDB(t, ctx, pgUserinfo, pgHostPort, dbName) + if !*keep { + defer drop() + } + + t.Run(interval.name, func(t *testing.T) { + t.Run("db", func(t *testing.T) { + u := url.URL{ + Scheme: "postgres", + User: url.UserPassword(pgUser, pgPassword), + Host: pgHostPort, + Path: dbName, + } + dbURL := u.String() + + db, err := Open(ctx, dbURL, "test_host") + if err != nil && !errors.As(err, &Warning{}) { + t.Fatalf("failed to open db: %v", err) + } + err = db.Close(ctx) + if err != nil { + t.Fatalf("failed to close db: %v", err) + } + + now := time.Now().Round(time.Millisecond) + bucket := "test_bucket" + data := []worklog.BucketMetadata{ + { + ID: "test_bucket_test_host", + Name: "test_bucket_name", + Type: "test_bucket_type", + Client: "testing", + Hostname: "test_host", + Created: now, + Data: map[string]any{"key0": "value0"}, + Events: []worklog.Event{ + { + Bucket: bucket, + Start: now.Add(1 * interval.duration), + End: now.Add(2 * interval.duration), + Data: map[string]any{"key1": "value1"}, + }, + { + Bucket: bucket, + Start: now.Add(3 * interval.duration), + End: now.Add(4 * interval.duration), + Data: map[string]any{"key2": "value2"}, + }, + { + Bucket: bucket, + Start: now.Add(5 * interval.duration), + End: now.Add(6 * interval.duration), + Data: map[string]any{"key3": "value3"}, + }, + }, + }, + } + + // worklog.Event.ID is the only int64 field, so + // this is easier than filtering on field name. + ignoreID := cmp.FilterValues( + func(_, _ int64) bool { return true }, + cmp.Ignore(), + ) + + db, err = Open(ctx, dbURL, "test_host") + if err != nil && !errors.As(err, &Warning{}) { + t.Fatalf("failed to open db: %v", err) + } + defer func() { + err = db.Close(ctx) + if err != nil { + t.Errorf("failed to close db: %v", err) + } + }() + + t.Run("load_dump", func(t *testing.T) { + err = db.Load(ctx, data, false) + if err != nil { + t.Errorf("failed to load data no replace: %v", err) + } + err = db.Load(ctx, data, true) + if err != nil { + t.Errorf("failed to load data replace: %v", err) + } + + got, err := db.Dump(ctx) + if err != nil { + t.Errorf("failed to dump data: %v", err) + } + + want := data + if !cmp.Equal(want, got, ignoreID) { + t.Errorf("unexpected dump result:\n--- want:\n+++ got:\n%s", cmp.Diff(want, got, ignoreID)) + } + + gotRange, err := db.DumpRange(ctx, now.Add(3*interval.duration), now.Add(4*interval.duration)) + if err != nil { + t.Errorf("failed to dump data: %v", err) + } + + wantRange := slices.Clone(data) + wantRange[0].Events = wantRange[0].Events[1:2] + if !cmp.Equal(wantRange, gotRange, ignoreID) { + t.Errorf("unexpected dump range result:\n--- want:\n+++ got:\n%s", cmp.Diff(wantRange, gotRange, ignoreID)) + } + + gotRangeFrom, err := db.DumpRange(ctx, now.Add(3*interval.duration), time.Time{}) + if err != nil { + t.Errorf("failed to dump data: %v", err) + } + + wantRangeFrom := slices.Clone(data) + wantRangeFrom[0].Events = wantRangeFrom[0].Events[1:] + if !cmp.Equal(wantRangeFrom, gotRangeFrom, ignoreID) { + t.Errorf("unexpected dump range result:\n--- want:\n+++ got:\n%s", cmp.Diff(wantRangeFrom, gotRangeFrom, ignoreID)) + } + + gotRangeUntil, err := db.DumpRange(ctx, time.Time{}, now.Add(4*interval.duration)) + if err != nil { + t.Errorf("failed to dump data: %v", err) + } + + wantRangeUntil := slices.Clone(data) + wantRangeUntil[0].Events = wantRangeUntil[0].Events[:2] + if !cmp.Equal(wantRangeUntil, gotRangeUntil, ignoreID) { + t.Errorf("unexpected dump range result:\n--- want:\n+++ got:\n%s", cmp.Diff(wantRangeUntil, gotRangeUntil, ignoreID)) + } + + gotRangeAll, err := db.DumpRange(ctx, time.Time{}, time.Time{}) + if err != nil { + t.Errorf("failed to dump data: %v", err) + } + + wantRangeAll := slices.Clone(data) + if !cmp.Equal(wantRangeAll, gotRangeAll, ignoreID) { + t.Errorf("unexpected dump range result:\n--- want:\n+++ got:\n%s", cmp.Diff(wantRangeAll, gotRangeAll, ignoreID)) + } + + }) + + t.Run("backup", func(t *testing.T) { + workDir := filepath.Join(testDir, interval.name) + err := os.MkdirAll(workDir, 0o755) + if err != nil && !errors.Is(err, fs.ErrExist) { + t.Fatalf("failed to make dir: %v", err) + } + + path, err := db.Backup(ctx, workDir) + if err != nil { + t.Errorf("failed to backup database: %v", err) + } + f, err := os.Open(path) + if errors.Is(err, fs.ErrNotExist) { + t.Error("did not create backup") + } else if err != nil { + t.Errorf("unexpected error opening backup file: %v", err) + } + r, err := gzip.NewReader(f) + if err != nil { + t.Errorf("unexpected error opening gzip backup file: %v", err) + } + _, err = io.Copy(io.Discard, r) + if err != nil { + t.Errorf("unexpected error gunzipping backup file: %v", err) + } + }) + + t.Run("last_event", func(t *testing.T) { + got, err := db.LastEvent(ctx, bucket) + if err != nil { + t.Fatalf("failed to get last event: %v", err) + } + got.Bucket = bucket + + want := &data[0].Events[len(data[0].Events)-1] + if !cmp.Equal(want, got, ignoreID) { + t.Errorf("unexpected result:\n--- want:\n+++ got:\n%s", cmp.Diff(want, got, ignoreID)) + } + }) + + t.Run("update_last_event", func(t *testing.T) { + e := data[0].Events[len(data[0].Events)-1] + e.End = e.End.Add(interval.duration) + last, err := db.LastEvent(ctx, e.Bucket) + if err != nil { + t.Fatalf("failed to get last event: %v", err) + } + e.ID = last.ID + _, err = db.UpdateEvent(ctx, &e) + if err != nil { + t.Fatalf("failed to update event: %v", err) + } + if err != nil { + t.Errorf("failed to update event: %v", err) + } + got, err := db.LastEvent(ctx, bucket) + if err != nil { + t.Errorf("failed to get last event: %v", err) + } + got.Bucket = bucket + + want := &e + if !cmp.Equal(want, got, ignoreID) { + t.Errorf("unexpected result:\n--- want:\n+++ got:\n%s", cmp.Diff(want, got, ignoreID)) + } + }) + + t.Run("events_range", func(t *testing.T) { + bid := db.BucketID(bucket) + for _, loc := range []*time.Location{time.Local, time.UTC} { + t.Run(loc.String(), func(t *testing.T) { + got, err := db.EventsRange(ctx, bid, now.Add(3*interval.duration).In(loc), now.Add(4*interval.duration).In(loc), -1) + if err != nil { + t.Errorf("failed to load data: %v", err) + } + for i := range got { + got[i].Bucket = bucket + } + + want := data[0].Events[1:2] + if !cmp.Equal(want, got, ignoreID) { + t.Errorf("unexpected result:\n--- want:\n+++ got:\n%s", cmp.Diff(want, got, ignoreID)) + } + }) + } + }) + + t.Run("update_last_event_coequal", func(t *testing.T) { + dbName := dbName + "_coequal" + dropTestDB(t, ctx, pgUserinfo, pgHostPort, dbName) + drop := createTestDB(t, ctx, pgUserinfo, pgHostPort, dbName) + if !*keep { + defer drop() + } + + u := url.URL{ + Scheme: "postgres", + User: url.UserPassword(pgUser, pgPassword), + Host: pgHostPort, + Path: dbName, + } + db, err := Open(ctx, u.String(), "test_host") + if err != nil && !errors.As(err, &Warning{}) { + t.Fatalf("failed to open db: %v", err) + } + defer db.Close(ctx) + + buckets := []string{ + `{"id":"window","name":"window-watcher","type":"currentwindow","client":"worklog","hostname":"test_host","created":"2023-06-12T19:54:38.305691865+09:30"}`, + `{"id":"afk","name":"afk-watcher","type":"afkstatus","client":"worklog","hostname":"test_host","created":"2023-06-12T19:54:38.310302464+09:30"}`, + } + for _, msg := range buckets { + var b worklog.BucketMetadata + err := json.Unmarshal([]byte(msg), &b) + if err != nil { + t.Fatalf("failed to unmarshal bucket message: %v", err) + } + _, err = db.CreateBucket(ctx, b.ID, b.Name, b.Type, b.Client, b.Created, b.Data) + if err != nil { + t.Fatalf("failed to create bucket: %v", err) + } + } + + events := []string{ + `{"bucket":"window","start":"2023-06-12T19:54:39.248859996+09:30","end":"2023-06-12T19:54:39.248859996+09:30","data":{"app":"Gnome-terminal","title":"Terminal"},"continue":false}`, + `{"bucket":"afk","start":"2023-06-12T19:54:39.248859996+09:30","end":"2023-06-12T19:54:39.248859996+09:30","data":{"afk":false,"locked":false},"continue":false}`, + `{"bucket":"window","start":"2023-06-12T19:54:40.247357339+09:30","end":"2023-06-12T19:54:40.247357339+09:30","data":{"app":"Gnome-terminal","title":"Terminal"},"continue":false}`, + `{"bucket":"afk","start":"2023-06-12T19:54:39.248859996+09:30","end":"2023-06-12T19:54:40.247357339+09:30","data":{"afk":false,"locked":false},"continue":true}`, + } + for i, msg := range events { + var note *worklog.Event + err := json.Unmarshal([]byte(msg), ¬e) + if err != nil { + t.Fatalf("failed to unmarshal event message: %v", err) + } + if note.Continue != nil && *note.Continue { + last, err := db.LastEvent(ctx, note.Bucket) + if err != nil { + t.Fatalf("failed to get last event: %v", err) + } + note.ID = last.ID + _, err = db.UpdateEvent(ctx, note) + if err != nil { + t.Fatalf("failed to update event: %v", err) + } + } else { + _, err = db.InsertEvent(ctx, note) + if err != nil { + t.Fatalf("failed to insert event: %v", err) + } + } + + dump, err := db.Dump(ctx) + if err != nil { + t.Fatalf("failed to dump db after step %d: %v", i, err) + } + t.Logf("note: %#v\ndump: %#v", note, dump) + + for _, b := range dump { + for _, e := range b.Events { + if e.Bucket == "window" { + if _, ok := e.Data["afk"]; ok { + t.Errorf("unexpectedly found afk data in window bucket: %v", e) + } + } + } + } + } + }) + + t.Run("amend", func(t *testing.T) { + // t.Skip("not yet working") + + dbName := dbName + "_amend" + dropTestDB(t, ctx, pgUserinfo, pgHostPort, dbName) + drop := createTestDB(t, ctx, pgUserinfo, pgHostPort, dbName) + if !*keep { + defer drop() + } + + u := url.URL{ + Scheme: "postgres", + User: pgUserinfo, + Host: pgHostPort, + Path: dbName, + } + db, err := Open(ctx, u.String(), "test_host") + if err != nil && !errors.As(err, &Warning{}) { + t.Fatalf("failed to open db: %v", err) + } + defer db.Close(ctx) + + buckets := []string{ + `{"id":"window","name":"window-watcher","type":"currentwindow","client":"worklog","hostname":"test_host","created":"2023-06-12T19:54:38Z"}`, + `{"id":"afk","name":"afk-watcher","type":"afkstatus","client":"worklog","hostname":"test_host","created":"2023-06-12T19:54:38Z"}`, + } + for _, msg := range buckets { + var b worklog.BucketMetadata + err := json.Unmarshal([]byte(msg), &b) + if err != nil { + t.Fatalf("failed to unmarshal bucket message: %v", err) + } + _, err = db.CreateBucket(ctx, b.ID, b.Name, b.Type, b.Client, b.Created, b.Data) + if err != nil { + t.Fatalf("failed to create bucket: %v", err) + } + } + + events := []string{ + `{"bucket":"window","start":"2023-06-12T19:54:40Z","end":"2023-06-12T19:54:45Z","data":{"app":"Gnome-terminal","title":"Terminal"}}`, + `{"bucket":"afk","start":"2023-06-12T19:54:40Z","end":"2023-06-12T19:54:45Z","data":{"afk":false,"locked":false}}`, + `{"bucket":"window","start":"2023-06-12T19:54:45Z","end":"2023-06-12T19:54:50Z","data":{"app":"Gnome-terminal","title":"Terminal"}}`, + `{"bucket":"afk","start":"2023-06-12T19:54:45Z","end":"2023-06-12T19:54:50Z","data":{"afk":true,"locked":true}}`, + `{"bucket":"window","start":"2023-06-12T19:54:50Z","end":"2023-06-12T19:54:55Z","data":{"app":"Gnome-terminal","title":"Terminal"}}`, + `{"bucket":"afk","start":"2023-06-12T19:54:50Z","end":"2023-06-12T19:54:55Z","data":{"afk":false,"locked":false}}`, + `{"bucket":"window","start":"2023-06-12T19:54:55Z","end":"2023-06-12T19:54:59Z","data":{"app":"Gnome-terminal","title":"Terminal"}}`, + `{"bucket":"afk","start":"2023-06-12T19:54:55Z","end":"2023-06-12T19:54:59Z","data":{"afk":true,"locked":true}}`, + } + for _, msg := range events { + var note *worklog.Event + err := json.Unmarshal([]byte(msg), ¬e) + if err != nil { + t.Fatalf("failed to unmarshal event message: %v", err) + } + _, err = db.InsertEvent(ctx, note) + if err != nil { + t.Fatalf("failed to insert event: %v", err) + } + } + msg := `{"bucket":"afk","msg":"testing","replace":[{"start":"2023-06-12T19:54:39Z","end":"2023-06-12T19:54:51Z","data":{"afk":true,"locked":true}}]}` + var amendment *worklog.Amendment + err = json.Unmarshal([]byte(msg), &amendment) + if err != nil { + t.Fatalf("failed to unmarshal event message: %v", err) + } + _, err = db.AmendEvents(ctx, time.Time{}, amendment) + if err != nil { + t.Errorf("unexpected error amending events: %v", err) + } + dump, err := db.Dump(ctx) + if err != nil { + t.Fatalf("failed to dump db: %v", err) + } + for _, bucket := range dump { + for i, event := range bucket.Events { + switch event.Bucket { + case "window": + _, ok := event.Data["amend"] + if ok { + t.Errorf("unexpected amendment in window event %d: %v", i, event.Data) + } + case "afk": + a, ok := event.Data["amend"] + if !ok { + for _, r := range amendment.Replace { + if overlaps(event.Start, event.End, r.Start, r.End) { + t.Errorf("expected amendment for event %d of afk", i) + break + } + } + break + } + var n []worklog.Amendment + err = remarshalJSON(&n, a) + if err != nil { + t.Errorf("unexpected error remarshalling []AmendEvents: %v", err) + } + if len(n) == 0 { + t.Fatal("unexpected zero-length []AmendEvents") + } + for _, r := range n[len(n)-1].Replace { + if r.Start.Before(event.Start) { + t.Errorf("replacement start extends before start of event: %s < %s", + r.Start.Format(time.RFC3339), event.Start.Format(time.RFC3339)) + } + if noted, ok := findOverlap(r, amendment.Replace); ok && !r.Start.Equal(event.Start) && !r.Start.Equal(noted.Start) { + t.Errorf("non-truncated replacement start was altered: %s != %s", + r.Start.Format(time.RFC3339), noted.Start.Format(time.RFC3339)) + } + if r.End.After(event.End) { + t.Errorf("replacement end extends beyond end of event: %s > %s", + r.End.Format(time.RFC3339), event.End.Format(time.RFC3339)) + } + if noted, ok := findOverlap(r, amendment.Replace); ok && !r.End.Equal(event.End) && !r.End.Equal(noted.End) { + t.Errorf("non-truncated replacement end was altered: %s != %s", + r.End.Format(time.RFC3339), noted.End.Format(time.RFC3339)) + } + } + default: + t.Errorf("unexpected event bucket name in event %d of %s: %s", i, bucket.ID, event.Bucket) + } + } + } + }) + + t.Run("dynamic_query", func(t *testing.T) { + grantReadAccess(t, ctx, pgUserinfo, pgHost, dbName, pgUser+"_ro") + u := url.URL{ + Scheme: "postgres", + User: pgUserinfo, + Host: pgHostPort, + Path: dbName, + } + db, err := Open(ctx, u.String(), "test_host") + if err != nil { + t.Fatalf("failed to open db: %v", err) + } + defer db.Close(ctx) + + dynamicTests := []struct { + name string + sql string + wantErr error + }{ + { + name: "kitchen_or", + sql: `select datastr ->> 'title', starttime, datastr ->> 'afk' from events + where + not (datastr ->> 'afk')::boolean or datastr ->> 'title' = 'Terminal' + limit 2`, + }, + { + name: "kitchen_and", + sql: `select datastr ->> 'title', starttime, datastr ->> 'afk' from events + where + not (datastr ->> 'afk')::boolean and datastr ->> 'title' = 'Terminal' + limit 2`, + }, + { + name: "count", + sql: `select count(*) from events`, + }, + { + name: "all", + sql: `select * from events`, + }, + { + name: "non_null_afk", + sql: `select * from events where datastr ? 'app'`, + }, + { + name: "drop_table", + sql: `drop table events`, + wantErr: errors.New("ERROR: must be owner of relation events (SQLSTATE 42501)"), + }, + { + name: "sneaky_create_table", + sql: "select count(*) from events; create table if not exists t(i)", + wantErr: errors.New("ERROR: syntax error at end of input (SQLSTATE 42601)"), + }, + { + name: "sneaky_drop_table", + sql: "select count(*) from events; drop table events", + wantErr: errors.New("ERROR: cannot insert multiple commands into a prepared statement (SQLSTATE 42601)"), + }, + } + + for _, test := range dynamicTests { + t.Run(test.name, func(t *testing.T) { + got, err := db.Select(test.sql) + if !sameError(err, test.wantErr) { + t.Errorf("unexpected error: got:%v want:%v", err, test.wantErr) + return + } + if err != nil { + return + } + + rows, err := db.store.Query(ctx, test.sql) + if err != nil { + t.Fatalf("unexpected error for query: %v", err) + } + descs := rows.FieldDescriptions() + var want []map[string]any + for rows.Next() { + args := make([]any, len(descs)) + for i := range args { + var a any + args[i] = &a + } + err = rows.Scan(args...) + if err != nil { + t.Fatal(err) + } + row := make(map[string]any) + for i, a := range args { + row[descs[i].Name] = *(a.(*any)) + } + want = append(want, row) + } + rows.Close() + + if !cmp.Equal(want, got) { + t.Errorf("unexpected result:\n--- want:\n+++ got:\n%s", cmp.Diff(want, got)) + } + }) + } + }) + }) + }) + } +} + +type pgPassEntry struct { + host string + port string + database string + user string + password string +} + +func (e pgPassEntry) match(user, host, port, database string) bool { + return user == e.user && + (host == e.host || e.host == "*") && + (port == e.port || e.port == "*") && + (database == e.database || e.database == "*") +} + +func parsePgPassLine(text string) (pgPassEntry, error) { + var ( + entry pgPassEntry + field int + last int + escape bool + ) + for i, r := range text { + switch r { + case '\\': + escape = !escape + continue + case ':': + if escape { + break + } + switch field { + case 0: + entry.host = text[last:i] + case 1: + entry.port = text[last:i] + case 2: + entry.database = text[last:i] + case 3: + entry.user = text[last:i] + default: + return entry, errors.New("too many fields") + } + last = i + 1 + field++ + } + escape = false + } + entry.password = text[last:] + return entry, nil +} + +func createTestDB(t *testing.T, ctx context.Context, user *url.Userinfo, host, dbname string) func() { + t.Helper() + + u := url.URL{ + Scheme: "postgres", + User: user, + Host: host, + Path: "template1", + } + db, err := pgx.Connect(ctx, u.String()) + if err != nil { + t.Fatalf("failed to open admin database: %v", err) + } + _, err = db.Exec(ctx, "create database "+dbname) + if err != nil { + t.Fatalf("failed to create test database: %v", err) + } + err = db.Close(ctx) + if err != nil { + t.Fatalf("failed to close admin connection: %v", err) + } + + return func() { + dropTestDB(t, ctx, user, host, dbname) + } +} + +func dropTestDB(t *testing.T, ctx context.Context, user *url.Userinfo, host, dbname string) { + t.Helper() + + u := url.URL{ + Scheme: "postgres", + User: user, + Host: host, + Path: "template1", + } + db, err := pgx.Connect(ctx, u.String()) + if err != nil { + t.Fatalf("failed to open admin database: %v", err) + } + _, err = db.Exec(ctx, "drop database if exists "+dbname) + if err != nil { + t.Fatalf("failed to drop test database: %v", err) + } + err = db.Close(ctx) + if err != nil { + t.Fatalf("failed to close admin connection: %v", err) + } +} + +func grantReadAccess(t *testing.T, ctx context.Context, user *url.Userinfo, host, dbname, target string) { + t.Helper() + + u := url.URL{ + Scheme: "postgres", + User: user, + Host: host, + Path: dbname, + } + db, err := pgx.Connect(ctx, u.String()) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + + statements := []string{ + fmt.Sprintf("GRANT CONNECT ON DATABASE %s TO %s", dbname, target), + fmt.Sprintf("GRANT USAGE ON SCHEMA public TO %s", target), + fmt.Sprintf("GRANT SELECT ON ALL TABLES IN SCHEMA public TO %s", target), + } + for _, s := range statements { + _, err = db.Exec(ctx, s) + if err != nil { + t.Fatalf("failed to execute grant: %v", err) + } + } + + err = db.Close(ctx) + if err != nil { + t.Fatalf("failed to close connection: %v", err) + } +} + +func ptr[T any](v T) *T { return &v } + +func findOverlap(n worklog.Replacement, h []worklog.Replacement) (worklog.Replacement, bool) { + for _, c := range h { + if overlaps(n.Start, n.End, c.Start, c.End) { + return c, true + } + } + return worklog.Replacement{}, false +} + +func overlaps(as, ae, bs, be time.Time) bool { + return ae.After(bs) && as.Before(be) +} +func remarshalJSON(dst, src any) error { + b, err := json.Marshal(src) + if err != nil { + return err + } + return json.Unmarshal(b, dst) +} + +func sameError(a, b error) bool { + switch { + case a != nil && b != nil: + return a.Error() == b.Error() + default: + return a == b + } +} diff --git a/go.mod b/go.mod index 79c639b..6cd7825 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/gofrs/flock v0.8.1 github.com/google/cel-go v0.20.1 github.com/google/go-cmp v0.6.0 + github.com/jackc/pgx/v5 v5.6.0 github.com/kortschak/ardilla v0.0.0-20240121074954-8297d203ffa4 github.com/kortschak/goroutine v1.1.2 github.com/kortschak/jsonrpc2 v0.0.0-20240214190357-0539ebd6a045 @@ -41,6 +42,8 @@ require ( github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect @@ -49,6 +52,7 @@ require ( github.com/tdewolff/font v0.0.0-20240417221047-e5855237f87b // indirect github.com/tdewolff/minify/v2 v2.20.5 // indirect github.com/tdewolff/parse/v2 v2.7.3 // indirect + golang.org/x/crypto v0.23.0 // indirect golang.org/x/exp/shiny v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/text v0.16.0 // indirect google.golang.org/genproto v0.0.0-20221207170731-23e4bf6bdc37 // indirect diff --git a/go.sum b/go.sum index 50f147a..82d9dca 100644 --- a/go.sum +++ b/go.sum @@ -76,6 +76,14 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= +github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/kortschak/ardilla v0.0.0-20240121074954-8297d203ffa4 h1:LBA2QNMoqB/6H1QG9oSzcVQxqLQJ8vX/Vr4SiyGlLNI= github.com/kortschak/ardilla v0.0.0-20240121074954-8297d203ffa4/go.mod h1:CzpsZoRc37ZYYkb/iJHFRW9m+pWuyiS3WAqvuBVFiCo= github.com/kortschak/goroutine v1.1.2 h1:lhllcCuERxMIK5cYr8yohZZScL1na+JM5JYPRclWjck= @@ -111,7 +119,9 @@ github.com/sstallion/go-hid v0.14.1/go.mod h1:fPKp4rqx0xuoTV94gwKojsPG++KNKhxuU8 github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU= github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tdewolff/canvas v0.0.0-20240512164826-1cb71758b3b2 h1:n5JfO/T/+VqXuHym08mEAzmDz2QDwMA71+/zl1SG+dE= @@ -125,6 +135,8 @@ github.com/tdewolff/parse/v2 v2.7.3/go.mod h1:9p2qMIHpjRSTr1qnFxQr+igogyTUTlwvf9 github.com/tdewolff/test v1.0.10/go.mod h1:6DAvZliBAAnD7rhVgwaM7DE5/d9NMOAJ09SqYqeK4QE= github.com/tdewolff/test v1.0.11-0.20240106005702-7de5f7df4739 h1:IkjBCtQOOjIn03u/dMQK9g+Iw9ewps4mCl1nB8Sscbo= github.com/tdewolff/test v1.0.11-0.20240106005702-7de5f7df4739/go.mod h1:XPuWBzvdUzhCuxWO1ojpXsyzsA5bFoS3tO/Q3kFuTG8= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f h1:99ci1mjWVBWwJiEKYY6jWa4d2nTQVIEhZIptnrVb1XY= golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/exp/shiny v0.0.0-20231006140011-7918f672742d h1:grE48C8cjIY0aiHVmFyYgYxxSARQWBABLXKZfQPrBhY= @@ -164,9 +176,10 @@ google.golang.org/genproto v0.0.0-20221207170731-23e4bf6bdc37/go.mod h1:RGgjbofJ google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= modernc.org/cc/v4 v4.21.2 h1:dycHFB/jDc3IyacKipCNSDrjIC0Lm1hyoWOZTRR20Lk=