Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat auth supported for clients #111

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions ci/fuzz/redisparser/redisparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ import (
)

var (
pc proto.ProxyConn
pc proto.ProxyConn
msgs []*proto.Message
nc *libnet.Conn
nc *libnet.Conn
)

func Fuzz(data []byte) int {
conn := _createConn(data)
nc := libnet.NewConn(conn, time.Second, time.Second)
pc = redis.NewProxyConn(nc)
pc = redis.NewProxyConn(nc, "")
msgs = proto.GetMsgs(4)
nmsgs, err := pc.Decode(msgs)
if err == bufio.ErrBufferFull {
Expand Down Expand Up @@ -56,7 +56,7 @@ func (m mockAddr) String() string {

type mockConn struct {
addr mockAddr
buf *bytes.Buffer
buf *bytes.Buffer
err error
closed int32
}
Expand Down Expand Up @@ -107,8 +107,8 @@ func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil }
// _createConn is useful tools for handler test
func _createConn(data []byte) net.Conn {
mconn := &mockConn{
addr: "127.0.0.1:12345",
buf: bytes.NewBuffer(data),
addr: "127.0.0.1:12345",
buf: bytes.NewBuffer(data),
}
return mconn
}
37 changes: 37 additions & 0 deletions cmd/proxy/proxy-redis-example.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
[[clusters]]
# This be used to specify the name of cache cluster.
name = "test-redis"
# The name of the hash function. Possible values are: sha1.
hash_method = "fnv1a_64"
# The key distribution mode. Possible values are: ketama.
hash_distribution = "ketama"
# A two character string that specifies the part of the key used for hashing. Eg "{}".
hash_tag = ""
# cache type: memcache | memcache_binary | redis | redis_cluster
cache_type = "redis"
# proxy listen proto: tcp | unix
listen_proto = "tcp"
# proxy listen addr: tcp addr | unix sock path
listen_addr = "0.0.0.0:26379"
# Authenticate to the Redis server on connect.
redis_auth = ""
# The dial timeout value in msec that we wait for to establish a connection to the server. By default, we wait indefinitely.
dial_timeout = 1000
# The read timeout value in msec that we wait for to receive a response from a server. By default, we wait indefinitely.
read_timeout = 1000
# The write timeout value in msec that we wait for to write a response to a server. By default, we wait indefinitely.
write_timeout = 1000
# The number of connections that can be opened to each server. By default, we open at most 1 server connection.
node_connections = 2
# The number of consecutive failures on a server that would lead to it being temporarily ejected when auto_eject is set to true. Defaults to 3.
ping_fail_limit = 3
# A boolean value that controls if server should be ejected temporarily when it fails consecutively ping_fail_limit times.
ping_auto_eject = false

slowlog_slower_than = 10
# A list of server address, port and weight (name:port:weight or ip:port:weight) for this server pool. Also you can use alias name like: ip:port:weight alias.
servers = [
"127.0.0.1:6379:1 redis1",
]
# Clients need to AUTH <PASSWORD> before processing any other commands.
password = "111"
1 change: 1 addition & 0 deletions proxy/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ type ClusterConfig struct {
PingAutoEject bool `toml:"ping_auto_eject"`
SlowlogSlowerThan int `toml:"slowlog_slower_than"`
Servers []string `toml:"servers"`
Password string `toml:"password"`
}

// ValidateStandalone validate redis/memcache address is valid or not
Expand Down
4 changes: 2 additions & 2 deletions proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ func NewHandler(p *Proxy, cc *ClusterConfig, conn net.Conn, forwarder proto.Forw
case types.CacheTypeMemcacheBinary:
h.pc = mcbin.NewProxyConn(h.conn)
case types.CacheTypeRedis:
h.pc = redis.NewProxyConn(h.conn)
h.pc = redis.NewProxyConn(h.conn, cc.Password)
case types.CacheTypeRedisCluster:
h.pc = rclstr.NewProxyConn(h.conn, forwarder)
h.pc = rclstr.NewProxyConn(h.conn, forwarder, cc.Password)
default:
panic(types.ErrNoSupportCacheType)
}
Expand Down
5 changes: 3 additions & 2 deletions proxy/proto/redis/cluster/proxy_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cluster
import (
"bytes"
errs "errors"

"overlord/pkg/conv"
libnet "overlord/pkg/net"
"overlord/proxy/proto"
Expand All @@ -29,14 +30,14 @@ type proxyConn struct {
}

// NewProxyConn creates new redis cluster Encoder and Decoder.
func NewProxyConn(conn *libnet.Conn, fer proto.Forwarder) proto.ProxyConn {
func NewProxyConn(conn *libnet.Conn, fer proto.Forwarder, password string) proto.ProxyConn {
var c *cluster
if fer != nil {
c = fer.(*cluster)
}
r := &proxyConn{
c: c,
pc: redis.NewProxyConn(conn),
pc: redis.NewProxyConn(conn, password),
}
return r
}
Expand Down
98 changes: 65 additions & 33 deletions proxy/proto/redis/proxy_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ const (
)

var (
nullBytes = []byte("-1\r\n")
okBytes = []byte("OK\r\n")
pongDataBytes = []byte("PONG")
justOkBytes = []byte("OK")
notSupportDataBytes = []byte("Error: command not support")
nullBytes = []byte("-1\r\n")
okBytes = []byte("OK\r\n")
pongDataBytes = []byte("PONG")
justOkBytes = []byte("OK")
invalidPasswordBytes = []byte("ERR invalid password")
noAuthBytes = []byte("NOAUTH Authentication required.")
noPasswordSetBytes = []byte("ERR Client sent AUTH, but no password is set.")
notSupportDataBytes = []byte("Error: command not support")
)

// ProxyConn is export for redis cluster.
Expand All @@ -34,21 +37,24 @@ func (pc *ProxyConn) Bw() *bufio.Writer {
}

type proxyConn struct {
br *bufio.Reader
bw *bufio.Writer
completed bool

resp *resp
br *bufio.Reader
bw *bufio.Writer
completed bool
resp *resp
authorized bool
password string
}

// NewProxyConn creates new redis Encoder and Decoder.
func NewProxyConn(conn *libnet.Conn) proto.ProxyConn {
func NewProxyConn(conn *libnet.Conn, password string) proto.ProxyConn {
r := &proxyConn{
br: bufio.NewReader(conn, bufio.Get(proxyReadBufSize)),
bw: bufio.NewWriter(conn),
completed: true,
resp: &resp{},
password: password,
}
r.authorized = password == ""
return r
}

Expand Down Expand Up @@ -183,31 +189,57 @@ func (pc *proxyConn) Encode(m *proto.Message) (err error) {
if !ok {
return ErrBadAssert
}
switch req.mType {
case mergeTypeOK:
err = pc.mergeOK(m)
case mergeTypeJoin:
err = pc.mergeJoin(m)
case mergeTypeCount:
err = pc.mergeCount(m)
default:
if !req.IsSupport() {
req.reply.respType = respError
req.reply.data = req.reply.data[:0]
req.reply.data = append(req.reply.data, notSupportDataBytes...)
} else if req.IsCtl() {
reqData := req.resp.array[0].data
if bytes.Equal(reqData, cmdPingBytes) {
req.reply.respType = respString
req.reply.data = req.reply.data[:0]
req.reply.data = append(req.reply.data, pongDataBytes...)
} else if bytes.Equal(reqData, cmdQuitBytes) {
req.reply.respType = respString
// general supported cmd need authorized
if !pc.authorized && !req.IsAuth() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should not forward any command when pc.authorized is false. And so that authorized should be atomic flag or guard by mutex (be used by encoder/decoder together).

req.reply.respType = respError
req.reply.data = req.reply.data[:0]
req.reply.data = append(req.reply.data, noAuthBytes...)
err = req.reply.encode(pc.bw)
} else {

switch req.mType {
case mergeTypeOK:
err = pc.mergeOK(m)
case mergeTypeJoin:
err = pc.mergeJoin(m)
case mergeTypeCount:
err = pc.mergeCount(m)
default:
if !req.IsSupport() {
req.reply.respType = respError
req.reply.data = req.reply.data[:0]
req.reply.data = append(req.reply.data, justOkBytes...)
req.reply.data = append(req.reply.data, notSupportDataBytes...)
} else if req.IsCtl() {
reqData := req.resp.array[0].data
if bytes.Equal(reqData, cmdPingBytes) {
req.reply.respType = respString
req.reply.data = req.reply.data[:0]
req.reply.data = append(req.reply.data, pongDataBytes...)
} else if bytes.Equal(reqData, cmdQuitBytes) {
req.reply.respType = respString
req.reply.data = req.reply.data[:0]
req.reply.data = append(req.reply.data, justOkBytes...)
} else if bytes.Equal(reqData, cmdAuthBytes) {
if bytes.Equal(req.Key(), []byte(pc.password)) {
pc.authorized = true
req.reply.respType = respString
req.reply.data = req.reply.data[:0]
req.reply.data = append(req.reply.data, justOkBytes...)
} else if pc.password == "" {
req.reply.respType = respError
req.reply.data = req.reply.data[:0]
req.reply.data = append(req.reply.data, noPasswordSetBytes...)
} else {
pc.authorized = false
req.reply.respType = respError
req.reply.data = req.reply.data[:0]
req.reply.data = append(req.reply.data, invalidPasswordBytes...)
}

}
}
err = req.reply.encode(pc.bw)
}
err = req.reply.encode(pc.bw)
}
if err != nil {
err = errors.WithStack(err)
Expand Down
12 changes: 6 additions & 6 deletions proxy/proto/redis/proxy_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

func _decodeMessage(t *testing.T, data string) []*proto.Message {
conn := libnet.NewConn(mockconn.CreateConn([]byte(data), 1), time.Second, time.Second)
pc := NewProxyConn(conn)
pc := NewProxyConn(conn, "")
msgs := proto.GetMsgs(16)
nmsgs, err := pc.Decode(msgs)
assert.NoError(t, err)
Expand Down Expand Up @@ -65,7 +65,7 @@ func TestDecodeBasicOk(t *testing.T) {
func TestDecodeComplexOk(t *testing.T) {
data := "*3\r\n$4\r\nMGET\r\n$4\r\nbaka\r\n$4\r\nkaba\r\n*5\r\n$4\r\nMSET\r\n$1\r\na\r\n$1\r\nb\r\n$3\r\neee\r\n$5\r\n12345\r\n*3\r\n$4\r\nMGET\r\n$4\r\nenen\r\n$4\r\nnime\r\n*2\r\n$3\r\nGET\r\n$5\r\nabcde\r\n*3\r\n$3\r\nDEL\r\n$1\r\na\r\n$1\r\nb\r\n"
conn := libnet.NewConn(mockconn.CreateConn([]byte(data), 1), time.Second, time.Second)
pc := NewProxyConn(conn)
pc := NewProxyConn(conn, "")
// test reuse command
msgs := proto.GetMsgs(16)
msgs[1].WithRequest(getReq())
Expand Down Expand Up @@ -183,7 +183,7 @@ func TestEncodeNotSupportCtl(t *testing.T) {
}
msg.WithRequest(req)
conn := libnet.NewConn(mockconn.CreateConn(nil, 1), time.Second, time.Second)
pc := NewProxyConn(conn)
pc := NewProxyConn(conn, "")
err := pc.Encode(msg)
assert.NoError(t, err)
assert.Equal(t, req.reply.data, notSupportDataBytes)
Expand Down Expand Up @@ -281,7 +281,7 @@ func TestEncodeMergeOk(t *testing.T) {
msg.Batch()
}
conn, buf := mockconn.CreateDownStreamConn()
pc := NewProxyConn(libnet.NewConn(conn, time.Second, time.Second))
pc := NewProxyConn(libnet.NewConn(conn, time.Second, time.Second), "")
err := pc.Encode(msg)
if !assert.NoError(t, err) {
return
Expand Down Expand Up @@ -310,7 +310,7 @@ func TestEncodeWithError(t *testing.T) {
msg.Done()

conn, buf := mockconn.CreateDownStreamConn()
pc := NewProxyConn(libnet.NewConn(conn, time.Second, time.Second))
pc := NewProxyConn(libnet.NewConn(conn, time.Second, time.Second), "")
err := pc.Encode(msg)
assert.Error(t, err)
assert.Equal(t, mockErr, err)
Expand Down Expand Up @@ -342,7 +342,7 @@ func TestEncodeWithPing(t *testing.T) {
msg.WithRequest(req)

conn, buf := mockconn.CreateDownStreamConn()
pc := NewProxyConn(libnet.NewConn(conn, time.Second, time.Second))
pc := NewProxyConn(libnet.NewConn(conn, time.Second, time.Second), "")
err := pc.Encode(msg)
assert.NoError(t, err)
err = pc.Flush()
Expand Down
14 changes: 11 additions & 3 deletions proxy/proto/redis/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ var (
cmdEvalBytes = []byte("4\r\nEVAL")
cmdQuitBytes = []byte("4\r\nQUIT")
cmdPingBytes = []byte("4\r\nPING")
cmdAuthBytes = []byte("4\r\nAUTH")
cmdMSetBytes = []byte("4\r\nMSET")
cmdMGetBytes = []byte("4\r\nMGET")
cmdGetBytes = []byte("3\r\nGET")
Expand Down Expand Up @@ -168,6 +169,13 @@ func (r *Request) IsSupport() bool {
return ok
}

func (r *Request) IsAuth() bool {
if r.IsCtl() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return r.IsCtrl() && bytes.Equal(r.resp.array[0].data, cmdAuthBytes)

return bytes.Equal(r.resp.array[0].data, cmdAuthBytes)
}
return false
}

// IsCtl is control command.
//
// NOTE: use string([]byte) as a map key, it is very specific!!!
Expand Down Expand Up @@ -307,13 +315,13 @@ var (
"5\r\nPFADD",
"7\r\nPFMERGE",
"4\r\nEVAL",
"11\r\nSUNIONSTORE",
"11\r\nZUNIONSTORE",
}
notSupportCmds = []string{
"6\r\nMSETNX",
"10\r\nSDIFFSTORE",
"11\r\nSINTERSTORE",
"11\r\nSUNIONSTORE",
"11\r\nZUNIONSTORE",
"5\r\nBLPOP",
"5\r\nBRPOP",
"10\r\nBRPOPLPUSH",
Expand All @@ -328,7 +336,6 @@ var (
"4\r\nWAIT",
"5\r\nBITOP",
"7\r\nEVALSHA",
"4\r\nAUTH",
"4\r\nECHO",
"4\r\nINFO",
"5\r\nPROXY",
Expand All @@ -341,5 +348,6 @@ var (
controlCmds = []string{
"4\r\nQUIT",
"4\r\nPING",
"4\r\nAUTH",
}
)
4 changes: 2 additions & 2 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ func (p *Proxy) accept(cc *ClusterConfig, l net.Listener, forwarder proto.Forwar
case types.CacheTypeMemcacheBinary:
encoder = mcbin.NewProxyConn(libnet.NewConn(conn, time.Second, time.Second))
case types.CacheTypeRedis:
encoder = redis.NewProxyConn(libnet.NewConn(conn, time.Second, time.Second))
encoder = redis.NewProxyConn(libnet.NewConn(conn, time.Second, time.Second), cc.Password)
case types.CacheTypeRedisCluster:
encoder = rclstr.NewProxyConn(libnet.NewConn(conn, time.Second, time.Second), nil)
encoder = rclstr.NewProxyConn(libnet.NewConn(conn, time.Second, time.Second), nil, cc.Password)
}
if encoder != nil {
_ = encoder.Encode(proto.ErrMessage(ErrProxyMoreMaxConns))
Expand Down