-
Notifications
You must be signed in to change notification settings - Fork 0
/
dialect_postgres.go
75 lines (59 loc) · 1.69 KB
/
dialect_postgres.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
package crud
import (
"database/sql"
"fmt"
"strings"
)
type PostgresDialect struct{}
func (PostgresDialect) Scan(rows *sql.Rows, args ...FieldBinder) error {
return genericScan(rows, args...)
}
func (PostgresDialect) Insert(db DbIsh, table, sqlIdFieldName string, obj FieldEnumerator) (id int64, er error) {
if er := deflate(obj); er != nil {
return 0, er
}
objFields, objValues := obj.EnumerateFields()
if len(objFields) != len(objValues) {
panic("crud2: FieldEnumerator.EnumerateFields' return values must have same length")
}
sqlFields := make([]string, 0, len(objFields))
sqlValues := make([]interface{}, 0, len(objFields))
placeholders := make([]string, 0, len(objFields))
for i, field := range objFields {
// If there's an id field, skip it so it can be automatically assigned.
if field != sqlIdFieldName {
sqlValues = append(sqlValues, objValues[i])
sqlFields = append(sqlFields, field)
placeholders = append(placeholders, fmt.Sprintf("$%d", len(sqlValues)))
}
}
var q string
if sqlIdFieldName != "" {
q = `
INSERT INTO %s
(%s)
VALUES (%s)
RETURNING %s
`
q = fmt.Sprintf(q, table, strings.Join(sqlFields, ", "), strings.Join(placeholders, ", "), sqlIdFieldName)
rows, er := db.Query(q, sqlValues...)
if er != nil {
return 0, er
}
defer rows.Close()
rows.Next()
er = rows.Scan(&id)
} else {
q = `
INSERT INTO %s
(%s)
VALUES (%s)
`
q = fmt.Sprintf(q, table, strings.Join(sqlFields, ", "), strings.Join(placeholders, ", "))
_, er = db.Exec(q, sqlValues...)
}
return
}
func (PostgresDialect) Update(db DbIsh, table, sqlIdFieldName string, obj FieldEnumerator) error {
return genericUpdate(db, table, sqlIdFieldName, obj)
}