Skip to content

Commit

Permalink
Merge pull request mosn#1446 from nejisama/tls/auto_fallback
Browse files Browse the repository at this point in the history
Tls/auto fallback.
  • Loading branch information
wangfakang authored Nov 1, 2020
2 parents 29c5aef + 14d94b6 commit b2025cb
Show file tree
Hide file tree
Showing 17 changed files with 684 additions and 323 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/reviewdog.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- name: golangci-lint
uses: reviewdog/action-golangci-lint@v1
with:
golangci_lint_flags: "--enable-all --timeout=10m --exclude-use-default=false --tests=false --disable=gochecknoinits,gochecknoglobals,exhaustive"
golangci_lint_flags: "--enable-all --timeout=10m --exclude-use-default=false --tests=false --disable=gochecknoinits,gochecknoglobals,exhaustive,nakedret"
workdir: pkg

test:
Expand Down
443 changes: 222 additions & 221 deletions pkg/mock/upstream.go

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pkg/mtls/confighook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (hook *testConfigHooks) verifyPeerCertificate(roots *x509.CertPool, certs [
return err
}
if leaf.Subject.CommonName != hook.PassCommonName {
return errors.New("common name miss match")
return errors.New("tls: common name miss match")
}
return nil
}
Expand Down
10 changes: 8 additions & 2 deletions pkg/mtls/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,14 @@ type TLSConn struct {

func (c *TLSConn) Read(b []byte) (int, error) {
n, err := c.Conn.Read(b)
if err != nil && strings.Contains(err.Error(), "tls") {
log.DefaultLogger.Alertf(types.ErrorKeyTLSRead, "[mtls] tls connection read error: %v, local address: %v, remote address: %v", err, c.Conn.LocalAddr(), c.Conn.RemoteAddr())
if err != nil {
if strings.Contains(err.Error(), "tls") {
log.DefaultLogger.Alertf(types.ErrorKeyTLSRead, "[mtls] tls connection read error: %v, local address: %v, remote address: %v", err, c.Conn.LocalAddr(), c.Conn.RemoteAddr())
}
if !c.Conn.GetConnectionState().HandshakeComplete {
// wraps as a new error which makes the read error and do no retry
return n, errors.New("tls: handshake is not completed: " + err.Error())
}
}
return n, err
}
Expand Down
10 changes: 6 additions & 4 deletions pkg/mtls/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ type MockListener struct {
Mng types.TLSContextManager
}

func MockClient(t *testing.T, addr string, cltMng types.TLSContextManager) (*http.Response, error) {
func MockClient(t *testing.T, addr string, cltMng types.TLSClientContextManager) (*http.Response, error) {
c, err := net.Dial("tcp", addr)
if err != nil {
return nil, fmt.Errorf("request server error %v", err)
Expand All @@ -80,9 +80,11 @@ func MockClient(t *testing.T, addr string, cltMng types.TLSContextManager) (*htt
conn = c
if cltMng != nil {
req, _ = http.NewRequest("GET", "https://"+addr, nil)
conn, _ = cltMng.Conn(c)
tlsConn, _ := conn.(*TLSConn)
if err := tlsConn.Handshake(); err != nil {
conn, err = cltMng.Conn(c)
if err != nil && cltMng.Fallback() {
conn, err = net.Dial("tcp", addr)
}
if err != nil {
return nil, fmt.Errorf("request tls handshake error %v", err)
}
} else {
Expand Down
140 changes: 140 additions & 0 deletions pkg/mtls/static_tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package mtls

import (
"io/ioutil"
"net"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -496,5 +498,143 @@ func TestFallback(t *testing.T) {
t.Fatal("create tls client context without certificate success, expected failed")
}
})
}

func TestClientFallBack(t *testing.T) {
// A server not support tls
lc := &v2.Listener{}
ctxMng, err := NewTLSServerContextManager(lc)
if err != nil {
t.Fatalf("tls context manager error: %v", err)
}
server := MockServer{
Mng: ctxMng,
t: t,
}
server.GoListenAndServe(t)
defer server.Close()
time.Sleep(time.Second) //wait server start
// A Client with fallback
fallbackConfig := &v2.TLSConfig{
Status: true,
InsecureSkip: true,
Fallback: true,
}
fallbackMng, err := NewTLSClientContextManager(fallbackConfig)
if err != nil {
t.Fatalf("tls context manager error: %v", err)
}
resp, err := MockClient(t, server.Addr, fallbackMng)
if !pass(resp, err) {
t.Fatalf("fallback request failed")
}
}

func TestHandshaketTimeout(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen error: %v", err)
}
addr := ln.Addr().String()
tlsConfig := &v2.TLSConfig{
Status: true,
InsecureSkip: true,
}
cltMng, err := NewTLSClientContextManager(tlsConfig)
if err != nil {
t.Fatalf("tls context manager error: %v", err)
}
c, err := net.DialTimeout("tcp", addr, time.Second)
if err != nil {
t.Fatalf("dial failed: %v", err)
}
handshakeTimeout = time.Second
conn, err := cltMng.Conn(c)
if err == nil {
conn.Close()
t.Fatalf("expected connect failed, but success")
}
t.Logf("tls connect failed: %v", err)
}

func TestReadError(t *testing.T) {
ci := &certInfo{"Cert1", "RSA", "www.example.com"}
ctx, _ := ci.CreateCertConfig()
filterChains := []v2.FilterChain{
{
TLSContexts: []v2.TLSConfig{
*ctx,
},
},
}
lc := &v2.Listener{
ListenerConfig: v2.ListenerConfig{
FilterChains: filterChains,
},
}
ctxMng, err := NewTLSServerContextManager(lc)
if err != nil {
t.Fatalf("tls context manager error: %v", err)
}
addrch := make(chan string, 1)
resch := make(chan string, 1)
// mock a read loop
go func() {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen failed: %v", err)
}
addrch <- ln.Addr().String()
for {
conn, err := ln.Accept()
if err != nil {
return
}
tlsconn, err := ctxMng.Conn(conn)
if err != nil {
conn.Close()
resch <- "conn error: " + err.Error()
return
}
b := make([]byte, 100)
loop := 0
READLOOP:
for {
tlsconn.SetReadDeadline(time.Now().Add(2 * time.Second))
if _, err := tlsconn.Read(b); err != nil {
if te, ok := err.(net.Error); ok && te.Timeout() {
loop++
if loop < 100 {
continue READLOOP
} else {
resch <- "dead loop"
return
}
}
resch <- "conn read error: " + err.Error()
return
} else {
resch <- string(b)
return
}
}
}
}()
addr := <-addrch
conn, err := net.DialTimeout("tcp", addr, 2*time.Second)
if err != nil {
t.Fatalf("dial failed %v", err)
}
defer conn.Close()
// do not do handshake wait result
select {
case result := <-resch:
// if no fix in conn.go: c.Conn.GetConnectionState().HandshakeComplete check
// this should returns dead loop
if !strings.Contains(result, "tls: handshake is not completed") {
t.Fatalf("got result: %s", result)
}
case <-time.After(10 * time.Second):
t.Fatalf("wait result timeout")
}
}
32 changes: 20 additions & 12 deletions pkg/mtls/tls_context_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package mtls
import (
"net"
"reflect"
"time"

"mosn.io/mosn/pkg/config/v2"
"mosn.io/mosn/pkg/log"
Expand Down Expand Up @@ -132,42 +133,45 @@ func (mng *serverContextManager) Enabled() bool {
return false
}

// The serverContextManager's HashValue is not used in mosn.
// maybe we will use it later.
func (mng *serverContextManager) HashValue() *types.HashValue {
if len(mng.providers) == 0 {
return nil
}
p := mng.providers[0]
return p.GetTLSConfigContext(false).HashValue()
}

type clientContextManager struct {
// client support only one certificate
provider types.TLSProvider
// fallback
fallback bool
}

// NewTLSClientContextManager returns a types.TLSContextManager used in TLS Client
func NewTLSClientContextManager(cfg *v2.TLSConfig) (types.TLSContextManager, error) {
func NewTLSClientContextManager(cfg *v2.TLSConfig) (types.TLSClientContextManager, error) {
provider, err := NewProvider(cfg)
if err != nil {
return nil, err
}
mng := &clientContextManager{
provider: provider,
fallback: cfg.Fallback,
}
return mng, nil
}

var handshakeTimeout = types.DefaultConnReadTimeout

func (mng *clientContextManager) Conn(c net.Conn) (net.Conn, error) {
if _, ok := c.(*net.TCPConn); !ok {
return c, nil
}
if !mng.Enabled() {
return c, nil
}
// make tls connection and try handshake
tlsconn := tls.Client(c, mng.provider.GetTLSConfigContext(true).Config())
tlsconn.SetReadDeadline(time.Now().Add(handshakeTimeout))
if err := tlsconn.Handshake(); err != nil {
c.Close() // close the failed connection
return nil, err
}

return &TLSConn{
tls.Client(c, mng.provider.GetTLSConfigContext(true).Config()),
tlsconn,
}, nil
}

Expand All @@ -183,3 +187,7 @@ func (mng *clientContextManager) HashValue() *types.HashValue {
return mng.provider.GetTLSConfigContext(true).HashValue()

}

func (mng *clientContextManager) Fallback() bool {
return mng.fallback
}
Loading

0 comments on commit b2025cb

Please sign in to comment.