Skip to content

Commit

Permalink
Merge branch 'master' of github.com:issue9/orm
Browse files Browse the repository at this point in the history
  • Loading branch information
caixw committed May 6, 2024
2 parents b203de5 + b8d1018 commit dcebc23
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 140 deletions.
6 changes: 6 additions & 0 deletions dialect/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@ package dialect

import (
"os/exec"
"strings"

"github.com/issue9/sliceutil"
)

var (
quoteApostrophe = strings.NewReplacer("'", "''") // 标准 SQL 用法
escapeApostrophe = strings.NewReplacer("'", "\\'") // mysql 用法
)

type base struct {
driverName string
name string
Expand Down
6 changes: 3 additions & 3 deletions dialect/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@ func (m *mysql) SQLType(col *core.Column) (string, error) {
return "", invalidTimeFractional(col)
}
return m.buildType("DATETIME", col, false, 1)
default:
return "", errUncovert(col)
}

return "", errUncovert(col)
}

// l 表示需要取的长度数量
Expand Down Expand Up @@ -285,7 +285,7 @@ func (m *mysql) formatSQL(col *core.Column) (f string, err error) {
}
return "0", nil
case string:
return "'" + vv + "'", nil
return "'" + escapeApostrophe.Replace(vv) + "'", nil
case time.Time: // datetime
return formatTime(col, vv)
case sql.NullTime: // datetime
Expand Down
6 changes: 3 additions & 3 deletions dialect/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ func (p *postgres) SQLType(col *core.Column) (string, error) {
return "", invalidTimeFractional(col)
}
return p.buildType("TIMESTAMP", col, 1)
default:
return "", errUncovert(col)
}

return "", errUncovert(col)
}

// l 表示需要取的长度数量
Expand Down Expand Up @@ -255,7 +255,7 @@ func (p *postgres) formatSQL(col *core.Column) (f string, err error) {

switch vv := v.(type) {
case string:
return "'" + vv + "'", nil
return "'" + quoteApostrophe.Replace(vv) + "'", nil
case time.Time: // timestamp
return formatTime(col, vv)
case sql.NullTime: // timestamp
Expand Down
12 changes: 6 additions & 6 deletions dialect/sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,9 @@ func (s *sqlite3) SQLType(col *core.Column) (string, error) {
return s.buildType("BLOB", col)
case core.Time:
return s.buildType("DATETIME", col)
default:
return "", errUncovert(col)
}

return "", errUncovert(col)
}

// l 表示需要取的长度数量
Expand Down Expand Up @@ -336,7 +336,7 @@ func (s *sqlite3) formatSQL(v any) (f string, err error) {

switch vv := v.(type) {
case string:
return "'" + vv + "'", nil
return "'" + quoteApostrophe.Replace(vv) + "'", nil
case time.Time: // timestamp
return "'" + vv.In(time.UTC).Format(datetimeLayouts[0]) + "'", nil
case sql.NullTime: // timestamp
Expand All @@ -351,10 +351,10 @@ func (s *sqlite3) Backup(dsn, dest string) error {
dsn = dsn[:index]
}

data,err :=os.ReadFile(dsn)
if err!=nil{
data, err := os.ReadFile(dsn)
if err != nil {
return err
}

return os.WriteFile(dest, data , os.ModePerm)
return os.WriteFile(dest, data, os.ModePerm)
}
56 changes: 11 additions & 45 deletions fetch/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package fetch

import (
"database/sql"
"reflect"

"github.com/issue9/orm/v6/core"
)
Expand All @@ -15,46 +14,11 @@ import (
//
// once 若为 true,则只导出第一条数据。
// colName 指定需要导出的列名,若指定了不存在的名称,返回 error。
func Column(once bool, colName string, rows *sql.Rows) ([]any, error) {
cols, err := rows.Columns()
if err != nil {
return nil, err
}

index := -1 // colName 列在 rows.Columns() 中的索引号
buff := make([]any, len(cols))
for i, v := range cols {
var value any
buff[i] = &value

if colName == v { // 获取 index 的值
index = i
}
}

if index == -1 {
return nil, core.ErrColumnNotFound(colName)
}

var data []any
for rows.Next() {
if err := rows.Scan(buff...); err != nil {
return nil, err
}
value := reflect.Indirect(reflect.ValueOf(buff[index]))
data = append(data, value.Interface())
if once {
return data, nil
}
}

return data, nil
}

// ColumnString 导出 rows 中某列的所有或是一行数据
//
// 功能等同于 [Column] 函数,但是返回值是 []string 而不是 []interface{}。
func ColumnString(once bool, colName string, rows *sql.Rows) ([]string, error) {
// NOTE: 要求 T 的类型必须符合 [sql.Row.Scan] 的参数要求;
func Column[T any](once bool, colName string, rows *sql.Rows) ([]T, error) {
// TODO: 应该约束 T 为 sql.Rows.Scan 允许的类型,但是以目前 Go 的语法无法做到。

cols, err := rows.Columns()
if err != nil {
return nil, err
Expand All @@ -63,24 +27,26 @@ func ColumnString(once bool, colName string, rows *sql.Rows) ([]string, error) {
index := -1 // colName 列在 rows.Columns() 中的索引号
buff := make([]any, len(cols))
for i, v := range cols {
var value string
buff[i] = &value

if colName == v { // 获取 index 的值
index = i
var zero T
buff[i] = &zero
} else {
var value any
buff[i] = &value
}
}

if index == -1 {
return nil, core.ErrColumnNotFound(colName)
}

var data []string
var data []T
for rows.Next() {
if err := rows.Scan(buff...); err != nil {
return nil, err
}
data = append(data, *(buff[index].(*string)))
data = append(data, *buff[index].(*T))
if once {
return data, nil
}
Expand Down
45 changes: 12 additions & 33 deletions fetch/column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package fetch_test

import (
"reflect"
"testing"

"github.com/issue9/assert/v4"
Expand All @@ -18,18 +17,6 @@ func TestColumn(t *testing.T) {
a := assert.New(t, false)
suite := test.NewSuite(a, "")

eq := func(s1, s2 []any) bool {
if len(s1) != len(s2) {
return false
}
for i, v := range s1 {
if !reflect.DeepEqual(v, s2[i]) {
return false
}
}
return true
}

suite.Run(func(t *test.Driver) {
initDB(t)
defer clearDB(t)
Expand All @@ -40,36 +27,28 @@ func TestColumn(t *testing.T) {
rows, err := db.Query(sql)
t.NotError(err).NotNil(rows)

cols, err := fetch.Column(false, "id", rows)
cols, err := fetch.Column[int64](false, "id", rows)
t.NotError(err).NotNil(cols)

if t.DriverName == "mysql" { // mysql 返回的是 []byte 类型
eq(cols, []any{[]byte{'1'}, []byte{'2'}})
} else {
eq(cols, []any{int64(1), int64(2)})
}
t.Equal(cols, []int64{int64(1), int64(2)})
t.NotError(rows.Close())

// 正常数据匹配,读取一行
rows, err = db.Query(sql)
t.NotError(err).NotNil(rows)

cols, err = fetch.Column(true, "id", rows)
cols, err = fetch.Column[int64](true, "id", rows)
t.NotError(err).NotNil(cols)

if t.DriverName == "mysql" { // mysql 返回的是 []byte 类型
eq([]any{[]byte{'1'}}, cols)
} else {
eq([]any{int64(1)}, cols)
}
t.Equal(cols, []int64{int64(1)})
t.NotError(rows.Close())

// 没有数据匹配,读取多行
sql = `SELECT id,email FROM fetch_users WHERE id<0 ORDER BY id ASC`
rows, err = db.Query(sql)
t.NotError(err).NotNil(rows)

cols, err = fetch.Column(false, "id", rows)
cols, err = fetch.Column[int64](false, "id", rows)
t.NotError(err)

t.Empty(cols)
Expand All @@ -79,7 +58,7 @@ func TestColumn(t *testing.T) {
rows, err = db.Query(sql)
t.NotError(err).NotNil(rows)

cols, err = fetch.Column(true, "id", rows)
cols, err = fetch.Column[int64](true, "id", rows)
t.NotError(err)

t.Empty(cols)
Expand All @@ -89,7 +68,7 @@ func TestColumn(t *testing.T) {
rows, err = db.Query(sql)
t.NotError(err).NotNil(rows)

cols, err = fetch.Column(true, "not-exists", rows)
cols, err = fetch.Column[int64](true, "not-exists", rows)
t.Error(err)

t.Empty(cols)
Expand All @@ -111,7 +90,7 @@ func TestColumnString(t *testing.T) {
rows, err := db.Query(sql)
t.NotError(err).NotNil(rows)

cols, err := fetch.ColumnString(false, "id", rows)
cols, err := fetch.Column[string](false, "id", rows)
t.NotError(err).NotNil(cols)

t.Equal([]string{"1", "2"}, cols)
Expand All @@ -121,7 +100,7 @@ func TestColumnString(t *testing.T) {
rows, err = db.Query(sql)
t.NotError(err).NotNil(rows)

cols, err = fetch.ColumnString(true, "id", rows)
cols, err = fetch.Column[string](true, "id", rows)
t.NotError(err).NotNil(cols)

t.Equal([]string{"1"}, cols)
Expand All @@ -132,7 +111,7 @@ func TestColumnString(t *testing.T) {
rows, err = db.Query(sql)
t.NotError(err).NotNil(rows)

cols, err = fetch.ColumnString(false, "id", rows)
cols, err = fetch.Column[string](false, "id", rows)
t.NotError(err)

t.Empty(cols)
Expand All @@ -142,7 +121,7 @@ func TestColumnString(t *testing.T) {
rows, err = db.Query(sql)
t.NotError(err).NotNil(rows)

cols, err = fetch.ColumnString(true, "id", rows)
cols, err = fetch.Column[string](true, "id", rows)
t.NotError(err)

t.Empty(cols)
Expand All @@ -152,7 +131,7 @@ func TestColumnString(t *testing.T) {
rows, err = db.Query(sql)
t.NotError(err).NotNil(rows)

cols, err = fetch.ColumnString(true, "not-exists", rows)
cols, err = fetch.Column[string](true, "not-exists", rows)
t.Error(err)

t.Empty(cols)
Expand Down
3 changes: 2 additions & 1 deletion internal/test/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ func (s Suite) close() {
for _, t := range s.drivers {
t.NotError(t.DB.Close())

if t.DB.Dialect().DriverName() != Sqlite3.DriverName() {
dn := t.DB.Dialect().DriverName()
if dn != Sqlite3.DriverName() && dn != Sqlite.DriverName() {
return
}

Expand Down
Loading

0 comments on commit dcebc23

Please sign in to comment.