diff --git a/db.go b/db.go index 89a52d1..ff00dca 100644 --- a/db.go +++ b/db.go @@ -2,6 +2,7 @@ package pg import ( "context" + "encoding/json" "errors" "fmt" "reflect" @@ -514,3 +515,63 @@ func (db *DB) Unlisten(ctx context.Context, channel string) error { _, err := db.Exec(ctx, query, channel) return err } + +// UpdateJSONB updates a JSONB column (full or partial) in the database by building and executing an +// SQL query based on the provided values and the given tableName and columnName. +// The values parameter is a map of key-value pairs where the key is the json field name and the value is its new value, +// new keys are accepted. Note that tableName and columnName are not escaped. +func (db *DB) UpdateJSONB(ctx context.Context, tableName, columnName, rowID string, values map[string]any, fieldsToUpdate []string, primaryKey *desc.Column) (int64, error) { + var ( + tag pgconn.CommandTag + err error + ) + + // We could extract the id from the column and do a select based on that but let's keep things simple and do it per row id. + // id, ok := values[primaryKey.Name] + // if !ok { + // return 0, fmt.Errorf("missing primary key value") + // } + + // Partial Update. + if len(fieldsToUpdate) > 0 { + // Build query. + query := fmt.Sprintf("UPDATE %s SET %s = ", tableName, columnName) + // Loop over the keys and construct the query using jsonb_set. + for _, key := range fieldsToUpdate { + // Get the value for the key from the map. + value, ok := values[key] + if !ok { + // Handle missing value. + return 0, fmt.Errorf("missing value for key: %s", key) + } + + // Convert the value to JSON. + valueJSON, err := json.Marshal(value) + if err != nil { + // Handle error. + return 0, fmt.Errorf("error converting value to json: %w", err) + } + + // Append the jsonb_set function to the query. + query += fmt.Sprintf("jsonb_set (%s, ' {%s}', '%s'::jsonb, true), ", columnName, key, string(valueJSON)) + } + // Remove the trailing comma and space. + query = strings.TrimSuffix(query, ", ") + + // Add the WHERE clause. + query += " WHERE id = $1" + + // Execute the query with the id parameter. + tag, err = db.Exec(ctx, query, rowID) + } else { + // Full Update. + query := fmt.Sprintf("UPDATE %s SET %s = $1 WHERE id = $2;", tableName, columnName) + tag, err = db.Exec(ctx, query, values, rowID) + } + if err != nil { + // Handle error + return 0, err + } + + return tag.RowsAffected(), nil +}