Skip to content

Commit

Permalink
feature: 添加调用选项real_row_count,设置是否通过count(*)获取真正受影响行数.默认值false
Browse files Browse the repository at this point in the history
  • Loading branch information
hanchuanchuan committed Jul 28, 2019
1 parent ffa50a8 commit 7ffa335
Show file tree
Hide file tree
Showing 5 changed files with 379 additions and 248 deletions.
40 changes: 40 additions & 0 deletions session/rewrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package session

import (
"fmt"
"vitess.io/vitess/go/vt/sqlparser"
)

Expand Down Expand Up @@ -104,3 +105,42 @@ func insert2Select(stmt *sqlparser.Insert) string {

return "select 1 from DUAL"
}

// select2Count : SELECT 转成 COUNT语句
func (rw *Rewrite) select2Count() string {
if rw.Stmt == nil {
return fmt.Sprintf("SELECT COUNT(1) FROM (%s)t", rw.SQL)
}
// log.Infof("%#v", rw.Stmt)

switch stmt := rw.Stmt.(type) {
case *sqlparser.Select:
if stmt.Distinct != "" || stmt.GroupBy != nil || stmt.Having != nil {
return fmt.Sprintf("SELECT COUNT(1) FROM (%s)t", rw.SQL)
}

newSQL := &sqlparser.Select{
SelectExprs: []sqlparser.SelectExpr{
&sqlparser.AliasedExpr{
Expr: &sqlparser.FuncExpr{
Name: sqlparser.NewColIdent("count"),
Exprs: []sqlparser.SelectExpr{
new(sqlparser.StarExpr),
},
},
},
},
Distinct: stmt.Distinct,
From: stmt.From,
Where: stmt.Where,
GroupBy: stmt.GroupBy,
Having: stmt.Having,
Limit: stmt.Limit,
}
return sqlparser.String(newSQL)
// case *sqlparser.Union, *sqlparser.ParenSelect:
// return fmt.Sprintf("SELECT COUNT(1) FROM (%s)t", rw.SQL)
default:
return fmt.Sprintf("SELECT COUNT(1) FROM (%s)t", rw.SQL)
}
}
130 changes: 113 additions & 17 deletions session/session_inception.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import (
"github.com/hanchuanchuan/goInception/util/auth"
"github.com/hanchuanchuan/goInception/util/sqlexec"
"github.com/hanchuanchuan/goInception/util/stringutil"
// "vitess.io/vitess/go/vt/sqlparser"
// "github.com/hanchuanchuan/parser/ast"
"github.com/jinzhu/gorm"
"github.com/percona/go-mysql/query"
Expand Down Expand Up @@ -126,6 +127,9 @@ type sourceOptions struct {

// DDL/DML分隔功能
split bool

// 使用count(*)计算受影响行数
realRowCount bool
}

// ExplainInfo 执行计划信息
Expand Down Expand Up @@ -2235,7 +2239,8 @@ func (s *session) parseOptions(sql string) {

Print: viper.GetBool("queryPrint"),

split: viper.GetBool("split"),
split: viper.GetBool("split"),
realRowCount: viper.GetBool("realRowCount"),
}

if s.opt.split || s.opt.check || s.opt.Print {
Expand Down Expand Up @@ -5736,6 +5741,73 @@ func (s *session) getExplainInfo(sql string, sqlId string) {
}
}

// getRealRowCount: 获取真正的受影响行数
func (s *session) getRealRowCount(sql string, sqlId string) {

if s.hasError() {
return
}

var newRecord *Record
if s.Inc.EnableFingerprint && sqlId != "" {
newRecord = &Record{
Buf: new(bytes.Buffer),
}
}
r := s.myRecord

var value int
rows, err := s.Raw(sql)
if rows != nil {
defer rows.Close()
}

if err != nil {
// log.Error(err)
log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
if myErr, ok := err.(*mysqlDriver.MySQLError); ok {
s.AppendErrorMessage(myErr.Message)
if newRecord != nil {
newRecord.AppendErrorMessage(myErr.Message)
}
} else {
s.AppendErrorMessage(err.Error())
if newRecord != nil {
newRecord.AppendErrorMessage(myErr.Message)
}
}
return
} else {
for rows.Next() {
rows.Scan(&value)
}
}

// log.Info(sql)
// log.Info(value)

r.AffectedRows = value
if newRecord != nil {
newRecord.AffectedRows = r.AffectedRows
}

if s.Inc.MaxUpdateRows > 0 && r.AffectedRows > int(s.Inc.MaxUpdateRows) {
switch r.Type.(type) {
case *ast.DeleteStmt, *ast.UpdateStmt:
s.AppendErrorNo(ER_UDPATE_TOO_MUCH_ROWS,
r.AffectedRows, s.Inc.MaxUpdateRows)
if newRecord != nil {
newRecord.AppendErrorNo(ER_UDPATE_TOO_MUCH_ROWS,
r.AffectedRows, s.Inc.MaxUpdateRows)
}
}
}

if newRecord != nil {
s.sqlFingerprint[sqlId] = newRecord
}
}

func (s *session) explainOrAnalyzeSql(sql string) {

// // 如果没有表结构,或者新增表 or 新增列时,不做explain
Expand All @@ -5749,38 +5821,62 @@ func (s *session) explainOrAnalyzeSql(sql string) {
return
}

if s.DBVersion < 50600 {
if s.opt.realRowCount {
// dml转换成select
rw, err := NewRewrite(sql)
if err != nil {
log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
s.AppendErrorMessage(err.Error())
} else {
rw, err = rw.Rewrite()
rw, err = rw.RewriteDML2Select()
if err != nil {
log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
s.AppendErrorMessage(err.Error())
} else {
sql = rw.NewSQL
if sql == "" {
return
stmt, err := NewRewrite(rw.NewSQL)
if err != nil {
log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
s.AppendErrorMessage(err.Error())
} else {
sql = stmt.select2Count()
// log.Info(sql)
s.getRealRowCount(sql, sqlId)
}
}
}
return
} else {
if s.DBVersion < 50600 {
rw, err := NewRewrite(sql)
if err != nil {
log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
s.AppendErrorMessage(err.Error())
} else {
rw, err = rw.RewriteDML2Select()
if err != nil {
log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
s.AppendErrorMessage(err.Error())
} else {
sql = rw.NewSQL
if sql == "" {
return
}
}
}
}
}

var explain []string

if s.isMiddleware() {
explain = append(explain, s.opt.middlewareExtend)
}
var explain []string

explain = append(explain, "EXPLAIN ")
explain = append(explain, sql)
if s.isMiddleware() {
explain = append(explain, s.opt.middlewareExtend)
}

// rows := s.getExplainInfo(strings.Join(explain, ""))
s.getExplainInfo(strings.Join(explain, ""), sqlId)
explain = append(explain, "EXPLAIN ")
explain = append(explain, sql)

// s.AnlyzeExplain(rows)
// rows := s.getExplainInfo(strings.Join(explain, ""))
s.getExplainInfo(strings.Join(explain, ""), sqlId)
}
}

func (s *session) AnlyzeExplain(rows []ExplainInfo) {
Expand Down
8 changes: 6 additions & 2 deletions session/session_inception_backup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ type testSessionIncBackupSuite struct {
sqlMode string
// 时间戳类型是否需要明确指定默认值
explicitDefaultsForTimestamp bool

realRowCount bool
}

func (s *testSessionIncBackupSuite) SetUpSuite(c *C) {
Expand All @@ -58,6 +60,8 @@ func (s *testSessionIncBackupSuite) SetUpSuite(c *C) {
c.Skip("skipping test; in TRAVIS mode")
}

s.realRowCount = true

testleak.BeforeTest()
s.cluster = mocktikv.NewCluster()
mocktikv.BootstrapWithSingleStore(s.cluster)
Expand Down Expand Up @@ -120,12 +124,12 @@ func (s *testSessionIncBackupSuite) TearDownTest(c *C) {
}

func (s *testSessionIncBackupSuite) makeSQL(c *C, tk *testkit.TestKit, sql string) *testkit.Result {
a := `/*--user=test;--password=test;--host=127.0.0.1;--execute=1;--backup=1;--port=3306;--enable-ignore-warnings;*/
a := `/*--user=test;--password=test;--host=127.0.0.1;--execute=1;--backup=1;--port=3306;--enable-ignore-warnings;real_row_count=%v;*/
inception_magic_start;
use test_inc;
%s;
inception_magic_commit;`
res := tk.MustQueryInc(fmt.Sprintf(a, sql))
res := tk.MustQueryInc(fmt.Sprintf(a, s.realRowCount, sql))

// 需要成功执行
for _, row := range res.Rows() {
Expand Down
Loading

0 comments on commit 7ffa335

Please sign in to comment.