From f4362e59714c350bf6b33579fc1e3c1d4632c3ce Mon Sep 17 00:00:00 2001 From: hanchuanchuan Date: Thu, 2 Apr 2020 16:55:30 +0800 Subject: [PATCH 1/2] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=9C=A8=E4=BA=8B?= =?UTF-8?q?=E5=8A=A1=E4=B8=ADDDL=E5=92=8CDML=E6=B7=B7=E5=90=88=E6=89=A7?= =?UTF-8?q?=E8=A1=8C=E6=97=B6=E5=8F=AF=E8=83=BD=E5=87=BA=E9=94=99=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98=20(#182)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- session/conn.go | 80 +++++++++++++++++++++++- session/session.go | 3 + session/session_inception.go | 44 ++++++++++--- session/session_inception_common_test.go | 8 +-- session/session_inception_tran_test.go | 53 ++++++++++++++++ 5 files changed, 174 insertions(+), 14 deletions(-) diff --git a/session/conn.go b/session/conn.go index 0ff00d8c..bcd011ff 100644 --- a/session/conn.go +++ b/session/conn.go @@ -87,7 +87,7 @@ func (s *session) Raw(sqlStr string) (rows *sql.Rows, err error) { return } -// Raw 执行sql语句,连接失败时自动重连,自动重置当前数据库 +// Exec 执行sql语句,连接失败时自动重连,自动重置当前数据库 func (s *session) Exec(sqlStr string, retry bool) (res sql.Result, err error) { // 连接断开无效时,自动重试 for i := 0; i < maxBadConnRetries; i++ { @@ -114,6 +114,33 @@ func (s *session) Exec(sqlStr string, retry bool) (res sql.Result, err error) { return } +// ExecDDL 执行sql语句,连接失败时自动重连,自动重置当前数据库 +func (s *session) ExecDDL(sqlStr string, retry bool) (res sql.Result, err error) { + // 连接断开无效时,自动重试 + for i := 0; i < maxBadConnRetries; i++ { + res, err = s.ddlDB.DB().Exec(sqlStr) + if err == nil { + return + } else { + log.Errorf("con:%d %v sql:%s", s.sessionVars.ConnectionID, err, sqlStr) + if err == mysqlDriver.ErrInvalidConn { + err1 := s.initConnection() + if err1 != nil { + return res, err1 + } + if retry { + s.AppendErrorMessage(mysqlDriver.ErrInvalidConn.Error()) + continue + } else { + return + } + } + return + } + } + return +} + // Raw 执行sql语句,连接失败时自动重连,自动重置当前数据库 func (s *session) RawScan(sqlStr string, dest interface{}) (err error) { // 连接断开无效时,自动重试 @@ -180,3 +207,54 @@ func (s *session) initConnection() (err error) { } return } + +// SwitchDatabase USE切换到当前数据库. (避免连接断开后当前数据库置空) +func (s *session) SwitchDatabase(db *gorm.DB) error { + name := s.DBName + if name == "" { + name = s.opt.db + } + if name == "" { + return nil + } + + // log.Infof("SwitchDatabase: %v", name) + _, err := db.DB().Exec(fmt.Sprintf("USE `%s`", name)) + if err != nil { + log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err) + if myErr, ok := err.(*mysqlDriver.MySQLError); ok { + s.AppendErrorMessage(myErr.Message) + } else { + s.AppendErrorMessage(err.Error()) + } + } + return err +} + +// GetDatabase 获取当前数据库 +func (s *session) GetDatabase() string { + log.Debug("GetDatabase") + + var value string + sql := "select database();" + + rows, err := s.Raw(sql) + if rows != nil { + defer rows.Close() + } + + if err != nil { + log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err) + if myErr, ok := err.(*mysqlDriver.MySQLError); ok { + s.AppendErrorMessage(myErr.Message) + } else { + s.AppendErrorMessage(err.Error()) + } + } else { + for rows.Next() { + rows.Scan(&value) + } + } + + return value +} diff --git a/session/session.go b/session/session.go index bf09cfa0..60424f81 100644 --- a/session/session.go +++ b/session/session.go @@ -169,6 +169,9 @@ type session struct { db *gorm.DB backupdb *gorm.DB + // 执行DDL操作的数据库连接. 仅用于事务功能 + ddlDB *gorm.DB + DBName string myRecord *Record diff --git a/session/session_inception.go b/session/session_inception.go index 719cf774..82de5670 100644 --- a/session/session_inception.go +++ b/session/session_inception.go @@ -21,6 +21,7 @@ import ( "bytes" "crypto/tls" "crypto/x509" + "database/sql" "database/sql/driver" "fmt" "io/ioutil" @@ -523,6 +524,9 @@ func (s *session) executeInc(ctx context.Context, sql string) (recordSets []sqle if s.db != nil { defer s.db.Close() } + if s.ddlDB != nil { + defer s.ddlDB.Close() + } if s.backupdb != nil { defer s.backupdb.Close() } @@ -1653,6 +1657,8 @@ func (s *session) executeAllStatement(ctx context.Context) { trans = make([]*Record, 0, s.opt.tranBatch) } + // 用于事务. 判断是否为DML语句 + lastIsDMLTrans := false for i, record := range s.recordSets.All() { // 忽略不需要备份的类型 @@ -1682,11 +1688,13 @@ func (s *session) executeAllStatement(ctx context.Context) { } } } + + lastIsDMLTrans = true case *ast.UseStmt, *ast.SetStmt: // 环境命令 // 事务内部和非事务均需要执行 // log.Infof("1111: [%s] [%d] %s,RowsAffected: %d", s.DBName, s.fetchThreadID(), record.Sql, record.AffectedRows) - _, err := s.Exec(record.Sql, true) + _, err := s.ExecDDL(record.Sql, true) if err != nil { // log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err) if myErr, ok := err.(*mysqlDriver.MySQLError); ok { @@ -1714,7 +1722,14 @@ func (s *session) executeAllStatement(ctx context.Context) { trans = nil } - s.executeRemoteCommand(record) + // 如果前端是DML语句,则在执行DDL前切换一次数据库 + // log.Infof("lastIsDMLTrans: %v", lastIsDMLTrans) + if lastIsDMLTrans { + s.SwitchDatabase(s.ddlDB) + lastIsDMLTrans = false + } + + s.executeRemoteCommand(record, true) // trans = append(trans, record) // s.executeTransaction(trans) @@ -1729,7 +1744,7 @@ func (s *session) executeAllStatement(ctx context.Context) { } } } else { - s.executeRemoteCommand(record) + s.executeRemoteCommand(record, false) } if s.hasErrorBefore() { @@ -1944,7 +1959,7 @@ func (s *session) executeTransaction(records []*Record) int { return 0 } -func (s *session) executeRemoteCommand(record *Record) int { +func (s *session) executeRemoteCommand(record *Record, isTran bool) int { s.myRecord = record record.Stage = StageExec @@ -1970,7 +1985,7 @@ func (s *session) executeRemoteCommand(record *Record) int { *ast.SetStmt, *ast.DropIndexStmt: - s.executeRemoteStatement(record) + s.executeRemoteStatement(record, isTran) default: log.Infof("无匹配类型: %T\n", node) @@ -2179,10 +2194,10 @@ func statisticsTableSQL() string { return buf.String() } -func (s *session) executeRemoteStatement(record *Record) { +func (s *session) executeRemoteStatement(record *Record, isTran bool) { log.Debug("executeRemoteStatement") - sql := record.Sql + sqlStmt := record.Sql start := time.Now() @@ -2203,7 +2218,13 @@ func (s *session) executeRemoteStatement(record *Record) { return } else { - res, err := s.Exec(sql, false) + var res sql.Result + var err error + if isTran { + res, err = s.ExecDDL(sqlStmt, false) + } else { + res, err = s.Exec(sqlStmt, false) + } record.ExecTime = fmt.Sprintf("%.3f", time.Since(start).Seconds()) record.ExecTimestamp = time.Now().Unix() @@ -2293,7 +2314,7 @@ func (s *session) executeRemoteStatementAndBackup(record *Record) { return } - s.executeRemoteStatement(record) + s.executeRemoteStatement(record, false) if !s.hasError() || record.ExecComplete { if s.opt.backup { @@ -2904,6 +2925,11 @@ func (s *session) parseOptions(sql string) { return } + if s.opt.tranBatch > 1 { + s.ddlDB, _ = gorm.Open("mysql", fmt.Sprintf("%s&autocommit=1", addr)) + s.ddlDB.LogMode(false) + } + // 禁用日志记录器,不显示任何日志 db.LogMode(false) diff --git a/session/session_inception_common_test.go b/session/session_inception_common_test.go index 2b1af45c..e8407803 100644 --- a/session/session_inception_common_test.go +++ b/session/session_inception_common_test.go @@ -507,10 +507,9 @@ func (s *testCommon) assertRows(c *C, rows [][]interface{}, rollbackSqls ...stri // 如果表改变了,或者超过500行了 if lastTable != currentTable || len(ids) >= 500 { - lastTable = currentTable if len(ids) > 0 { sql := "select rollback_statement from %s where opid_time in (?) order by opid_time,id;" - sql = fmt.Sprintf(sql, currentTable) + sql = fmt.Sprintf(sql, lastTable) rows, err := s.db.Raw(sql, ids).Rows() c.Assert(err, IsNil) @@ -522,11 +521,12 @@ func (s *testCommon) assertRows(c *C, rows [][]interface{}, rollbackSqls ...stri } rows.Close() - c.Assert(len(result1), Not(Equals), 0, Commentf("-----------: %v", sql)) + c.Assert(len(result1), Not(Equals), 0, Commentf("-----------: %v,%v", sql, ids)) result = append(result, result1...) ids = nil } + lastTable = currentTable } ids = append(ids, opid) @@ -551,7 +551,7 @@ func (s *testCommon) assertRows(c *C, rows [][]interface{}, rollbackSqls ...stri result = append(result, result1...) } - c.Assert(len(result), Equals, len(rollbackSqls), Commentf("%v", rows)) + c.Assert(len(result), Equals, len(rollbackSqls), Commentf("%v", result)) // 如果是UPDATE多表操作,此时回滚的SQL可能是无序的 if len(result) > 1 && strings.HasPrefix(result[0], "UPDATE") { diff --git a/session/session_inception_tran_test.go b/session/session_inception_tran_test.go index e3c72e9c..60a195d2 100644 --- a/session/session_inception_tran_test.go +++ b/session/session_inception_tran_test.go @@ -444,3 +444,56 @@ func (s *testSessionIncTranSuite) TestDelete(c *C) { c.Assert(backup, Equals, "INSERT INTO `test_inc`.`t1`(`id`,`c1`) VALUES(1,'😁😄🙂👩');", Commentf("%v", res.Rows())) } + +func (s *testSessionIncTranSuite) TestCreateTable(c *C) { + saved := config.GetGlobalConfig().Inc + defer func() { + config.GetGlobalConfig().Inc = saved + }() + + var ( + res *testkit.Result + // row []interface{} + // backup string + ) + + res = s.mustRunBackupTran(c, `DROP TABLE IF EXISTS t1,t2; + + CREATE TABLE t1 (id int(11) NOT NULL, + c1 int(11) DEFAULT NULL, + c2 int(11) DEFAULT NULL, + PRIMARY KEY (id)); + + INSERT INTO t1 VALUES (1, 1, 1); + + CREATE TABLE t2 (id int(11) NOT NULL, + c1 int(11) DEFAULT NULL, + c2 int(11) DEFAULT NULL, + PRIMARY KEY (id))`) + s.assertRows(c, res.Rows()[2:], + "DROP TABLE `test_inc`.`t1`;", + "DELETE FROM `test_inc`.`t1` WHERE `id`=1;", + "DROP TABLE `test_inc`.`t2`;") + + res = s.mustRunBackupTran(c, `DROP TABLE IF EXISTS t1,t2; + create table t1(id int primary key,c1 int); + insert into t1 values(1,1),(2,2); + delete from t1 where id=1; + alter table t1 add column c2 int; + insert into t1 values(3,3,3); + delete from t1 where id>0; + create table t2(id int primary key,c1 int); + insert into t2 values(3,3);`) + s.assertRows(c, res.Rows()[2:], + "DROP TABLE `test_inc`.`t1`;", + "DELETE FROM `test_inc`.`t1` WHERE `id`=1;", + "DELETE FROM `test_inc`.`t1` WHERE `id`=2;", + "INSERT INTO `test_inc`.`t1`(`id`,`c1`) VALUES(1,1);", + "ALTER TABLE `test_inc`.`t1` DROP COLUMN `c2`;", + "DELETE FROM `test_inc`.`t1` WHERE `id`=3;", + "INSERT INTO `test_inc`.`t1`(`id`,`c1`,`c2`) VALUES(2,2,NULL);", + "INSERT INTO `test_inc`.`t1`(`id`,`c1`,`c2`) VALUES(3,3,3);", + "DROP TABLE `test_inc`.`t2`;", + "DELETE FROM `test_inc`.`t2` WHERE `id`=3;") + +} From c165a35d6e591723595de3b6c4c583169c223cbb Mon Sep 17 00:00:00 2001 From: hanchuanchuan Date: Fri, 3 Apr 2020 17:43:44 +0800 Subject: [PATCH 2/2] =?UTF-8?q?update:=20=E7=BB=86=E8=8A=82=E4=BC=98?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- session/conn.go | 88 ++++++++++++------------ session/session_inception.go | 12 ++-- session/session_inception_common_test.go | 18 ++--- 3 files changed, 59 insertions(+), 59 deletions(-) diff --git a/session/conn.go b/session/conn.go index bcd011ff..ed1ddb37 100644 --- a/session/conn.go +++ b/session/conn.go @@ -208,53 +208,53 @@ func (s *session) initConnection() (err error) { return } -// SwitchDatabase USE切换到当前数据库. (避免连接断开后当前数据库置空) -func (s *session) SwitchDatabase(db *gorm.DB) error { - name := s.DBName - if name == "" { - name = s.opt.db - } - if name == "" { - return nil - } +// // SwitchDatabase USE切换到当前数据库. (避免连接断开后当前数据库置空) +// func (s *session) SwitchDatabase(db *gorm.DB) error { +// name := s.DBName +// if name == "" { +// name = s.opt.db +// } +// if name == "" { +// return nil +// } - // log.Infof("SwitchDatabase: %v", name) - _, err := db.DB().Exec(fmt.Sprintf("USE `%s`", name)) - if err != nil { - log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err) - if myErr, ok := err.(*mysqlDriver.MySQLError); ok { - s.AppendErrorMessage(myErr.Message) - } else { - s.AppendErrorMessage(err.Error()) - } - } - return err -} +// // log.Infof("SwitchDatabase: %v", name) +// _, err := db.DB().Exec(fmt.Sprintf("USE `%s`", name)) +// if err != nil { +// log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err) +// if myErr, ok := err.(*mysqlDriver.MySQLError); ok { +// s.AppendErrorMessage(myErr.Message) +// } else { +// s.AppendErrorMessage(err.Error()) +// } +// } +// return err +// } -// GetDatabase 获取当前数据库 -func (s *session) GetDatabase() string { - log.Debug("GetDatabase") +// // GetDatabase 获取当前数据库 +// func (s *session) GetDatabase() string { +// log.Debug("GetDatabase") - var value string - sql := "select database();" +// var value string +// sql := "select database();" - rows, err := s.Raw(sql) - if rows != nil { - defer rows.Close() - } +// rows, err := s.Raw(sql) +// if rows != nil { +// defer rows.Close() +// } - if err != nil { - log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err) - if myErr, ok := err.(*mysqlDriver.MySQLError); ok { - s.AppendErrorMessage(myErr.Message) - } else { - s.AppendErrorMessage(err.Error()) - } - } else { - for rows.Next() { - rows.Scan(&value) - } - } +// if err != nil { +// log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err) +// if myErr, ok := err.(*mysqlDriver.MySQLError); ok { +// s.AppendErrorMessage(myErr.Message) +// } else { +// s.AppendErrorMessage(err.Error()) +// } +// } else { +// for rows.Next() { +// rows.Scan(&value) +// } +// } - return value -} +// return value +// } diff --git a/session/session_inception.go b/session/session_inception.go index 82de5670..694deadb 100644 --- a/session/session_inception.go +++ b/session/session_inception.go @@ -1658,7 +1658,7 @@ func (s *session) executeAllStatement(ctx context.Context) { } // 用于事务. 判断是否为DML语句 - lastIsDMLTrans := false + // lastIsDMLTrans := false for i, record := range s.recordSets.All() { // 忽略不需要备份的类型 @@ -1689,7 +1689,7 @@ func (s *session) executeAllStatement(ctx context.Context) { } } - lastIsDMLTrans = true + // lastIsDMLTrans = true case *ast.UseStmt, *ast.SetStmt: // 环境命令 // 事务内部和非事务均需要执行 @@ -1724,10 +1724,10 @@ func (s *session) executeAllStatement(ctx context.Context) { // 如果前端是DML语句,则在执行DDL前切换一次数据库 // log.Infof("lastIsDMLTrans: %v", lastIsDMLTrans) - if lastIsDMLTrans { - s.SwitchDatabase(s.ddlDB) - lastIsDMLTrans = false - } + // if lastIsDMLTrans { + // s.SwitchDatabase(s.ddlDB) + // lastIsDMLTrans = false + // } s.executeRemoteCommand(record, true) diff --git a/session/session_inception_common_test.go b/session/session_inception_common_test.go index e8407803..60b91080 100644 --- a/session/session_inception_common_test.go +++ b/session/session_inception_common_test.go @@ -491,12 +491,12 @@ func (s *testCommon) assertRows(c *C, rows [][]interface{}, rollbackSqls ...stri if tableName == "" { sql := "select tablename from `%s`.`%s` where opid_time = ?" sql = fmt.Sprintf(sql, backupDBName, s.remoteBackupTable) - rows, err := s.db.Raw(sql, opid).Rows() + tableRows, err := s.db.Raw(sql, opid).Rows() c.Assert(err, IsNil) - for rows.Next() { - rows.Scan(&tableName) + for tableRows.Next() { + tableRows.Scan(&tableName) } - rows.Close() + tableRows.Close() } c.Assert(tableName, Not(Equals), "", Commentf("%v", row)) @@ -536,18 +536,18 @@ func (s *testCommon) assertRows(c *C, rows [][]interface{}, rollbackSqls ...stri if len(ids) > 0 { sql := "select rollback_statement from %s where opid_time in (?) order by opid_time,id;" sql = fmt.Sprintf(sql, currentTable) - rows, err := s.db.Raw(sql, ids).Rows() + rollbackRows, err := s.db.Raw(sql, ids).Rows() c.Assert(err, IsNil) str := "" result1 := []string{} - for rows.Next() { - rows.Scan(&str) + for rollbackRows.Next() { + rollbackRows.Scan(&str) result1 = append(result1, s.trim(str)) } - rows.Close() + rollbackRows.Close() - c.Assert(len(result1), Not(Equals), 0, Commentf("------2-----: %v", sql)) + c.Assert(len(result1), Not(Equals), 0, Commentf("------2-----: %v", rows)) result = append(result, result1...) }