diff --git a/db_repository.go b/db_repository.go index 8f29bac..b2f9c43 100644 --- a/db_repository.go +++ b/db_repository.go @@ -298,7 +298,7 @@ func (db *DB) UpdateExceptColumns(ctx context.Context, columnsToExcept []string, } columnsToUpdate := td.ListColumnNamesExcept(columnsToExcept...) - return db.updateTableRecords(ctx, td, columnsToUpdate, values) + return db.updateTableRecords(ctx, td, columnsToUpdate, false, values) } // UpdateOnlyColumns updates one or more values in the database by building and executing an @@ -315,17 +315,17 @@ func (db *DB) UpdateOnlyColumns(ctx context.Context, columnsToUpdate []string, v return 0, err // return the error if the table definition is not found } - return db.updateTableRecords(ctx, td, columnsToUpdate, values) + return db.updateTableRecords(ctx, td, columnsToUpdate, false, values) } -func (db *DB) updateTableRecords(ctx context.Context, td *desc.Table, columnsToUpdate []string, values []any) (int64, error) { +func (db *DB) updateTableRecords(ctx context.Context, td *desc.Table, columnsToUpdate []string, reportNotFound bool, values []any) (int64, error) { primaryKey, ok := td.PrimaryKey() if !ok { return 0, fmt.Errorf("no primary key found in table definition: %s", td.Name) } if len(values) == 1 { - return db.updateTableRecord(ctx, values[0], columnsToUpdate, primaryKey) + return db.updateTableRecord(ctx, values[0], columnsToUpdate, reportNotFound, primaryKey) } // if more than one: update each value inside a transaction. @@ -333,7 +333,7 @@ func (db *DB) updateTableRecords(ctx context.Context, td *desc.Table, columnsToU err := db.InTransaction(ctx, func(db *DB) error { for _, value := range values { - rowsAffected, err := db.updateTableRecord(ctx, value, columnsToUpdate, primaryKey) + rowsAffected, err := db.updateTableRecord(ctx, value, columnsToUpdate, reportNotFound, primaryKey) if err != nil { return err } @@ -350,13 +350,22 @@ func (db *DB) updateTableRecords(ctx context.Context, td *desc.Table, columnsToU return totalRowsAffected, nil } -func (db *DB) updateTableRecord(ctx context.Context, value any, columnsToUpdate []string, primaryKey *desc.Column) (int64, error) { +func (db *DB) updateTableRecord(ctx context.Context, value any, columnsToUpdate []string, reportNotFound bool, primaryKey *desc.Column) (int64, error) { // build the SQL query and arguments using the table definition and its primary key. - query, args, err := desc.BuildUpdateQuery(value, columnsToUpdate, primaryKey) + query, args, err := desc.BuildUpdateQuery(value, columnsToUpdate, reportNotFound, primaryKey) if err != nil { return 0, err } + if reportNotFound { + scanErr := db.QueryRow(ctx, query, args...).Scan(nil) + if scanErr != nil { + return 0, scanErr + } + + return 1, nil + } + // execute the query using db.Exec and pass in the primary key values as a parameter tag, err := db.Exec(ctx, query, args...) if err != nil { diff --git a/desc/insert_query.go b/desc/insert_query.go index 61f641f..ef2f6f0 100644 --- a/desc/insert_query.go +++ b/desc/insert_query.go @@ -18,7 +18,7 @@ func BuildInsertQuery(td *Table, structValue reflect.Value, idPtr any, forceOnCo if idPtr != nil { // if idPtr is not nil, it means we want to get the primary key value of the inserted row columnDefinition, ok := td.PrimaryKey() // get the primary key column definition from the table definition - if ok && idPtr != nil { + if ok { returningColumn = columnDefinition.Name // assign the column name to returningColumn } } diff --git a/desc/update_query.go b/desc/update_query.go index a54f9e3..dfb0135 100644 --- a/desc/update_query.go +++ b/desc/update_query.go @@ -8,7 +8,7 @@ import ( // BuildUpdateQuery builds and returns an SQL query for updating a row in the table, // using the given struct value and the primary key. -func BuildUpdateQuery(value any, columnsToUpdate []string, primaryKey *Column) (string, []any, error) { +func BuildUpdateQuery(value any, columnsToUpdate []string, reportNotFound bool, primaryKey *Column) (string, []any, error) { args, err := extractUpdateArguments(value, columnsToUpdate, primaryKey) if err != nil { return "", nil, err @@ -27,7 +27,7 @@ func BuildUpdateQuery(value any, columnsToUpdate []string, primaryKey *Column) ( } // build the SQL query using the table definition and its primary key. - query := buildUpdateQuery(primaryKey.Table, args, primaryKey.Name, shouldUpdateID) + query := buildUpdateQuery(primaryKey.Table, args, primaryKey.Name, shouldUpdateID, reportNotFound) return query, args.Values(), nil } @@ -81,7 +81,7 @@ func extractUpdateArguments(value any, columnsToUpdate []string, primaryKey *Col return args, nil } -func buildUpdateQuery(td *Table, args Arguments, primaryKeyName string, shouldUpdateID bool) string { +func buildUpdateQuery(td *Table, args Arguments, primaryKeyName string, shouldUpdateID bool, reportNotFound bool) string { var b strings.Builder b.WriteString(`UPDATE "` + td.Name + `" SET `) @@ -122,6 +122,10 @@ func buildUpdateQuery(td *Table, args Arguments, primaryKeyName string, shouldUp } b.WriteString(` WHERE "` + primaryKeyName + `" = $` + strconv.Itoa(primaryKeyWhereIndex)) + if reportNotFound { + b.WriteString(` RETURNING "` + primaryKeyName + `"`) + } + b.WriteByte(';') return b.String() diff --git a/go.mod b/go.mod index ae4b348..1728473 100644 --- a/go.mod +++ b/go.mod @@ -5,14 +5,14 @@ go 1.23 require ( github.com/gertd/go-pluralize v0.2.1 github.com/jackc/pgx/v5 v5.7.1 - golang.org/x/mod v0.21.0 + golang.org/x/mod v0.22.0 ) require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect - golang.org/x/crypto v0.28.0 // indirect - golang.org/x/sync v0.8.0 // indirect - golang.org/x/text v0.19.0 // indirect + golang.org/x/crypto v0.29.0 // indirect + golang.org/x/sync v0.9.0 // indirect + golang.org/x/text v0.20.0 // indirect ) diff --git a/go.sum b/go.sum index 24a5929..4882962 100644 --- a/go.sum +++ b/go.sum @@ -18,14 +18,14 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= -golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= -golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= -golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= -golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= +golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= +golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= +golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= +golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/repository.go b/repository.go index ef74f68..49ba172 100644 --- a/repository.go +++ b/repository.go @@ -289,7 +289,31 @@ func (repo *Repository[T]) UpdateOnlyColumns(ctx context.Context, columnsToUpdat } valuesAsInterfaces := toInterfaces(values) - return repo.db.updateTableRecords(ctx, repo.td, columnsToUpdate, valuesAsInterfaces) + return repo.db.updateTableRecords(ctx, repo.td, columnsToUpdate, false, valuesAsInterfaces) +} + +// UpdateOnlyColumnsReportNoRows updates one or more values of type T in the database by their primary key values. +// It returns an ErrNoRows if there is no matching row on the given criteria. +func (repo *Repository[T]) UpdateOnlyColumnsReportNoRows(ctx context.Context, columnsToUpdate []string, values ...T) (bool, error) { + if repo.IsReadOnly() { + return false, ErrIsReadOnly + } + + if len(values) == 0 { + return false, nil + } + + valuesAsInterfaces := toInterfaces(values) + _, err := repo.db.updateTableRecords(ctx, repo.td, columnsToUpdate, true, valuesAsInterfaces) + if err != nil { + if errors.Is(err, ErrNoRows) { + return false, nil + } + + return false, err + } + + return true, nil } func toInterfaces[T any](values []T) []any {