diff --git a/dialect/base.go b/dialect/base.go index 28332c7..057689b 100644 --- a/dialect/base.go +++ b/dialect/base.go @@ -6,10 +6,16 @@ package dialect import ( "os/exec" + "strings" "github.com/issue9/sliceutil" ) +var ( + quoteApostrophe = strings.NewReplacer("'", "''") // 标准 SQL 用法 + escapeApostrophe = strings.NewReplacer("'", "\\'") // mysql 用法 +) + type base struct { driverName string name string diff --git a/dialect/mysql.go b/dialect/mysql.go index 7709d3a..9fdbf8c 100644 --- a/dialect/mysql.go +++ b/dialect/mysql.go @@ -222,9 +222,9 @@ func (m *mysql) SQLType(col *core.Column) (string, error) { return "", invalidTimeFractional(col) } return m.buildType("DATETIME", col, false, 1) + default: + return "", errUncovert(col) } - - return "", errUncovert(col) } // l 表示需要取的长度数量 @@ -285,7 +285,7 @@ func (m *mysql) formatSQL(col *core.Column) (f string, err error) { } return "0", nil case string: - return "'" + vv + "'", nil + return "'" + escapeApostrophe.Replace(vv) + "'", nil case time.Time: // datetime return formatTime(col, vv) case sql.NullTime: // datetime diff --git a/dialect/postgres.go b/dialect/postgres.go index da6e406..bb8cc2f 100644 --- a/dialect/postgres.go +++ b/dialect/postgres.go @@ -205,9 +205,9 @@ func (p *postgres) SQLType(col *core.Column) (string, error) { return "", invalidTimeFractional(col) } return p.buildType("TIMESTAMP", col, 1) + default: + return "", errUncovert(col) } - - return "", errUncovert(col) } // l 表示需要取的长度数量 @@ -255,7 +255,7 @@ func (p *postgres) formatSQL(col *core.Column) (f string, err error) { switch vv := v.(type) { case string: - return "'" + vv + "'", nil + return "'" + quoteApostrophe.Replace(vv) + "'", nil case time.Time: // timestamp return formatTime(col, vv) case sql.NullTime: // timestamp diff --git a/dialect/sqlite3.go b/dialect/sqlite3.go index f072b4e..1d042a1 100644 --- a/dialect/sqlite3.go +++ b/dialect/sqlite3.go @@ -293,9 +293,9 @@ func (s *sqlite3) SQLType(col *core.Column) (string, error) { return s.buildType("BLOB", col) case core.Time: return s.buildType("DATETIME", col) + default: + return "", errUncovert(col) } - - return "", errUncovert(col) } // l 表示需要取的长度数量 @@ -336,7 +336,7 @@ func (s *sqlite3) formatSQL(v any) (f string, err error) { switch vv := v.(type) { case string: - return "'" + vv + "'", nil + return "'" + quoteApostrophe.Replace(vv) + "'", nil case time.Time: // timestamp return "'" + vv.In(time.UTC).Format(datetimeLayouts[0]) + "'", nil case sql.NullTime: // timestamp @@ -351,10 +351,10 @@ func (s *sqlite3) Backup(dsn, dest string) error { dsn = dsn[:index] } - data,err :=os.ReadFile(dsn) - if err!=nil{ + data, err := os.ReadFile(dsn) + if err != nil { return err } - return os.WriteFile(dest, data , os.ModePerm) + return os.WriteFile(dest, data, os.ModePerm) } diff --git a/fetch/column.go b/fetch/column.go index 8291f0b..8618fbf 100644 --- a/fetch/column.go +++ b/fetch/column.go @@ -6,7 +6,6 @@ package fetch import ( "database/sql" - "reflect" "github.com/issue9/orm/v6/core" ) @@ -15,46 +14,11 @@ import ( // // once 若为 true,则只导出第一条数据。 // colName 指定需要导出的列名,若指定了不存在的名称,返回 error。 -func Column(once bool, colName string, rows *sql.Rows) ([]any, error) { - cols, err := rows.Columns() - if err != nil { - return nil, err - } - - index := -1 // colName 列在 rows.Columns() 中的索引号 - buff := make([]any, len(cols)) - for i, v := range cols { - var value any - buff[i] = &value - - if colName == v { // 获取 index 的值 - index = i - } - } - - if index == -1 { - return nil, core.ErrColumnNotFound(colName) - } - - var data []any - for rows.Next() { - if err := rows.Scan(buff...); err != nil { - return nil, err - } - value := reflect.Indirect(reflect.ValueOf(buff[index])) - data = append(data, value.Interface()) - if once { - return data, nil - } - } - - return data, nil -} - -// ColumnString 导出 rows 中某列的所有或是一行数据 // -// 功能等同于 [Column] 函数,但是返回值是 []string 而不是 []interface{}。 -func ColumnString(once bool, colName string, rows *sql.Rows) ([]string, error) { +// NOTE: 要求 T 的类型必须符合 [sql.Row.Scan] 的参数要求; +func Column[T any](once bool, colName string, rows *sql.Rows) ([]T, error) { + // TODO: 应该约束 T 为 sql.Rows.Scan 允许的类型,但是以目前 Go 的语法无法做到。 + cols, err := rows.Columns() if err != nil { return nil, err @@ -63,11 +27,13 @@ func ColumnString(once bool, colName string, rows *sql.Rows) ([]string, error) { index := -1 // colName 列在 rows.Columns() 中的索引号 buff := make([]any, len(cols)) for i, v := range cols { - var value string - buff[i] = &value - if colName == v { // 获取 index 的值 index = i + var zero T + buff[i] = &zero + } else { + var value any + buff[i] = &value } } @@ -75,12 +41,12 @@ func ColumnString(once bool, colName string, rows *sql.Rows) ([]string, error) { return nil, core.ErrColumnNotFound(colName) } - var data []string + var data []T for rows.Next() { if err := rows.Scan(buff...); err != nil { return nil, err } - data = append(data, *(buff[index].(*string))) + data = append(data, *buff[index].(*T)) if once { return data, nil } diff --git a/fetch/column_test.go b/fetch/column_test.go index 1354cd3..31fb84d 100644 --- a/fetch/column_test.go +++ b/fetch/column_test.go @@ -5,7 +5,6 @@ package fetch_test import ( - "reflect" "testing" "github.com/issue9/assert/v4" @@ -18,18 +17,6 @@ func TestColumn(t *testing.T) { a := assert.New(t, false) suite := test.NewSuite(a, "") - eq := func(s1, s2 []any) bool { - if len(s1) != len(s2) { - return false - } - for i, v := range s1 { - if !reflect.DeepEqual(v, s2[i]) { - return false - } - } - return true - } - suite.Run(func(t *test.Driver) { initDB(t) defer clearDB(t) @@ -40,28 +27,20 @@ func TestColumn(t *testing.T) { rows, err := db.Query(sql) t.NotError(err).NotNil(rows) - cols, err := fetch.Column(false, "id", rows) + cols, err := fetch.Column[int64](false, "id", rows) t.NotError(err).NotNil(cols) - if t.DriverName == "mysql" { // mysql 返回的是 []byte 类型 - eq(cols, []any{[]byte{'1'}, []byte{'2'}}) - } else { - eq(cols, []any{int64(1), int64(2)}) - } + t.Equal(cols, []int64{int64(1), int64(2)}) t.NotError(rows.Close()) // 正常数据匹配,读取一行 rows, err = db.Query(sql) t.NotError(err).NotNil(rows) - cols, err = fetch.Column(true, "id", rows) + cols, err = fetch.Column[int64](true, "id", rows) t.NotError(err).NotNil(cols) - if t.DriverName == "mysql" { // mysql 返回的是 []byte 类型 - eq([]any{[]byte{'1'}}, cols) - } else { - eq([]any{int64(1)}, cols) - } + t.Equal(cols, []int64{int64(1)}) t.NotError(rows.Close()) // 没有数据匹配,读取多行 @@ -69,7 +48,7 @@ func TestColumn(t *testing.T) { rows, err = db.Query(sql) t.NotError(err).NotNil(rows) - cols, err = fetch.Column(false, "id", rows) + cols, err = fetch.Column[int64](false, "id", rows) t.NotError(err) t.Empty(cols) @@ -79,7 +58,7 @@ func TestColumn(t *testing.T) { rows, err = db.Query(sql) t.NotError(err).NotNil(rows) - cols, err = fetch.Column(true, "id", rows) + cols, err = fetch.Column[int64](true, "id", rows) t.NotError(err) t.Empty(cols) @@ -89,7 +68,7 @@ func TestColumn(t *testing.T) { rows, err = db.Query(sql) t.NotError(err).NotNil(rows) - cols, err = fetch.Column(true, "not-exists", rows) + cols, err = fetch.Column[int64](true, "not-exists", rows) t.Error(err) t.Empty(cols) @@ -111,7 +90,7 @@ func TestColumnString(t *testing.T) { rows, err := db.Query(sql) t.NotError(err).NotNil(rows) - cols, err := fetch.ColumnString(false, "id", rows) + cols, err := fetch.Column[string](false, "id", rows) t.NotError(err).NotNil(cols) t.Equal([]string{"1", "2"}, cols) @@ -121,7 +100,7 @@ func TestColumnString(t *testing.T) { rows, err = db.Query(sql) t.NotError(err).NotNil(rows) - cols, err = fetch.ColumnString(true, "id", rows) + cols, err = fetch.Column[string](true, "id", rows) t.NotError(err).NotNil(cols) t.Equal([]string{"1"}, cols) @@ -132,7 +111,7 @@ func TestColumnString(t *testing.T) { rows, err = db.Query(sql) t.NotError(err).NotNil(rows) - cols, err = fetch.ColumnString(false, "id", rows) + cols, err = fetch.Column[string](false, "id", rows) t.NotError(err) t.Empty(cols) @@ -142,7 +121,7 @@ func TestColumnString(t *testing.T) { rows, err = db.Query(sql) t.NotError(err).NotNil(rows) - cols, err = fetch.ColumnString(true, "id", rows) + cols, err = fetch.Column[string](true, "id", rows) t.NotError(err) t.Empty(cols) @@ -152,7 +131,7 @@ func TestColumnString(t *testing.T) { rows, err = db.Query(sql) t.NotError(err).NotNil(rows) - cols, err = fetch.ColumnString(true, "not-exists", rows) + cols, err = fetch.Column[string](true, "not-exists", rows) t.Error(err) t.Empty(cols) diff --git a/internal/test/test.go b/internal/test/test.go index 68ea782..5cf6d9c 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -125,7 +125,8 @@ func (s Suite) close() { for _, t := range s.drivers { t.NotError(t.DB.Close()) - if t.DB.Dialect().DriverName() != Sqlite3.DriverName() { + dn := t.DB.Dialect().DriverName() + if dn != Sqlite3.DriverName() && dn != Sqlite.DriverName() { return } diff --git a/sqlbuilder/select.go b/sqlbuilder/select.go index 21e1a15..4ef9015 100644 --- a/sqlbuilder/select.go +++ b/sqlbuilder/select.go @@ -8,7 +8,6 @@ import ( "context" "database/sql" "errors" - "strconv" "github.com/issue9/orm/v6/core" "github.com/issue9/orm/v6/fetch" @@ -454,11 +453,7 @@ func (stmt *SelectStmt) QueryString(colName string) (v string, err error) { } func (stmt *SelectStmt) QueryStringContext(ctx context.Context, colName string) (v string, err error) { - rows, err := stmt.QueryContext(ctx) - if err != nil { - return "", err - } - return fetchString(rows, colName) + return fetchSelectStmtColumn[string](stmt, ctx, colName) } // QueryFloat 查询指定列的第一行数据,并将其转换成 float64 @@ -467,12 +462,7 @@ func (stmt *SelectStmt) QueryFloat(colName string) (float64, error) { } func (stmt *SelectStmt) QueryFloatContext(ctx context.Context, colName string) (float64, error) { - v, err := stmt.QueryStringContext(ctx, colName) - if err != nil { - return 0, err - } - - return strconv.ParseFloat(v, 64) + return fetchSelectStmtColumn[float64](stmt, ctx, colName) } // QueryInt 查询指定列的第一行数据,并将其转换成 int64 @@ -481,15 +471,7 @@ func (stmt *SelectStmt) QueryInt(colName string) (int64, error) { } func (stmt *SelectStmt) QueryIntContext(ctx context.Context, colName string) (int64, error) { - // NOTE: 可能会出现浮点数的情况。比如: - // select avg(xx) as avg form xxx where xxx - // 查询 avg 的值可能是 5.000 等值。 - v, err := stmt.QueryStringContext(ctx, colName) - if err != nil { - return 0, err - } - - return strconv.ParseInt(v, 10, 64) + return fetchSelectStmtColumn[int64](stmt, ctx, colName) } // Select 生成 select 语句 @@ -523,35 +505,18 @@ func (stmt *SelectQuery) QueryObject(strict bool, objs any, arg ...any) (size in } // QueryString 查询指定列的第一行数据,并将其转换成 string -func (stmt *SelectQuery) QueryString(colName string, arg ...any) (v string, err error) { - rows, err := stmt.stmt.Query(arg...) - if err != nil { - return "", err - } - return fetchString(rows, colName) +func (stmt *SelectQuery) QueryString(colName string, arg ...any) (string, error) { + return fetchSelectQueryColumn[string](stmt, colName, arg...) } // QueryFloat 查询指定列的第一行数据,并将其转换成 float64 func (stmt *SelectQuery) QueryFloat(colName string, arg ...any) (float64, error) { - v, err := stmt.QueryString(colName, arg...) - if err != nil { - return 0, err - } - - return strconv.ParseFloat(v, 64) + return fetchSelectQueryColumn[float64](stmt, colName, arg...) } // QueryInt 查询指定列的第一行数据,并将其转换成 int64 func (stmt *SelectQuery) QueryInt(colName string, arg ...any) (int64, error) { - // NOTE: 可能会出现浮点数的情况。比如: - // select avg(xx) as avg form xxx where xxx - // 查询 avg 的值可能是 5.000 等值。 - v, err := stmt.QueryString(colName, arg...) - if err != nil { - return 0, err - } - - return strconv.ParseInt(v, 10, 64) + return fetchSelectQueryColumn[int64](stmt, colName, arg...) } func (stmt *SelectQuery) Close() error { return stmt.stmt.Close() } @@ -559,19 +524,35 @@ func (stmt *SelectQuery) Close() error { return stmt.stmt.Close() } func fetchObject(rows *sql.Rows, strict bool, objs any) (size int, err error) { defer func() { err = errors.Join(err, rows.Close()) }() size, err = fetch.Object(strict, rows, objs) - return + return // 注意 defer,独立为一行 +} + +func fetchSelectStmtColumn[T any](stmt *SelectStmt, ctx context.Context, colName string) (v T, err error) { + rows, err := stmt.QueryContext(ctx) + if err != nil { + return v, err + } + return fetchColumn[T](rows, colName) +} + +func fetchSelectQueryColumn[T any](stmt *SelectQuery, colName string, arg ...any) (v T, err error) { + rows, err := stmt.stmt.Query(arg...) + if err != nil { + return v, err + } + return fetchColumn[T](rows, colName) } -func fetchString(rows *sql.Rows, colName string) (v string, err error) { +func fetchColumn[T any](rows *sql.Rows, colName string) (v T, err error) { defer func() { err = errors.Join(err, rows.Close()) }() - cols, err := fetch.ColumnString(true, colName, rows) + cols, err := fetch.Column[T](true, colName, rows) if err != nil { - return "", err + return v, err } if len(cols) == 0 { - return "", ErrNoData + return v, ErrNoData } return cols[0], nil diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index e1a6ccd..6b3ff6e 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -509,7 +509,7 @@ func (stmt *TableExistsStmt) Exists() (bool, error) { return false, err } - name, err := fetchString(rows, "name") + name, err := fetchColumn[string](rows, "name") switch { case errors.Is(err, ErrNoData): return false, nil diff --git a/sqlbuilder/view.go b/sqlbuilder/view.go index 596ee04..7582ff2 100644 --- a/sqlbuilder/view.go +++ b/sqlbuilder/view.go @@ -195,7 +195,7 @@ func (stmt *ViewExistsStmt) Exists() (bool, error) { return false, err } - name, err := fetchString(rows, "name") + name, err := fetchColumn[string](rows, "name") switch { case errors.Is(err, ErrNoData): return false, nil