Skip to content

Commit

Permalink
Wrap proxyproto in package
Browse files Browse the repository at this point in the history
  • Loading branch information
tsipinakis committed May 13, 2022
1 parent 4b167ff commit ddddfd4
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 16 deletions.
34 changes: 34 additions & 0 deletions internal/proxyproto/proxyproto.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package proxyproto

import (
"net"

"github.com/pires/go-proxyproto"
)

// WrapProxy is a function that wraps a net.Conn around the PROXY tcp protocol. It is used for correctly reporting the originator IP address when a service is running behind a load balancer
// In case proxy use is allowed the wrapped network connection is returned along with the IP address of the proxy that it is used. The wrapped network connection will return the IP address
// of the client when RemoteAddr() is called
//
// conn is the network connection to wrap
// proxyList is a list of addresses that are allowed to send proxy information
//
func WrapProxy(conn net.Conn, proxyList []string) (net.Conn, *net.TCPAddr, error) {
if len(proxyList) == 0 {
return conn, nil, nil
}
policyFunc := proxyproto.MustStrictWhiteListPolicy(proxyList)
policy, err := policyFunc(conn.RemoteAddr())
if err != nil {
return nil, nil, err
}
if policy == proxyproto.REJECT || policy == proxyproto.IGNORE {
// If it's not an approved proxy we should fail loudly, not silently
return conn, nil, nil
}
tcpAddr := conn.RemoteAddr().(*net.TCPAddr)
return proxyproto.NewConn(
conn,
proxyproto.WithPolicy(policy),
), tcpAddr, nil
}
148 changes: 148 additions & 0 deletions internal/proxyproto/proxyproto_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package proxyproto_test

import (
"fmt"
"io"
"net"
"testing"
"time"

"github.com/containerssh/libcontainerssh/internal/proxyproto"
goproxyproto "github.com/pires/go-proxyproto"
)

type fakeConn struct {
remoteAddr string
localAddr string
pipeReader io.ReadCloser
pipeWriter io.WriteCloser
}

func NewFakeConn(clientAddr string, serverAddr string) (fakeConn, fakeConn) {
clientPipeReader, clientPipeWriter := io.Pipe()
serverPipeReader, serverPipeWriter := io.Pipe()
return fakeConn{
remoteAddr: clientAddr,
localAddr: serverAddr,
pipeReader: serverPipeReader,
pipeWriter: clientPipeWriter,
}, fakeConn{
remoteAddr: serverAddr,
localAddr: clientAddr,
pipeReader: clientPipeReader,
pipeWriter: serverPipeWriter,
}
}

func (f fakeConn) Read(b []byte) (n int, err error) {
return f.pipeReader.Read(b)
}

func (f fakeConn) Write(b []byte) (n int, err error) {
return f.pipeWriter.Write(b)
}

func (f fakeConn) Close() error {
f.pipeWriter.Close()
f.pipeReader.Close()
return nil
}

func (f fakeConn) LocalAddr() net.Addr {
return &net.TCPAddr{
IP: net.ParseIP(f.localAddr),
}
}

func (f fakeConn) RemoteAddr() net.Addr {
return &net.TCPAddr{
IP: net.ParseIP(f.remoteAddr),
}
}
func (f fakeConn) SetDeadline(t time.Time) error {
return fmt.Errorf("Unimplemented")
}
func (f fakeConn) SetReadDeadline(t time.Time) error {
return fmt.Errorf("Unimplemented")
}
func (f fakeConn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("Unimplemented")
}

func TestProxyWithHeader(t *testing.T) {
clientIP := "127.0.0.1"
proxyIP := "127.0.0.2"
serverIP := "127.0.0.3"

server, proxy := NewFakeConn(proxyIP, serverIP)
wrappedConn, proxyAddr, err := proxyproto.WrapProxy(server, []string{proxyIP})
if err != nil {
t.Fatal(err)
}

header := &goproxyproto.Header{
Version: 1,
Command: goproxyproto.PROXY,
TransportProtocol: goproxyproto.TCPv4,
SourceAddr: &net.TCPAddr{
IP: net.ParseIP(clientIP),
Port: 1000,
},
DestinationAddr: &net.TCPAddr{
IP: net.ParseIP(proxyIP),
Port: 2000,
},
}
go func() {
_, err := header.WriteTo(proxy)
if err != nil {
return
}
}()

if proxyAddr == nil {
t.Fatalf("Proxy info was rejected")
}
if proxyAddr.String() != proxyIP+":0" {
t.Fatalf("Unexpected proxy address %s, expected %s", proxyAddr, proxyIP)
}
if wrappedConn.RemoteAddr().String() != clientIP+":1000" {
t.Fatalf("Header not accepted when it should be %s != %s", wrappedConn.RemoteAddr().String(), clientIP+":1000")
}
}

func TestProxyUnauthorizedHeader(t *testing.T) {
clientIP := "127.0.0.1"
proxyIP := "127.0.0.2"
serverIP := "127.0.0.3"

server, proxy := NewFakeConn(proxyIP, serverIP)
_, proxyAddr, err := proxyproto.WrapProxy(server, []string{"128.0.0.2"})
if err != nil {
t.Fatal(err)
}

header := &goproxyproto.Header{
Version: 1,
Command: goproxyproto.PROXY,
TransportProtocol: goproxyproto.TCPv4,
SourceAddr: &net.TCPAddr{
IP: net.ParseIP(clientIP),
Port: 1000,
},
DestinationAddr: &net.TCPAddr{
IP: net.ParseIP(proxyIP),
Port: 2000,
},
}
go func() {
_, err := header.WriteTo(proxy)
if err != nil {
return
}
}()

if proxyAddr != nil {
t.Fatalf("Proxy info was accepted when unauthorized")
}
}
24 changes: 8 additions & 16 deletions internal/sshserver/serverImpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ import (

"github.com/containerssh/libcontainerssh/auth"
"github.com/containerssh/libcontainerssh/config"
"github.com/containerssh/libcontainerssh/internal/proxyproto"
ssh2 "github.com/containerssh/libcontainerssh/internal/ssh"
"github.com/containerssh/libcontainerssh/log"
messageCodes "github.com/containerssh/libcontainerssh/message"
"github.com/containerssh/libcontainerssh/service"
"github.com/pires/go-proxyproto"
"golang.org/x/crypto/ssh"
)

Expand Down Expand Up @@ -55,20 +55,12 @@ func (s *serverImpl) RunWithLifecycle(lifecycle service.Lifecycle) error {
Control: s.socketControl,
}

useProxy := len(s.cfg.AllowedProxies) > 0

netListener, err := listenConfig.Listen(lifecycle.Context(), "tcp", s.cfg.Listen)
if err != nil {
s.lock.Unlock()
return messageCodes.Wrap(err, messageCodes.ESSHStartFailed, "failed to start SSH server on %s", s.cfg.Listen)
}
if useProxy {
policy := proxyproto.MustStrictWhiteListPolicy(s.cfg.AllowedProxies)
netListener = &proxyproto.Listener{
Listener: netListener,
Policy: policy,
}
}

s.listenSocket = netListener
s.lock.Unlock()
if err := s.handler.OnReady(); err != nil {
Expand All @@ -87,19 +79,19 @@ func (s *serverImpl) RunWithLifecycle(lifecycle service.Lifecycle) error {
s.logger.Info(messageCodes.NewMessage(messageCodes.MSSHServiceAvailable, "SSH server running on %s", s.cfg.Listen))

go s.handleListenSocketOnShutdown(lifecycle)

for {
tcpConn, err := netListener.Accept()
if err != nil {
// Assume listen socket closed
break
}
s.wg.Add(1)
var proxy *net.TCPAddr
if useProxy {
proxyConn := tcpConn.(*proxyproto.Conn)
proxy = proxyConn.Raw().RemoteAddr().(*net.TCPAddr)
tcpConn, proxyAddr, err := proxyproto.WrapProxy(tcpConn, s.cfg.AllowedProxies)
if err != nil {
break
}
go s.handleConnection(tcpConn, proxy)
s.wg.Add(1)
go s.handleConnection(tcpConn, proxyAddr)
}
lifecycle.Stopping()
s.shuttingDown = true
Expand Down

0 comments on commit ddddfd4

Please sign in to comment.