Skip to content

Commit

Permalink
auth: reconnect backend
Browse files Browse the repository at this point in the history
Signed-off-by: disksing <[email protected]>
  • Loading branch information
disksing committed Oct 30, 2023
1 parent 28d53d3 commit b353ff0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
21 changes: 18 additions & 3 deletions pkg/proxy/backend/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/go-mysql-org/go-mysql/mysql"
gomysql "github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tiproxy/lib/util/errors"
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
Expand Down Expand Up @@ -159,6 +160,8 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
auth.attrs = clientResp.Attrs
auth.zstdLevel = clientResp.ZstdLevel

RECONNECT:

// In case of testing, backendIO is passed manually that we don't want to bother with the routing logic.
backendIO, err := getBackendIO(cctx, auth, clientResp, 15*time.Second)
if err != nil {
Expand Down Expand Up @@ -214,7 +217,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
pktIdx := 0
loop:
for {
serverPkt, err := forwardMsg(backendIO, clientIO)
serverPkt, err := backendIO.ReadPacket()
if err != nil {
// tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence
// tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence
Expand All @@ -223,6 +226,20 @@ loop:
}
return err
}
if serverPkt[0] == pnet.ErrHeader.Byte() {
err = pnet.ParseErrorPacket(serverPkt)
if handshakeHandler.HandleHandshakeErr(cctx, err.(*gomysql.MyError)) {
logger.Warn("handle handshake error, start reconnect", zap.Error(err))
backendIO.Close()
goto RECONNECT
}
return err
}
err = clientIO.WritePacket(serverPkt, true)
if err != nil {
return err
}

pktIdx++
switch serverPkt[0] {
case pnet.OKHeader.Byte():
Expand All @@ -233,8 +250,6 @@ loop:
return err
}
return nil
case pnet.ErrHeader.Byte():
return pnet.ParseErrorPacket(serverPkt)
default: // mysql.AuthSwitchRequest, ShaCommand
if serverPkt[0] == pnet.AuthSwitchHeader.Byte() {
pluginName = string(serverPkt[1 : bytes.IndexByte(serverPkt[1:], 0)+1])
Expand Down
14 changes: 14 additions & 0 deletions pkg/proxy/backend/handshake_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package backend

import (
gomysql "github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/tiproxy/lib/util/errors"
"github.com/pingcap/tiproxy/pkg/manager/namespace"
"github.com/pingcap/tiproxy/pkg/manager/router"
Expand Down Expand Up @@ -70,6 +71,7 @@ type ConnContext interface {

type HandshakeHandler interface {
HandleHandshakeResp(ctx ConnContext, resp *pnet.HandshakeResp) error
HandleHandshakeErr(ctx ConnContext, err *gomysql.MyError) bool // return true means retry connect
GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error)
OnHandshake(ctx ConnContext, to string, err error)
OnConnClose(ctx ConnContext) error
Expand All @@ -94,6 +96,10 @@ func (handler *DefaultHandshakeHandler) HandleHandshakeResp(ConnContext, *pnet.H
return nil
}

func (handler *DefaultHandshakeHandler) HandleHandshakeErr(ctx ConnContext, err *gomysql.MyError) bool {
return false
}

func (handler *DefaultHandshakeHandler) GetRouter(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) {
ns, ok := handler.nsManager.GetNamespaceByUser(resp.User)
if !ok {
Expand Down Expand Up @@ -142,6 +148,7 @@ type CustomHandshakeHandler struct {
onTraffic func(ConnContext)
onConnClose func(ConnContext) error
handleHandshakeResp func(ctx ConnContext, resp *pnet.HandshakeResp) error
handleHandshakeErr func(ctx ConnContext, err *gomysql.MyError) bool
getCapability func() pnet.Capability
getServerVersion func() string
}
Expand Down Expand Up @@ -179,6 +186,13 @@ func (h *CustomHandshakeHandler) HandleHandshakeResp(ctx ConnContext, resp *pnet
return nil
}

func (h *CustomHandshakeHandler) HandleHandshakeErr(ctx ConnContext, err *gomysql.MyError) bool {
if h.handleHandshakeErr != nil {
return h.handleHandshakeErr(ctx, err)
}
return false
}

func (h *CustomHandshakeHandler) GetCapability() pnet.Capability {
if h.getCapability != nil {
return h.getCapability()
Expand Down

0 comments on commit b353ff0

Please sign in to comment.