Skip to content

Commit

Permalink
fix: config target info by flag
Browse files Browse the repository at this point in the history
  • Loading branch information
newborn22 committed Dec 10, 2024
1 parent 972df4f commit 6f5e5cb
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 22 deletions.
21 changes: 20 additions & 1 deletion go/internal/global/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ Licensed under the Apache v2(found in the LICENSE file in the root directory).

package global

import "time"
import (
"github.com/spf13/pflag"
"time"
"vitess.io/vitess/go/vt/servenv"
)

// Keyspace
const (
Expand Down Expand Up @@ -56,3 +60,18 @@ const (
const (
TopoServerConfigOverwriteShard = true
)

// *****************************************************************************************************************************

var (
MysqlServerPort = -1
)

func registerPluginFlags(fs *pflag.FlagSet) {
fs.IntVar(&MysqlServerPort, "mysql_server_port", MysqlServerPort, "If set, also listen for MySQL binary protocol connections on this port.")
}

func init() {
servenv.OnParseFor("vtgate", registerPluginFlags)
servenv.OnParseFor("vtcombo", registerPluginFlags)
}
31 changes: 31 additions & 0 deletions go/viperutil/vtgate_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,35 @@ func RegisterReloadHandlersForVtGate(v *ViperConfig) {
log.Errorf("fail to reload config %s=%s, err: %v", key, value, err)
}
})

// branch
v.ReloadHandler.AddReloadHandler("branch_default_name", func(key string, value string, fs *pflag.FlagSet) {
if err := fs.Set("branch_default_name", value); err != nil {
log.Errorf("fail to reload config %s=%s, err: %v", key, value, err)
}
})

v.ReloadHandler.AddReloadHandler("branch_default_target_host", func(key string, value string, fs *pflag.FlagSet) {
if err := fs.Set("branch_default_target_host", value); err != nil {
log.Errorf("fail to reload config %s=%s, err: %v", key, value, err)
}
})

v.ReloadHandler.AddReloadHandler("branch_default_target_port", func(key string, value string, fs *pflag.FlagSet) {
if err := fs.Set("branch_default_target_port", value); err != nil {
log.Errorf("fail to reload config %s=%s, err: %v", key, value, err)
}
})

v.ReloadHandler.AddReloadHandler("branch_default_target_user", func(key string, value string, fs *pflag.FlagSet) {
if err := fs.Set("branch_default_target_user", value); err != nil {
log.Errorf("fail to reload config %s=%s, err: %v", key, value, err)
}
})

v.ReloadHandler.AddReloadHandler("branch_default_target_password", func(key string, value string, fs *pflag.FlagSet) {
if err := fs.Set("branch_default_target_password", value); err != nil {
log.Errorf("fail to reload config %s=%s, err: %v", key, value, err)
}
})
}
55 changes: 42 additions & 13 deletions go/vt/vtgate/engine/branch_primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ import (
"context"
"fmt"
"github.com/go-sql-driver/mysql"
"github.com/spf13/pflag"
"strconv"
"strings"
"vitess.io/vitess/go/internal/global"
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/schemadiff"
"vitess.io/vitess/go/vt/servenv"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/branch"
)
Expand All @@ -26,23 +29,39 @@ const (
Show BranchCommandType = "show"
)

// todo enhancement: add flags to config
const (
DefaultBranchName = "my_branch"

var (
DefaultBranchName = "my_branch"
DefaultBranchTargetHost = "127.0.0.1"
DefaultBranchTargetPort = 15306
DefaultBranchTargetPort = -1
DefaultBranchTargetUser = "root"
DefaultBranchTargetPassword = "passwd"
)

func registerBranchFlags(fs *pflag.FlagSet) {
// todo add dynamic handler
fs.StringVar(&DefaultBranchName, "branch_default_name", DefaultBranchName, "default branch name")
fs.StringVar(&DefaultBranchTargetHost, "branch_default_target_host", DefaultBranchTargetHost, "default branch target host")
fs.IntVar(&DefaultBranchTargetPort, "branch_default_target_port", DefaultBranchTargetPort, "default branch target port")
fs.StringVar(&DefaultBranchTargetUser, "branch_default_target_user", DefaultBranchTargetUser, "default branch target user")
fs.StringVar(&DefaultBranchTargetPassword, "branch_default_target_password", DefaultBranchTargetPassword, "default branch target password")
}

func init() {
servenv.OnParseFor("vtgate", registerBranchFlags)
}

// Branch is an operator to deal with branch commands
type Branch struct {
// set when plan building
name string
commandType BranchCommandType
params branchParams

targetHost string
targetPort int
targetUser string
targetPassword string

noInputs
}

Expand Down Expand Up @@ -120,6 +139,16 @@ func BuildBranchPlan(branchCmd *sqlparser.BranchCommand) (*Branch, error) {
if err != nil {
return nil, fmt.Errorf("invalid branch command params: %w", err)
}

b.targetHost = DefaultBranchTargetHost
if DefaultBranchTargetPort == -1 {
b.targetPort = global.MysqlServerPort
} else {
b.targetPort = DefaultBranchTargetPort
}
b.targetUser = DefaultBranchTargetUser
b.targetPassword = DefaultBranchTargetPassword

return b, nil
}

Expand Down Expand Up @@ -467,7 +496,7 @@ func (b *Branch) branchCreate() (*sqltypes.Result, error) {
if err != nil {
return nil, err
}
targetHandler, err := createBranchTargetHandler(DefaultBranchTargetUser, DefaultBranchTargetPassword, DefaultBranchTargetHost, DefaultBranchTargetPort)
targetHandler, err := createBranchTargetHandler(b.targetUser, b.targetPassword, b.targetHost, b.targetPort)
if err != nil {
return nil, err
}
Expand All @@ -482,7 +511,7 @@ func (b *Branch) branchDiff() (*sqltypes.Result, error) {
if !ok {
return nil, fmt.Errorf("branch diff: invalid branch command params")
}
meta, bs, _, _, err := getBranchDataStruct(b.name)
meta, bs, _, _, err := getBranchDataStruct(b.name, b.targetUser, b.targetPassword, b.targetHost, b.targetPort)
if err != nil {
return nil, err
}
Expand All @@ -502,7 +531,7 @@ func (b *Branch) branchPrepareMergeBack() (*sqltypes.Result, error) {
return nil, fmt.Errorf("branch prepare merge back: invalid branch command params")
}

meta, bs, _, _, err := getBranchDataStruct(b.name)
meta, bs, _, _, err := getBranchDataStruct(b.name, b.targetUser, b.targetPassword, b.targetHost, b.targetPort)
if err != nil {
return nil, err
}
Expand All @@ -517,7 +546,7 @@ func (b *Branch) branchPrepareMergeBack() (*sqltypes.Result, error) {
}

func (b *Branch) branchMergeBack() (*sqltypes.Result, error) {
meta, bs, _, _, err := getBranchDataStruct(b.name)
meta, bs, _, _, err := getBranchDataStruct(b.name, b.targetUser, b.targetPassword, b.targetHost, b.targetPort)
if err != nil {
return nil, err
}
Expand All @@ -526,7 +555,7 @@ func (b *Branch) branchMergeBack() (*sqltypes.Result, error) {

func (b *Branch) branchCleanUp() (*sqltypes.Result, error) {
// get target handler
targetHandler, err := createBranchTargetHandler(DefaultBranchTargetUser, DefaultBranchTargetPassword, DefaultBranchTargetHost, DefaultBranchTargetPort)
targetHandler, err := createBranchTargetHandler(b.targetUser, b.targetPassword, b.targetHost, b.targetPort)
if err != nil {
return nil, err
}
Expand All @@ -540,7 +569,7 @@ func (b *Branch) branchShow() (*sqltypes.Result, error) {
return nil, fmt.Errorf("branch show: invalid branch command params")
}

meta, _, _, targetHandler, err := getBranchDataStruct(b.name)
meta, _, _, targetHandler, err := getBranchDataStruct(b.name, b.targetUser, b.targetPassword, b.targetHost, b.targetPort)
if err != nil {
return nil, err
}
Expand All @@ -557,9 +586,9 @@ func (b *Branch) branchShow() (*sqltypes.Result, error) {
}
}

func getBranchDataStruct(name string) (*branch.BranchMeta, *branch.BranchService, *branch.SourceMySQLService, *branch.TargetMySQLService, error) {
func getBranchDataStruct(name string, targetUser, targetPassword, targetHost string, targetPort int) (*branch.BranchMeta, *branch.BranchService, *branch.SourceMySQLService, *branch.TargetMySQLService, error) {
// get target handler
targetHandler, err := createBranchTargetHandler(DefaultBranchTargetUser, DefaultBranchTargetPassword, DefaultBranchTargetHost, DefaultBranchTargetPort)
targetHandler, err := createBranchTargetHandler(targetUser, targetPassword, targetHost, targetPort)
if err != nil {
return nil, nil, nil, nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/engine/declarative_ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ var configInStr = HintsInStr{
}

// not all options are supported to configure
func registerPoolSizeControllerConfigTypeFlags(fs *pflag.FlagSet) {
func registerDeclarativeDDLDiffHintsFlags(fs *pflag.FlagSet) {
fs.StringVar(&configInStr.AutoIncrementStrategy, "declarative_ddl_hints_auto_increment_strategy", configInStr.AutoIncrementStrategy, "auto increment strategy")
fs.StringVar(&configInStr.RangeRotationStrategy, "declarative_ddl_hints_range_rotation_strategy", configInStr.RangeRotationStrategy, "range rotation strategy")
fs.StringVar(&configInStr.ConstraintNamesStrategy, "declarative_ddl_hints_constraint_names_strategy", configInStr.ConstraintNamesStrategy, "constraint names strategy")
Expand All @@ -74,7 +74,7 @@ func registerPoolSizeControllerConfigTypeFlags(fs *pflag.FlagSet) {
}

func init() {
servenv.OnParseFor("vtgate", registerPoolSizeControllerConfigTypeFlags)
servenv.OnParseFor("vtgate", registerDeclarativeDDLDiffHintsFlags)
}

// DeclarativeDDL is an operator to send schema diff DDL queries to the specific keyspace, tabletType and destination
Expand Down
9 changes: 4 additions & 5 deletions go/vt/vtgate/plugin_mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"sync/atomic"
"syscall"
"time"
"vitess.io/vitess/go/internal/global"

topodatapb "vitess.io/vitess/go/vt/proto/topodata"

Expand All @@ -57,7 +58,6 @@ import (
)

var (
mysqlServerPort = -1
mysqlServerBindAddress string
mysqlServerSocketPath string
mysqlTCPVersion = "tcp"
Expand Down Expand Up @@ -85,7 +85,6 @@ var (
)

func registerPluginFlags(fs *pflag.FlagSet) {
fs.IntVar(&mysqlServerPort, "mysql_server_port", mysqlServerPort, "If set, also listen for MySQL binary protocol connections on this port.")
fs.StringVar(&mysqlServerBindAddress, "mysql_server_bind_address", mysqlServerBindAddress, "Binds on this address when listening to MySQL binary protocol. Useful to restrict listening to 'localhost' only for instance.")
fs.StringVar(&mysqlServerSocketPath, "mysql_server_socket_path", mysqlServerSocketPath, "This option specifies the Unix socket file to use when listening for local connections. By default it will be empty and it won't listen to a unix socket")
fs.StringVar(&mysqlTCPVersion, "mysql_tcp_version", mysqlTCPVersion, "Select tcp, tcp4, or tcp6 to control the socket type.")
Expand Down Expand Up @@ -464,7 +463,7 @@ func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mys
// It should be called only once in a process.
func initMySQLProtocol() {
// Flag is not set, just return.
if mysqlServerPort < 0 && mysqlServerSocketPath == "" {
if global.MysqlServerPort < 0 && mysqlServerSocketPath == "" {
return
}

Expand Down Expand Up @@ -494,10 +493,10 @@ func initMySQLProtocol() {
// Create a Listener.
var err error
vtgateHandle = newVtgateHandler(rpcVTGate)
if mysqlServerPort >= 0 {
if global.MysqlServerPort >= 0 {
mysqlListener, err = mysql.NewListener(
mysqlTCPVersion,
net.JoinHostPort(mysqlServerBindAddress, fmt.Sprintf("%v", mysqlServerPort)),
net.JoinHostPort(mysqlServerBindAddress, fmt.Sprintf("%v", global.MysqlServerPort)),
authServer,
vtgateHandle,
mysqlConnReadTimeout,
Expand Down
3 changes: 2 additions & 1 deletion go/vt/vtgate/vtgate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"fmt"
"strings"
"testing"
"vitess.io/vitess/go/internal/global"

"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -68,7 +69,7 @@ func init() {
transactionMode = "MULTI"
Init(context.Background(), hcVTGateTest, newSandboxForCells([]string{"aa"}), "aa", nil, querypb.ExecuteOptions_Gen4)

mysqlServerPort = 0
global.MysqlServerPort = 0
mysqlAuthServerImpl = "none"
initMySQLProtocol()
}
Expand Down

0 comments on commit 6f5e5cb

Please sign in to comment.