Skip to content

Commit

Permalink
fix: add mysql service interface
Browse files Browse the repository at this point in the history
  • Loading branch information
newborn22 committed Dec 11, 2024
1 parent 8a7fdb1 commit 4f7bca0
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 74 deletions.
2 changes: 1 addition & 1 deletion go/vt/vtgate/branch/common_mysql_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package branch
import "fmt"

type CommonMysqlService struct {
mysqlService *MysqlService
mysqlService MysqlService
}

// GetBranchSchema retrieves CREATE TABLE statements for all tables in databases filtered by `databasesInclude` and `databasesExclude`
Expand Down
71 changes: 4 additions & 67 deletions go/vt/vtgate/branch/mysql_service.go
Original file line number Diff line number Diff line change
@@ -1,76 +1,13 @@
package branch

import (
"context"
"database/sql"
"fmt"

"github.com/go-sql-driver/mysql"
_ "github.com/go-sql-driver/mysql"
)

type MysqlService struct {
db *sql.DB
}

func NewMysqlService(db *sql.DB) (*MysqlService, error) {
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping MySQL: %w", err)
}
return &MysqlService{db: db}, nil
}

func NewMysqlServiceWithConfig(config *mysql.Config) (*MysqlService, error) {
config.MultiStatements = true
db, err := sql.Open("mysql", config.FormatDSN())
if err != nil {
return nil, fmt.Errorf("failed to connect to MySQL: %w", err)
}

service, err := NewMysqlService(db)
if err != nil {
db.Close()
return nil, err
}

return service, nil
}

func (m *MysqlService) Close() error {
return m.db.Close()
}

func (m *MysqlService) Query(query string) (*sql.Rows, error) {
return m.db.Query(query)
}

func (m *MysqlService) Exec(database, query string) (sql.Result, error) {
ctx := context.Background()
if database != "" {
query = fmt.Sprintf("USE %s; %s", database, query)
}
conn, err := m.db.Conn(ctx)
if err != nil {
return nil, err
}
defer conn.Close()
return conn.ExecContext(ctx, query)
}

func (m *MysqlService) ExecuteInTxn(queries ...string) error {
tx, err := m.db.Begin()
if err != nil {
return err
}
// make sure to rollback if any query fails
defer tx.Rollback()

for _, query := range queries {
_, err := tx.Exec(query)
if err != nil {
return err
}
}

return tx.Commit()
type MysqlService interface {
Query(query string) (*sql.Rows, error)
Exec(database, query string) (sql.Result, error)
ExecuteInTxn(queries ...string) error
}
76 changes: 76 additions & 0 deletions go/vt/vtgate/branch/native_mysql_service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package branch

import (
"context"
"database/sql"
"fmt"

"github.com/go-sql-driver/mysql"
_ "github.com/go-sql-driver/mysql"
)

type NativeMysqlService struct {
db *sql.DB
}

func NewMysqlService(db *sql.DB) (*NativeMysqlService, error) {
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping MySQL: %w", err)
}
return &NativeMysqlService{db: db}, nil
}

func NewMysqlServiceWithConfig(config *mysql.Config) (*NativeMysqlService, error) {
config.MultiStatements = true
db, err := sql.Open("mysql", config.FormatDSN())
if err != nil {
return nil, fmt.Errorf("failed to connect to MySQL: %w", err)
}

service, err := NewMysqlService(db)
if err != nil {
db.Close()
return nil, err
}

return service, nil
}

func (m *NativeMysqlService) Close() error {
return m.db.Close()
}

func (m *NativeMysqlService) Query(query string) (*sql.Rows, error) {
return m.db.Query(query)
}

func (m *NativeMysqlService) Exec(database, query string) (sql.Result, error) {
ctx := context.Background()
if database != "" {
query = fmt.Sprintf("USE %s; %s", database, query)
}
conn, err := m.db.Conn(ctx)
if err != nil {
return nil, err
}
defer conn.Close()
return conn.ExecContext(ctx, query)
}

func (m *NativeMysqlService) ExecuteInTxn(queries ...string) error {
tx, err := m.db.Begin()
if err != nil {
return err
}
// make sure to rollback if any query fails
defer tx.Rollback()

for _, query := range queries {
_, err := tx.Exec(query)
if err != nil {
return err
}
}

return tx.Commit()
}
File renamed without changes.
4 changes: 2 additions & 2 deletions go/vt/vtgate/branch/source_mysql_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import (

type SourceMySQLService struct {
*CommonMysqlService
mysqlService *MysqlService
mysqlService MysqlService
}

func NewSourceMySQLService(mysqlService *MysqlService) *SourceMySQLService {
func NewSourceMySQLService(mysqlService MysqlService) *SourceMySQLService {
return &SourceMySQLService{
CommonMysqlService: &CommonMysqlService{
mysqlService: mysqlService,
Expand Down
6 changes: 3 additions & 3 deletions go/vt/vtgate/branch/target_mysql_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import (

type TargetMySQLService struct {
*CommonMysqlService
mysqlService *MysqlService
mysqlService MysqlService
}

func NewTargetMySQLService(mysqlService *MysqlService) *TargetMySQLService {
func NewTargetMySQLService(mysqlService MysqlService) *TargetMySQLService {
return &TargetMySQLService{
CommonMysqlService: &CommonMysqlService{
mysqlService: mysqlService,
Expand Down Expand Up @@ -82,7 +82,7 @@ func (t *TargetMySQLService) ApplySnapshot(name string) error {
return nil
}

func (t *TargetMySQLService) GetMysqlService() *MysqlService {
func (t *TargetMySQLService) GetMysqlService() MysqlService {
return t.mysqlService
}

Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/branch/test_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/DATA-DOG/go-sqlmock"
)

func NewMockMysqlService(t *testing.T) (*MysqlService, sqlmock.Sqlmock) {
func NewMockMysqlService(t *testing.T) (*NativeMysqlService, sqlmock.Sqlmock) {
// use QueryMatcherEqual to match exact query
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))

Expand Down

0 comments on commit 4f7bca0

Please sign in to comment.