From dd83e8e12130c7a543dd75a78b0b35152e5a7950 Mon Sep 17 00:00:00 2001 From: xormplus Date: Thu, 6 Dec 2018 01:08:59 +0800 Subject: [PATCH] add QueryExpr function --- README.md | 8 +++ engine.go | 2 +- engine_cond.go | 2 +- error.go | 2 + session_cond.go | 2 +- session_cond_test.go | 2 +- session_exist.go | 2 +- session_find.go | 2 +- session_query.go | 23 +++++++- session_query_test.go | 4 +- session_raw.go | 2 +- session_stats_test.go | 2 +- session_sum_test.go | 2 +- session_update.go | 2 +- sql_expr.go | 96 ++++++++++++++++++++++++++++++++++ statement.go | 39 ++++++++++++-- string_builder.go | 119 ++++++++++++++++++++++++++++++++++++++++++ 17 files changed, 293 insertions(+), 18 deletions(-) create mode 100644 sql_expr.go create mode 100644 string_builder.go diff --git a/README.md b/README.md index 806b7f0..ca0f94a 100644 --- a/README.md +++ b/README.md @@ -1110,6 +1110,14 @@ err := engine.Table("user").Select("user.*, detail.*") // SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 0 offset 10 ``` +* 子查询 + +```Go +var student []Student +err = db.Table("student").Select("id ,name").Where("id in (?)", db.Table("studentinfo").Select("id").Where("status = ?", 2).QueryExpr()).Find(&student) +//SELECT id ,name FROM `student` WHERE (id in (SELECT id FROM `studentinfo` WHERE (status = 2))) +``` + * 根据条件遍历数据库,可以有两种方式: Iterate and Rows ```Go diff --git a/engine.go b/engine.go index cb10337..f52f572 100644 --- a/engine.go +++ b/engine.go @@ -20,7 +20,7 @@ import ( "time" "github.com/fsnotify/fsnotify" - "github.com/go-xorm/builder" + "github.com/xormplus/builder" "github.com/xormplus/core" ) diff --git a/engine_cond.go b/engine_cond.go index fdf33de..a01714a 100644 --- a/engine_cond.go +++ b/engine_cond.go @@ -12,7 +12,7 @@ import ( "strings" "time" - "github.com/go-xorm/builder" + "github.com/xormplus/builder" "github.com/xormplus/core" ) diff --git a/error.go b/error.go index d1abc27..ab8a611 100644 --- a/error.go +++ b/error.go @@ -30,6 +30,8 @@ var ( ErrNotImplemented = errors.New("Not implemented") // ErrConditionType condition type unsupported ErrConditionType = errors.New("Unsupported condition type") + // ErrNeedMoreArguments need more arguments + ErrNeedMoreArguments = errors.New("Need more sql arguments") ) // ErrFieldIsNotExist columns does not exist diff --git a/session_cond.go b/session_cond.go index 568144f..b2c51f5 100644 --- a/session_cond.go +++ b/session_cond.go @@ -4,7 +4,7 @@ package xorm -import "github.com/go-xorm/builder" +import "github.com/xormplus/builder" // Sql provides raw sql input parameter. When you have a complex SQL statement // and cannot use Where, Id, In and etc. Methods to describe, you can use SQL. diff --git a/session_cond_test.go b/session_cond_test.go index ae4c2f8..c45b682 100644 --- a/session_cond_test.go +++ b/session_cond_test.go @@ -9,8 +9,8 @@ import ( "fmt" "testing" - "github.com/go-xorm/builder" "github.com/stretchr/testify/assert" + "github.com/xormplus/builder" ) func TestBuilder(t *testing.T) { diff --git a/session_exist.go b/session_exist.go index b5e60f9..4480834 100644 --- a/session_exist.go +++ b/session_exist.go @@ -9,7 +9,7 @@ import ( "fmt" "reflect" - "github.com/go-xorm/builder" + "github.com/xormplus/builder" "github.com/xormplus/core" ) diff --git a/session_find.go b/session_find.go index d81705b..82f1242 100644 --- a/session_find.go +++ b/session_find.go @@ -10,7 +10,7 @@ import ( "reflect" "strings" - "github.com/go-xorm/builder" + "github.com/xormplus/builder" "github.com/xormplus/core" ) diff --git a/session_query.go b/session_query.go index 59ac7ee..1d0acf4 100644 --- a/session_query.go +++ b/session_query.go @@ -11,7 +11,7 @@ import ( "strings" "time" - "github.com/go-xorm/builder" + "github.com/xormplus/builder" "github.com/xormplus/core" ) @@ -384,3 +384,24 @@ func (session *Session) QueryInterface(sqlorArgs ...interface{}) ([]map[string]i return rows2Interfaces(rows) } + +// QueryExpr returns the query as bound SQL +func (session *Session) QueryExpr(sqlorArgs ...interface{}) sqlExpr { + if session.isAutoClose { + defer session.Close() + } + + sqlStr, args, err := session.genQuerySQL() + if err != nil { + session.engine.logger.Error(err) + return sqlExpr{sqlExpr: ""} + } + + sqlStr, err = ConvertToBoundSQL(sqlStr, args) + if err != nil { + session.engine.logger.Error(err) + return sqlExpr{sqlExpr: ""} + } + + return sqlExpr{sqlExpr: sqlStr} +} diff --git a/session_query_test.go b/session_query_test.go index 5ce89a4..c20cbdd 100644 --- a/session_query_test.go +++ b/session_query_test.go @@ -10,8 +10,8 @@ import ( "testing" "time" - "github.com/go-xorm/builder" - "github.com/go-xorm/core" + "github.com/xormplus/builder" + "github.com/xormplus/core" "github.com/stretchr/testify/assert" ) diff --git a/session_raw.go b/session_raw.go index d521c9f..9444386 100644 --- a/session_raw.go +++ b/session_raw.go @@ -9,7 +9,7 @@ import ( "reflect" "time" - "github.com/go-xorm/builder" + "github.com/xormplus/builder" "github.com/xormplus/core" ) diff --git a/session_stats_test.go b/session_stats_test.go index b66a84b..8c0d314 100644 --- a/session_stats_test.go +++ b/session_stats_test.go @@ -9,8 +9,8 @@ import ( "strconv" "testing" - "github.com/go-xorm/builder" "github.com/stretchr/testify/assert" + "github.com/xormplus/builder" ) func isFloatEq(i, j float64, precision int) bool { diff --git a/session_sum_test.go b/session_sum_test.go index 2d2ad9b..12f61cc 100644 --- a/session_sum_test.go +++ b/session_sum_test.go @@ -9,8 +9,8 @@ import ( "strconv" "testing" - "github.com/go-xorm/builder" "github.com/stretchr/testify/assert" + "github.com/xormplus/builder" ) func isFloatEq(i, j float64, precision int) bool { diff --git a/session_update.go b/session_update.go index 42d1916..96ecd79 100644 --- a/session_update.go +++ b/session_update.go @@ -11,7 +11,7 @@ import ( "strconv" "strings" - "github.com/go-xorm/builder" + "github.com/xormplus/builder" "github.com/xormplus/core" ) diff --git a/sql_expr.go b/sql_expr.go new file mode 100644 index 0000000..eb1c6d4 --- /dev/null +++ b/sql_expr.go @@ -0,0 +1,96 @@ +package xorm + +import ( + sql2 "database/sql" + "fmt" + "reflect" + "time" +) + +type sqlExpr struct { + sqlExpr string +} + +func noSQLQuoteNeeded(a interface{}) bool { + switch a.(type) { + case int, int8, int16, int32, int64: + return true + case uint, uint8, uint16, uint32, uint64: + return true + case float32, float64: + return true + case bool: + return true + case string: + return false + case time.Time, *time.Time: + return false + case sqlExpr, *sqlExpr: + return true + } + + t := reflect.TypeOf(a) + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return true + case reflect.Float32, reflect.Float64: + return true + case reflect.Bool: + return true + case reflect.String: + return false + } + + return false +} + +// ConvertToBoundSQL will convert SQL and args to a bound SQL +func ConvertToBoundSQL(sql string, args []interface{}) (string, error) { + buf := StringBuilder{} + var i, j, start int + for ; i < len(sql); i++ { + if sql[i] == '?' { + _, err := buf.WriteString(sql[start:i]) + if err != nil { + return "", err + } + start = i + 1 + + if len(args) == j { + return "", ErrNeedMoreArguments + } + + arg := args[j] + + if exprArg, ok := arg.(sqlExpr); ok { + _, err = fmt.Fprint(&buf, exprArg.sqlExpr) + if err != nil { + return "", err + } + + } else { + if namedArg, ok := arg.(sql2.NamedArg); ok { + arg = namedArg.Value + } + + if noSQLQuoteNeeded(arg) { + _, err = fmt.Fprint(&buf, arg) + } else { + _, err = fmt.Fprintf(&buf, "'%v'", arg) + } + if err != nil { + return "", err + } + } + + j = j + 1 + } + } + _, err := buf.WriteString(sql[start:]) + if err != nil { + return "", err + } + return buf.String(), nil +} diff --git a/statement.go b/statement.go index cc65c30..ff3d80c 100644 --- a/statement.go +++ b/statement.go @@ -13,7 +13,7 @@ import ( "strings" "time" - "github.com/go-xorm/builder" + "github.com/xormplus/builder" "github.com/xormplus/core" ) @@ -146,8 +146,22 @@ func (statement *Statement) Where(query interface{}, args ...interface{}) *State func (statement *Statement) And(query interface{}, args ...interface{}) *Statement { switch query.(type) { case string: - cond := builder.Expr(query.(string), args...) - statement.cond = statement.cond.And(cond) + isExpr := false + var cargs []interface{} + for i, _ := range args { + if _, ok := args[i].(sqlExpr); ok { + isExpr = true + } + cargs = append(cargs, args[i]) + } + if isExpr { + sqlStr, _ := ConvertToBoundSQL(query.(string), cargs) + cond := builder.Expr(sqlStr) + statement.cond = statement.cond.And(cond) + } else { + cond := builder.Expr(query.(string), args...) + statement.cond = statement.cond.And(cond) + } case map[string]interface{}: cond := builder.Eq(query.(map[string]interface{})) statement.cond = statement.cond.And(cond) @@ -170,8 +184,23 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement { switch query.(type) { case string: - cond := builder.Expr(query.(string), args...) - statement.cond = statement.cond.Or(cond) + isExpr := false + var cargs []interface{} + for i, _ := range args { + if _, ok := args[i].(sqlExpr); ok { + isExpr = true + } + cargs = append(cargs, args[i]) + } + if isExpr { + sqlStr, _ := ConvertToBoundSQL(query.(string), cargs) + cond := builder.Expr(sqlStr) + statement.cond = statement.cond.Or(cond) + } else { + cond := builder.Expr(query.(string), args...) + statement.cond = statement.cond.Or(cond) + } + case map[string]interface{}: cond := builder.Eq(query.(map[string]interface{})) statement.cond = statement.cond.Or(cond) diff --git a/string_builder.go b/string_builder.go new file mode 100644 index 0000000..f462771 --- /dev/null +++ b/string_builder.go @@ -0,0 +1,119 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "unicode/utf8" + "unsafe" +) + +// A StringBuilder is used to efficiently build a string using Write methods. +// It minimizes memory copying. The zero value is ready to use. +// Do not copy a non-zero Builder. +type StringBuilder struct { + addr *StringBuilder // of receiver, to detect copies by value + buf []byte +} + +// noescape hides a pointer from escape analysis. noescape is +// the identity function but escape analysis doesn't think the +// output depends on the input. noescape is inlined and currently +// compiles down to zero instructions. +// USE CAREFULLY! +// This was copied from the runtime; see issues 23382 and 7921. +//go:nosplit +func noescape(p unsafe.Pointer) unsafe.Pointer { + x := uintptr(p) + return unsafe.Pointer(x ^ 0) +} + +func (b *StringBuilder) copyCheck() { + if b.addr == nil { + // This hack works around a failing of Go's escape analysis + // that was causing b to escape and be heap allocated. + // See issue 23382. + // TODO: once issue 7921 is fixed, this should be reverted to + // just "b.addr = b". + b.addr = (*StringBuilder)(noescape(unsafe.Pointer(b))) + } else if b.addr != b { + panic("strings: illegal use of non-zero Builder copied by value") + } +} + +// String returns the accumulated string. +func (b *StringBuilder) String() string { + return *(*string)(unsafe.Pointer(&b.buf)) +} + +// Len returns the number of accumulated bytes; b.Len() == len(b.String()). +func (b *StringBuilder) Len() int { return len(b.buf) } + +// Reset resets the Builder to be empty. +func (b *StringBuilder) Reset() { + b.addr = nil + b.buf = nil +} + +// grow copies the buffer to a new, larger buffer so that there are at least n +// bytes of capacity beyond len(b.buf). +func (b *StringBuilder) grow(n int) { + buf := make([]byte, len(b.buf), 2*cap(b.buf)+n) + copy(buf, b.buf) + b.buf = buf +} + +// Grow grows b's capacity, if necessary, to guarantee space for +// another n bytes. After Grow(n), at least n bytes can be written to b +// without another allocation. If n is negative, Grow panics. +func (b *StringBuilder) Grow(n int) { + b.copyCheck() + if n < 0 { + panic("strings.Builder.Grow: negative count") + } + if cap(b.buf)-len(b.buf) < n { + b.grow(n) + } +} + +// Write appends the contents of p to b's buffer. +// Write always returns len(p), nil. +func (b *StringBuilder) Write(p []byte) (int, error) { + b.copyCheck() + b.buf = append(b.buf, p...) + return len(p), nil +} + +// WriteByte appends the byte c to b's buffer. +// The returned error is always nil. +func (b *StringBuilder) WriteByte(c byte) error { + b.copyCheck() + b.buf = append(b.buf, c) + return nil +} + +// WriteRune appends the UTF-8 encoding of Unicode code point r to b's buffer. +// It returns the length of r and a nil error. +func (b *StringBuilder) WriteRune(r rune) (int, error) { + b.copyCheck() + if r < utf8.RuneSelf { + b.buf = append(b.buf, byte(r)) + return 1, nil + } + l := len(b.buf) + if cap(b.buf)-l < utf8.UTFMax { + b.grow(utf8.UTFMax) + } + n := utf8.EncodeRune(b.buf[l:l+utf8.UTFMax], r) + b.buf = b.buf[:l+n] + return n, nil +} + +// WriteString appends the contents of s to b's buffer. +// It returns the length of s and a nil error. +func (b *StringBuilder) WriteString(s string) (int, error) { + b.copyCheck() + b.buf = append(b.buf, s...) + return len(s), nil +}