Skip to content

Commit

Permalink
update: 细节优化
Browse files Browse the repository at this point in the history
  • Loading branch information
hanchuanchuan committed Apr 3, 2020
1 parent f4362e5 commit c165a35
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 59 deletions.
88 changes: 44 additions & 44 deletions session/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,53 +208,53 @@ func (s *session) initConnection() (err error) {
return
}

// SwitchDatabase USE切换到当前数据库. (避免连接断开后当前数据库置空)
func (s *session) SwitchDatabase(db *gorm.DB) error {
name := s.DBName
if name == "" {
name = s.opt.db
}
if name == "" {
return nil
}
// // SwitchDatabase USE切换到当前数据库. (避免连接断开后当前数据库置空)
// func (s *session) SwitchDatabase(db *gorm.DB) error {
// name := s.DBName
// if name == "" {
// name = s.opt.db
// }
// if name == "" {
// return nil
// }

// log.Infof("SwitchDatabase: %v", name)
_, err := db.DB().Exec(fmt.Sprintf("USE `%s`", name))
if err != nil {
log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
if myErr, ok := err.(*mysqlDriver.MySQLError); ok {
s.AppendErrorMessage(myErr.Message)
} else {
s.AppendErrorMessage(err.Error())
}
}
return err
}
// // log.Infof("SwitchDatabase: %v", name)
// _, err := db.DB().Exec(fmt.Sprintf("USE `%s`", name))
// if err != nil {
// log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
// if myErr, ok := err.(*mysqlDriver.MySQLError); ok {
// s.AppendErrorMessage(myErr.Message)
// } else {
// s.AppendErrorMessage(err.Error())
// }
// }
// return err
// }

// GetDatabase 获取当前数据库
func (s *session) GetDatabase() string {
log.Debug("GetDatabase")
// // GetDatabase 获取当前数据库
// func (s *session) GetDatabase() string {
// log.Debug("GetDatabase")

var value string
sql := "select database();"
// var value string
// sql := "select database();"

rows, err := s.Raw(sql)
if rows != nil {
defer rows.Close()
}
// rows, err := s.Raw(sql)
// if rows != nil {
// defer rows.Close()
// }

if err != nil {
log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
if myErr, ok := err.(*mysqlDriver.MySQLError); ok {
s.AppendErrorMessage(myErr.Message)
} else {
s.AppendErrorMessage(err.Error())
}
} else {
for rows.Next() {
rows.Scan(&value)
}
}
// if err != nil {
// log.Errorf("con:%d %v", s.sessionVars.ConnectionID, err)
// if myErr, ok := err.(*mysqlDriver.MySQLError); ok {
// s.AppendErrorMessage(myErr.Message)
// } else {
// s.AppendErrorMessage(err.Error())
// }
// } else {
// for rows.Next() {
// rows.Scan(&value)
// }
// }

return value
}
// return value
// }
12 changes: 6 additions & 6 deletions session/session_inception.go
Original file line number Diff line number Diff line change
Expand Up @@ -1658,7 +1658,7 @@ func (s *session) executeAllStatement(ctx context.Context) {
}

// 用于事务. 判断是否为DML语句
lastIsDMLTrans := false
// lastIsDMLTrans := false
for i, record := range s.recordSets.All() {

// 忽略不需要备份的类型
Expand Down Expand Up @@ -1689,7 +1689,7 @@ func (s *session) executeAllStatement(ctx context.Context) {
}
}

lastIsDMLTrans = true
// lastIsDMLTrans = true
case *ast.UseStmt, *ast.SetStmt:
// 环境命令
// 事务内部和非事务均需要执行
Expand Down Expand Up @@ -1724,10 +1724,10 @@ func (s *session) executeAllStatement(ctx context.Context) {

// 如果前端是DML语句,则在执行DDL前切换一次数据库
// log.Infof("lastIsDMLTrans: %v", lastIsDMLTrans)
if lastIsDMLTrans {
s.SwitchDatabase(s.ddlDB)
lastIsDMLTrans = false
}
// if lastIsDMLTrans {
// s.SwitchDatabase(s.ddlDB)
// lastIsDMLTrans = false
// }

s.executeRemoteCommand(record, true)

Expand Down
18 changes: 9 additions & 9 deletions session/session_inception_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,12 +491,12 @@ func (s *testCommon) assertRows(c *C, rows [][]interface{}, rollbackSqls ...stri
if tableName == "" {
sql := "select tablename from `%s`.`%s` where opid_time = ?"
sql = fmt.Sprintf(sql, backupDBName, s.remoteBackupTable)
rows, err := s.db.Raw(sql, opid).Rows()
tableRows, err := s.db.Raw(sql, opid).Rows()
c.Assert(err, IsNil)
for rows.Next() {
rows.Scan(&tableName)
for tableRows.Next() {
tableRows.Scan(&tableName)
}
rows.Close()
tableRows.Close()
}
c.Assert(tableName, Not(Equals), "", Commentf("%v", row))

Expand Down Expand Up @@ -536,18 +536,18 @@ func (s *testCommon) assertRows(c *C, rows [][]interface{}, rollbackSqls ...stri
if len(ids) > 0 {
sql := "select rollback_statement from %s where opid_time in (?) order by opid_time,id;"
sql = fmt.Sprintf(sql, currentTable)
rows, err := s.db.Raw(sql, ids).Rows()
rollbackRows, err := s.db.Raw(sql, ids).Rows()
c.Assert(err, IsNil)

str := ""
result1 := []string{}
for rows.Next() {
rows.Scan(&str)
for rollbackRows.Next() {
rollbackRows.Scan(&str)
result1 = append(result1, s.trim(str))
}
rows.Close()
rollbackRows.Close()

c.Assert(len(result1), Not(Equals), 0, Commentf("------2-----: %v", sql))
c.Assert(len(result1), Not(Equals), 0, Commentf("------2-----: %v", rows))
result = append(result, result1...)
}

Expand Down

0 comments on commit c165a35

Please sign in to comment.