Skip to content

Commit dcebc23

Browse files
committed
Merge branch 'master' of github.com:issue9/orm
2 parents b203de5 + b8d1018 commit dcebc23

File tree

10 files changed

+73
-140
lines changed

10 files changed

+73
-140
lines changed

dialect/base.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,16 @@ package dialect
66

77
import (
88
"os/exec"
9+
"strings"
910

1011
"github.com/issue9/sliceutil"
1112
)
1213

14+
var (
15+
quoteApostrophe = strings.NewReplacer("'", "''") // 标准 SQL 用法
16+
escapeApostrophe = strings.NewReplacer("'", "\\'") // mysql 用法
17+
)
18+
1319
type base struct {
1420
driverName string
1521
name string

dialect/mysql.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,9 @@ func (m *mysql) SQLType(col *core.Column) (string, error) {
222222
return "", invalidTimeFractional(col)
223223
}
224224
return m.buildType("DATETIME", col, false, 1)
225+
default:
226+
return "", errUncovert(col)
225227
}
226-
227-
return "", errUncovert(col)
228228
}
229229

230230
// l 表示需要取的长度数量
@@ -285,7 +285,7 @@ func (m *mysql) formatSQL(col *core.Column) (f string, err error) {
285285
}
286286
return "0", nil
287287
case string:
288-
return "'" + vv + "'", nil
288+
return "'" + escapeApostrophe.Replace(vv) + "'", nil
289289
case time.Time: // datetime
290290
return formatTime(col, vv)
291291
case sql.NullTime: // datetime

dialect/postgres.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,9 @@ func (p *postgres) SQLType(col *core.Column) (string, error) {
205205
return "", invalidTimeFractional(col)
206206
}
207207
return p.buildType("TIMESTAMP", col, 1)
208+
default:
209+
return "", errUncovert(col)
208210
}
209-
210-
return "", errUncovert(col)
211211
}
212212

213213
// l 表示需要取的长度数量
@@ -255,7 +255,7 @@ func (p *postgres) formatSQL(col *core.Column) (f string, err error) {
255255

256256
switch vv := v.(type) {
257257
case string:
258-
return "'" + vv + "'", nil
258+
return "'" + quoteApostrophe.Replace(vv) + "'", nil
259259
case time.Time: // timestamp
260260
return formatTime(col, vv)
261261
case sql.NullTime: // timestamp

dialect/sqlite3.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,9 @@ func (s *sqlite3) SQLType(col *core.Column) (string, error) {
293293
return s.buildType("BLOB", col)
294294
case core.Time:
295295
return s.buildType("DATETIME", col)
296+
default:
297+
return "", errUncovert(col)
296298
}
297-
298-
return "", errUncovert(col)
299299
}
300300

301301
// l 表示需要取的长度数量
@@ -336,7 +336,7 @@ func (s *sqlite3) formatSQL(v any) (f string, err error) {
336336

337337
switch vv := v.(type) {
338338
case string:
339-
return "'" + vv + "'", nil
339+
return "'" + quoteApostrophe.Replace(vv) + "'", nil
340340
case time.Time: // timestamp
341341
return "'" + vv.In(time.UTC).Format(datetimeLayouts[0]) + "'", nil
342342
case sql.NullTime: // timestamp
@@ -351,10 +351,10 @@ func (s *sqlite3) Backup(dsn, dest string) error {
351351
dsn = dsn[:index]
352352
}
353353

354-
data,err :=os.ReadFile(dsn)
355-
if err!=nil{
354+
data, err := os.ReadFile(dsn)
355+
if err != nil {
356356
return err
357357
}
358358

359-
return os.WriteFile(dest, data , os.ModePerm)
359+
return os.WriteFile(dest, data, os.ModePerm)
360360
}

fetch/column.go

Lines changed: 11 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ package fetch
66

77
import (
88
"database/sql"
9-
"reflect"
109

1110
"github.com/issue9/orm/v6/core"
1211
)
@@ -15,46 +14,11 @@ import (
1514
//
1615
// once 若为 true,则只导出第一条数据。
1716
// colName 指定需要导出的列名,若指定了不存在的名称,返回 error。
18-
func Column(once bool, colName string, rows *sql.Rows) ([]any, error) {
19-
cols, err := rows.Columns()
20-
if err != nil {
21-
return nil, err
22-
}
23-
24-
index := -1 // colName 列在 rows.Columns() 中的索引号
25-
buff := make([]any, len(cols))
26-
for i, v := range cols {
27-
var value any
28-
buff[i] = &value
29-
30-
if colName == v { // 获取 index 的值
31-
index = i
32-
}
33-
}
34-
35-
if index == -1 {
36-
return nil, core.ErrColumnNotFound(colName)
37-
}
38-
39-
var data []any
40-
for rows.Next() {
41-
if err := rows.Scan(buff...); err != nil {
42-
return nil, err
43-
}
44-
value := reflect.Indirect(reflect.ValueOf(buff[index]))
45-
data = append(data, value.Interface())
46-
if once {
47-
return data, nil
48-
}
49-
}
50-
51-
return data, nil
52-
}
53-
54-
// ColumnString 导出 rows 中某列的所有或是一行数据
5517
//
56-
// 功能等同于 [Column] 函数,但是返回值是 []string 而不是 []interface{}。
57-
func ColumnString(once bool, colName string, rows *sql.Rows) ([]string, error) {
18+
// NOTE: 要求 T 的类型必须符合 [sql.Row.Scan] 的参数要求;
19+
func Column[T any](once bool, colName string, rows *sql.Rows) ([]T, error) {
20+
// TODO: 应该约束 T 为 sql.Rows.Scan 允许的类型,但是以目前 Go 的语法无法做到。
21+
5822
cols, err := rows.Columns()
5923
if err != nil {
6024
return nil, err
@@ -63,24 +27,26 @@ func ColumnString(once bool, colName string, rows *sql.Rows) ([]string, error) {
6327
index := -1 // colName 列在 rows.Columns() 中的索引号
6428
buff := make([]any, len(cols))
6529
for i, v := range cols {
66-
var value string
67-
buff[i] = &value
68-
6930
if colName == v { // 获取 index 的值
7031
index = i
32+
var zero T
33+
buff[i] = &zero
34+
} else {
35+
var value any
36+
buff[i] = &value
7137
}
7238
}
7339

7440
if index == -1 {
7541
return nil, core.ErrColumnNotFound(colName)
7642
}
7743

78-
var data []string
44+
var data []T
7945
for rows.Next() {
8046
if err := rows.Scan(buff...); err != nil {
8147
return nil, err
8248
}
83-
data = append(data, *(buff[index].(*string)))
49+
data = append(data, *buff[index].(*T))
8450
if once {
8551
return data, nil
8652
}

fetch/column_test.go

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
package fetch_test
66

77
import (
8-
"reflect"
98
"testing"
109

1110
"github.com/issue9/assert/v4"
@@ -18,18 +17,6 @@ func TestColumn(t *testing.T) {
1817
a := assert.New(t, false)
1918
suite := test.NewSuite(a, "")
2019

21-
eq := func(s1, s2 []any) bool {
22-
if len(s1) != len(s2) {
23-
return false
24-
}
25-
for i, v := range s1 {
26-
if !reflect.DeepEqual(v, s2[i]) {
27-
return false
28-
}
29-
}
30-
return true
31-
}
32-
3320
suite.Run(func(t *test.Driver) {
3421
initDB(t)
3522
defer clearDB(t)
@@ -40,36 +27,28 @@ func TestColumn(t *testing.T) {
4027
rows, err := db.Query(sql)
4128
t.NotError(err).NotNil(rows)
4229

43-
cols, err := fetch.Column(false, "id", rows)
30+
cols, err := fetch.Column[int64](false, "id", rows)
4431
t.NotError(err).NotNil(cols)
4532

46-
if t.DriverName == "mysql" { // mysql 返回的是 []byte 类型
47-
eq(cols, []any{[]byte{'1'}, []byte{'2'}})
48-
} else {
49-
eq(cols, []any{int64(1), int64(2)})
50-
}
33+
t.Equal(cols, []int64{int64(1), int64(2)})
5134
t.NotError(rows.Close())
5235

5336
// 正常数据匹配,读取一行
5437
rows, err = db.Query(sql)
5538
t.NotError(err).NotNil(rows)
5639

57-
cols, err = fetch.Column(true, "id", rows)
40+
cols, err = fetch.Column[int64](true, "id", rows)
5841
t.NotError(err).NotNil(cols)
5942

60-
if t.DriverName == "mysql" { // mysql 返回的是 []byte 类型
61-
eq([]any{[]byte{'1'}}, cols)
62-
} else {
63-
eq([]any{int64(1)}, cols)
64-
}
43+
t.Equal(cols, []int64{int64(1)})
6544
t.NotError(rows.Close())
6645

6746
// 没有数据匹配,读取多行
6847
sql = `SELECT id,email FROM fetch_users WHERE id<0 ORDER BY id ASC`
6948
rows, err = db.Query(sql)
7049
t.NotError(err).NotNil(rows)
7150

72-
cols, err = fetch.Column(false, "id", rows)
51+
cols, err = fetch.Column[int64](false, "id", rows)
7352
t.NotError(err)
7453

7554
t.Empty(cols)
@@ -79,7 +58,7 @@ func TestColumn(t *testing.T) {
7958
rows, err = db.Query(sql)
8059
t.NotError(err).NotNil(rows)
8160

82-
cols, err = fetch.Column(true, "id", rows)
61+
cols, err = fetch.Column[int64](true, "id", rows)
8362
t.NotError(err)
8463

8564
t.Empty(cols)
@@ -89,7 +68,7 @@ func TestColumn(t *testing.T) {
8968
rows, err = db.Query(sql)
9069
t.NotError(err).NotNil(rows)
9170

92-
cols, err = fetch.Column(true, "not-exists", rows)
71+
cols, err = fetch.Column[int64](true, "not-exists", rows)
9372
t.Error(err)
9473

9574
t.Empty(cols)
@@ -111,7 +90,7 @@ func TestColumnString(t *testing.T) {
11190
rows, err := db.Query(sql)
11291
t.NotError(err).NotNil(rows)
11392

114-
cols, err := fetch.ColumnString(false, "id", rows)
93+
cols, err := fetch.Column[string](false, "id", rows)
11594
t.NotError(err).NotNil(cols)
11695

11796
t.Equal([]string{"1", "2"}, cols)
@@ -121,7 +100,7 @@ func TestColumnString(t *testing.T) {
121100
rows, err = db.Query(sql)
122101
t.NotError(err).NotNil(rows)
123102

124-
cols, err = fetch.ColumnString(true, "id", rows)
103+
cols, err = fetch.Column[string](true, "id", rows)
125104
t.NotError(err).NotNil(cols)
126105

127106
t.Equal([]string{"1"}, cols)
@@ -132,7 +111,7 @@ func TestColumnString(t *testing.T) {
132111
rows, err = db.Query(sql)
133112
t.NotError(err).NotNil(rows)
134113

135-
cols, err = fetch.ColumnString(false, "id", rows)
114+
cols, err = fetch.Column[string](false, "id", rows)
136115
t.NotError(err)
137116

138117
t.Empty(cols)
@@ -142,7 +121,7 @@ func TestColumnString(t *testing.T) {
142121
rows, err = db.Query(sql)
143122
t.NotError(err).NotNil(rows)
144123

145-
cols, err = fetch.ColumnString(true, "id", rows)
124+
cols, err = fetch.Column[string](true, "id", rows)
146125
t.NotError(err)
147126

148127
t.Empty(cols)
@@ -152,7 +131,7 @@ func TestColumnString(t *testing.T) {
152131
rows, err = db.Query(sql)
153132
t.NotError(err).NotNil(rows)
154133

155-
cols, err = fetch.ColumnString(true, "not-exists", rows)
134+
cols, err = fetch.Column[string](true, "not-exists", rows)
156135
t.Error(err)
157136

158137
t.Empty(cols)

internal/test/test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ func (s Suite) close() {
125125
for _, t := range s.drivers {
126126
t.NotError(t.DB.Close())
127127

128-
if t.DB.Dialect().DriverName() != Sqlite3.DriverName() {
128+
dn := t.DB.Dialect().DriverName()
129+
if dn != Sqlite3.DriverName() && dn != Sqlite.DriverName() {
129130
return
130131
}
131132

0 commit comments

Comments
 (0)