diff --git a/auditlog/message/connect.go b/auditlog/message/connect.go index 29a0e10d..eb038f4e 100644 --- a/auditlog/message/connect.go +++ b/auditlog/message/connect.go @@ -2,8 +2,9 @@ package message // PayloadConnect is the payload for TypeConnect messages. type PayloadConnect struct { - RemoteAddr string `json:"remoteAddr" yaml:"remoteAddr"` // RemoteAddr contains the IP address of the connecting user. - Country string `json:"country" yaml:"country"` // Country contains the country code looked up from the IP address. Contains "XX" if the lookup failed. + RemoteAddr string `json:"remoteAddr" yaml:"remoteAddr"` // RemoteAddr contains the IP address of the connecting user. + ProxyAddr string `json:"proxyAddr,omitempty" yaml:"proxyAddr"` // ProxyAddr contains the IP adress of the proxy used (if behind a load balancer) + Country string `json:"country" yaml:"country"` // Country contains the country code looked up from the IP address. Contains "XX" if the lookup failed. } // Equals compares two PayloadConnect datasets. diff --git a/auditlog/message/type.go b/auditlog/message/type.go index 286fb5be..09609700 100644 --- a/auditlog/message/type.go +++ b/auditlog/message/type.go @@ -72,6 +72,9 @@ var typeToID = map[Type]string{ TypeAuthKeyboardInteractiveFailed: "auth_keyboard_interactive_failed", TypeAuthKeyboardInteractiveBackendError: "auth_keyboard_interactive_backend_error", + TypeHandshakeFailed: "handshake_failed", + TypeHandshakeSuccessful: "handshake_successful", + TypeGlobalRequestUnknown: "global_request_unknown", TypeNewChannel: "new_channel", TypeNewChannelSuccessful: "new_channel_successful", @@ -114,6 +117,9 @@ var typeToName = map[Type]string{ TypeAuthKeyboardInteractiveFailed: "Keyboard-interactive authentication failed", TypeAuthKeyboardInteractiveBackendError: "Keyboard-interactive authentication backend error", + TypeHandshakeFailed: "Handshake failed", + TypeHandshakeSuccessful: "Handshake successful", + TypeGlobalRequestUnknown: "Unknown global request", TypeNewChannel: "New channel request", TypeNewChannelSuccessful: "New channel successful", diff --git a/config/ssh.go b/config/ssh.go index 549538be..5226f3e1 100644 --- a/config/ssh.go +++ b/config/ssh.go @@ -25,6 +25,9 @@ type SSHConfig struct { // See https://tools.ietf.org/html/rfc4253#page-4 section 4.2. Protocol Version Exchange // The trailing CR and LF characters should NOT be added to this string. ServerVersion SSHServerVersion `json:"serverVersion" yaml:"serverVersion" default:"SSH-2.0-ContainerSSH"` + // AllowedProxies is a list of IP addresses or CIDR ranges that are allowed to use the + // PROXY protocol to override the connection originator IP address. + AllowedProxies []string `json:"allowedProxies" yaml:"allowedProxies"` // Ciphers are the ciphers offered to the client. Ciphers SSHCipherList `json:"ciphers" yaml:"ciphers" default:"[\"chacha20-poly1305@openssh.com\",\"aes256-gcm@openssh.com\",\"aes128-gcm@openssh.com\",\"aes256-ctr\",\"aes192-ctr\",\"aes128-ctr\"]" comment:"SSHCipher suites to use"` // KexAlgorithms are the key exchange algorithms offered to the client. diff --git a/go.mod b/go.mod index 4eeb666f..6bc225d0 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/mitchellh/mapstructure v1.4.3 github.com/opencontainers/image-spec v1.0.2 github.com/oschwald/geoip2-golang v1.5.0 + github.com/pires/go-proxyproto v0.6.1 github.com/qdm12/reprint v0.0.0-20200326205758-722754a53494 github.com/stretchr/testify v1.7.0 golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 diff --git a/go.sum b/go.sum index 7cac47ed..b096a581 100644 --- a/go.sum +++ b/go.sum @@ -628,6 +628,8 @@ github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/9 github.com/pelletier/go-toml v1.8.1 h1:1Nf83orprkJyknT6h7zbuEGUEjcyVlCxSUGTENmNCRM= github.com/pelletier/go-toml v1.8.1/go.mod h1:T2/BmBdy8dvIRq1a/8aqjN41wvWlN4lrapLU/GW4pbc= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= +github.com/pires/go-proxyproto v0.6.1 h1:EBupykFmo22SDjv4fQVQd2J9NOoLPmyZA/15ldOGkPw= +github.com/pires/go-proxyproto v0.6.1/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1-0.20171018195549-f15c970de5b7/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/internal/auditlog/codec/asciinema/encoder.go b/internal/auditlog/codec/asciinema/encoder.go index 5fa62029..9932028b 100644 --- a/internal/auditlog/codec/asciinema/encoder.go +++ b/internal/auditlog/codec/asciinema/encoder.go @@ -62,6 +62,7 @@ func (e *encoder) Encode(messages <-chan message.Message, storage storage.Writer startTime := int64(0) headerWritten := false var ip = "" + var proxy *string var username *string const shell = "/bin/sh" for { @@ -70,11 +71,12 @@ func (e *encoder) Encode(messages <-chan message.Message, storage storage.Writer break } var err error - startTime, headerWritten, ip, username, err = e.encodeMessage( + startTime, headerWritten, ip, proxy, username, err = e.encodeMessage( startTime, msg, &asciicastHeader, ip, + proxy, storage, username, headerWritten, @@ -106,11 +108,12 @@ func (e *encoder) encodeMessage( msg message.Message, asciicastHeader *Header, ip string, + proxy *string, storage storage.Writer, username *string, headerWritten bool, shell string, -) (int64, bool, string, *string, error) { +) (int64, bool, string, *string, *string, error) { if msg.MessageType == message.TypeConnect { startTime = msg.Timestamp asciicastHeader.Timestamp = int(startTime / 1000000000) @@ -121,11 +124,11 @@ func (e *encoder) encodeMessage( case message.TypeConnect: ip, username = e.handleConnect(storage, msg, startTime, country, username) case message.TypeAuthPasswordSuccessful: - ip, username = e.handleAuthPasswordSuccessful(storage, msg, startTime, ip, country) + ip, username = e.handleAuthPasswordSuccessful(storage, msg, startTime, ip, proxy, country) case message.TypeAuthPubKeySuccessful: - ip, username = e.handleAuthPubkeySuccessful(storage, msg, startTime, ip, country) + ip, username = e.handleAuthPubkeySuccessful(storage, msg, startTime, ip, proxy, country) case message.TypeHandshakeSuccessful: - ip, username = e.handleHandshakeSuccessful(storage, msg, startTime, ip, country) + ip, username = e.handleHandshakeSuccessful(storage, msg, startTime, ip, proxy, country) case message.TypeChannelRequestSetEnv: payload := msg.Payload.(message.PayloadChannelRequestSetEnv) asciicastHeader.Env[payload.Name] = payload.Value @@ -142,36 +145,40 @@ func (e *encoder) encodeMessage( startTime, headerWritten, err = e.handleIO(startTime, msg, asciicastHeader, headerWritten, shell, storage) } if err != nil { - return startTime, headerWritten, ip, username, err + return startTime, headerWritten, ip, proxy, username, err } - return startTime, headerWritten, ip, username, nil + return startTime, headerWritten, ip, proxy, username, nil } func (e *encoder) handleConnect(storage storage.Writer, msg message.Message, startTime int64, country string, username *string) (string, *string) { payload := msg.Payload.(message.PayloadConnect) ip := payload.RemoteAddr - storage.SetMetadata(startTime/1000000000, ip, country, username) + var proxy *string + if payload.ProxyAddr != "" { + proxy = &payload.ProxyAddr + } + storage.SetMetadata(startTime/1000000000, ip, proxy, country, username) return ip, username } -func (e *encoder) handleAuthPasswordSuccessful(storage storage.Writer, msg message.Message, startTime int64, ip string, country string) (string, *string) { +func (e *encoder) handleAuthPasswordSuccessful(storage storage.Writer, msg message.Message, startTime int64, ip string, proxy *string, country string) (string, *string) { payload := msg.Payload.(message.PayloadAuthPassword) username := &payload.Username - storage.SetMetadata(startTime/1000000000, ip, country, username) + storage.SetMetadata(startTime/1000000000, ip, proxy, country, username) return ip, username } -func (e *encoder) handleAuthPubkeySuccessful(storage storage.Writer, msg message.Message, startTime int64, ip string, country string) (string, *string) { +func (e *encoder) handleAuthPubkeySuccessful(storage storage.Writer, msg message.Message, startTime int64, ip string, proxy *string, country string) (string, *string) { payload := msg.Payload.(message.PayloadAuthPubKey) username := &payload.Username - storage.SetMetadata(startTime/1000000000, ip, country, username) + storage.SetMetadata(startTime/1000000000, ip, proxy, country, username) return ip, username } -func (e *encoder) handleHandshakeSuccessful(storage storage.Writer, msg message.Message, startTime int64, ip string, country string) (string, *string) { +func (e *encoder) handleHandshakeSuccessful(storage storage.Writer, msg message.Message, startTime int64, ip string, proxy *string, country string) (string, *string) { payload := msg.Payload.(message.PayloadHandshakeSuccessful) username := &payload.Username - storage.SetMetadata(startTime/1000000000, ip, country, username) + storage.SetMetadata(startTime/1000000000, ip, proxy, country, username) return ip, username } diff --git a/internal/auditlog/codec/asciinema/encoder_test.go b/internal/auditlog/codec/asciinema/encoder_test.go index 5560b7ec..fa48c6db 100644 --- a/internal/auditlog/codec/asciinema/encoder_test.go +++ b/internal/auditlog/codec/asciinema/encoder_test.go @@ -19,6 +19,7 @@ type writer struct { data bytes.Buffer startTime int64 sourceIP string + proxyIP *string username *string wait chan bool country string @@ -43,9 +44,10 @@ func (w *writer) waitForClose() { <-w.wait } -func (w *writer) SetMetadata(startTime int64, sourceIP string, country string, username *string) { +func (w *writer) SetMetadata(startTime int64, sourceIP string, proxyIP *string, country string, username *string) { w.startTime = startTime w.sourceIP = sourceIP + w.proxyIP = proxyIP w.username = username w.country = country } diff --git a/internal/auditlog/codec/binary/encode.go b/internal/auditlog/codec/binary/encode.go index 145daff1..308d3278 100644 --- a/internal/auditlog/codec/binary/encode.go +++ b/internal/auditlog/codec/binary/encode.go @@ -48,6 +48,7 @@ func (e *encoder) Encode(messages <-chan message.Message, storage storage.Writer startTime := int64(0) var ip = "" + var proxy *string var country = "XX" var username *string for { @@ -58,7 +59,7 @@ func (e *encoder) Encode(messages <-chan message.Message, storage storage.Writer if startTime == 0 { startTime = msg.Timestamp } - ip, country, username = e.storeMetadata(msg, storage, startTime, ip, country, username) + ip, proxy, country, username = e.storeMetadata(msg, storage, startTime, ip, proxy, country, username) if err := encoder.Encode(&msg); err != nil { return fmt.Errorf("failed to encode audit log message (%w)", err) } @@ -83,28 +84,32 @@ func (e *encoder) storeMetadata( storage storage.Writer, startTime int64, ip string, + proxy *string, country string, username *string, -) (string, string, *string) { +) (string, *string, string, *string) { switch msg.MessageType { case message.TypeConnect: - remoteAddr := msg.Payload.(message.PayloadConnect).RemoteAddr - ip = remoteAddr + payload := msg.Payload.(message.PayloadConnect) + ip = payload.RemoteAddr + if payload.ProxyAddr != "" { + proxy = &payload.ProxyAddr + } country := e.geoIPProvider.Lookup(net.ParseIP(ip)) - storage.SetMetadata(startTime/1000000000, ip, country, username) + storage.SetMetadata(startTime/1000000000, ip, proxy, country, username) case message.TypeAuthPasswordSuccessful: u := msg.Payload.(message.PayloadAuthPassword).Username username = &u - storage.SetMetadata(startTime/1000000000, ip, country, username) + storage.SetMetadata(startTime/1000000000, ip, proxy, country, username) case message.TypeAuthPubKeySuccessful: payload := msg.Payload.(message.PayloadAuthPubKey) username = &payload.Username - storage.SetMetadata(startTime/1000000000, ip, country, username) + storage.SetMetadata(startTime/1000000000, ip, proxy, country, username) case message.TypeHandshakeSuccessful: payload := msg.Payload.(message.PayloadHandshakeSuccessful) username = &payload.Username - storage.SetMetadata(startTime/1000000000, ip, country, username) + storage.SetMetadata(startTime/1000000000, ip, proxy, country, username) } - return ip, country, username + return ip, proxy, country, username } diff --git a/internal/auditlog/codec/proxy.go b/internal/auditlog/codec/proxy.go index 1ec8235f..ed980aad 100644 --- a/internal/auditlog/codec/proxy.go +++ b/internal/auditlog/codec/proxy.go @@ -23,6 +23,6 @@ func (s *storageWriterProxy) Close() error { return s.backend.Close() } -func (s *storageWriterProxy) SetMetadata(_ int64, _ string, _ string, _ *string) { +func (s *storageWriterProxy) SetMetadata(_ int64, _ string, _ *string, _ string, _ *string) { // No metadata storage } diff --git a/internal/auditlog/logger.go b/internal/auditlog/logger.go index 78ea8256..5454d8a1 100644 --- a/internal/auditlog/logger.go +++ b/internal/auditlog/logger.go @@ -12,7 +12,7 @@ import ( type Logger interface { // OnConnect creates an audit log message for a new connection and simultaneously returns a connection object for // connection-specific messages - OnConnect(connectionID message.ConnectionID, ip net.TCPAddr) (Connection, error) + OnConnect(connectionID message.ConnectionID, ip net.TCPAddr, proxy *net.TCPAddr) (Connection, error) // Shutdown triggers all failing uploads to cancel, waits for all currently running uploads to finish, then returns. // When the shutdownContext expires it will do its best to immediately upload any running background processes. Shutdown(shutdownContext context.Context) diff --git a/internal/auditlog/logger_empty.go b/internal/auditlog/logger_empty.go index 275b6eb5..9d670fe2 100644 --- a/internal/auditlog/logger_empty.go +++ b/internal/auditlog/logger_empty.go @@ -95,7 +95,7 @@ func (e *empty) OnNewChannelSuccess(_ message.ChannelID, _ string) Channel { return e } -func (e *empty) OnConnect(_ message.ConnectionID, _ net.TCPAddr) (Connection, error) { +func (e *empty) OnConnect(_ message.ConnectionID, _ net.TCPAddr, _ *net.TCPAddr) (Connection, error) { return e, nil } diff --git a/internal/auditlog/logger_impl.go b/internal/auditlog/logger_impl.go index 2b3ce7e2..c61d8609 100644 --- a/internal/auditlog/logger_impl.go +++ b/internal/auditlog/logger_impl.go @@ -55,7 +55,7 @@ func (l *loggerImplementation) Shutdown(shutdownContext context.Context) { //region Connection -func (l *loggerImplementation) OnConnect(connectionID message.ConnectionID, ip net.TCPAddr) (Connection, error) { +func (l *loggerImplementation) OnConnect(connectionID message.ConnectionID, ip net.TCPAddr, proxy *net.TCPAddr) (Connection, error) { name := string(connectionID) writer, err := l.storage.OpenWriter(name) if err != nil { @@ -76,12 +76,17 @@ func (l *loggerImplementation) OnConnect(connectionID message.ConnectionID, ip n l.logger.Emergency(err) } }() + proxyAddr := "" + if proxy != nil { + proxyAddr = proxy.IP.String() + } conn.log(message.Message{ ConnectionID: connectionID, Timestamp: time.Now().UnixNano(), MessageType: message.TypeConnect, Payload: message.PayloadConnect{ RemoteAddr: ip.IP.String(), + ProxyAddr: proxyAddr, Country: l.geoIPLookup.Lookup(ip.IP), }, ChannelID: nil, diff --git a/internal/auditlog/logger_test.go b/internal/auditlog/logger_test.go index c0ceece4..e75e88cf 100644 --- a/internal/auditlog/logger_test.go +++ b/internal/auditlog/logger_test.go @@ -203,6 +203,7 @@ func TestConnect(t *testing.T) { Port: 2222, Zone: "", }, + nil, ) if err != nil { assert.Fail(t, "failed to send connect message to logger", err) @@ -246,6 +247,7 @@ func TestAuth(t *testing.T) { Port: 2222, Zone: "", }, + nil, ) assert.Nil(t, err) connection.OnAuthPassword("foo", []byte("bar")) diff --git a/internal/auditlog/storage/file/struct.go b/internal/auditlog/storage/file/struct.go index f703b94d..b8fdb49f 100644 --- a/internal/auditlog/storage/file/struct.go +++ b/internal/auditlog/storage/file/struct.go @@ -74,5 +74,5 @@ func (w *writer) Close() error { return w.file.Close() } -func (w *writer) SetMetadata(_ int64, _ string, _ string, _ *string) { +func (w *writer) SetMetadata(_ int64, _ string, _ *string, _ string, _ *string) { } diff --git a/internal/auditlog/storage/none/writer.go b/internal/auditlog/storage/none/writer.go index 4bdebcba..e2cde6ac 100644 --- a/internal/auditlog/storage/none/writer.go +++ b/internal/auditlog/storage/none/writer.go @@ -3,7 +3,7 @@ package none type nullWriteCloser struct { } -func (w *nullWriteCloser) SetMetadata(_ int64, _ string, _ string, _ *string) { +func (w *nullWriteCloser) SetMetadata(_ int64, _ string, _ *string, _ string, _ *string) { } func (w *nullWriteCloser) Write(p []byte) (n int, err error) { diff --git a/internal/auditlog/storage/s3/queue.go b/internal/auditlog/storage/s3/queue.go index 73e06329..4ac82c51 100644 --- a/internal/auditlog/storage/s3/queue.go +++ b/internal/auditlog/storage/s3/queue.go @@ -22,6 +22,7 @@ var maxPartSize = uint(5 * 1024 * 1024 * 1024) type queueEntryMetadata struct { StartTime int64 `json:"startTime" yaml:"startTime"` RemoteAddr string `json:"remoteAddr" yaml:"remoteAddr"` + ProxyAddr string `json:"proxyAddr,omitempty" yaml:"proxyAddr"` Authenticated bool `json:"authenticated" yaml:"authenticated"` Username string `json:"username" yaml:"username"` Country string `json:"country" yaml:"country"` @@ -215,9 +216,14 @@ func (q *uploadQueue) getMonitoringWriter( return newMonitoringWriter( writeHandle, q.partSize, - func(startTime int64, remoteAddr string, country string, username *string) { + func(startTime int64, remoteAddr string, proxyIp *string, country string, username *string) { entry.metadata.StartTime = startTime entry.metadata.RemoteAddr = remoteAddr + if proxyIp == nil { + entry.metadata.ProxyAddr = "" + } else { + entry.metadata.ProxyAddr = *proxyIp + } entry.metadata.Country = country if username == nil { entry.metadata.Authenticated = false diff --git a/internal/auditlog/storage/s3/writer.go b/internal/auditlog/storage/s3/writer.go index f769f272..5950ea2a 100644 --- a/internal/auditlog/storage/s3/writer.go +++ b/internal/auditlog/storage/s3/writer.go @@ -9,7 +9,7 @@ import ( func newMonitoringWriter( backingWriter io.WriteCloser, partSize uint, - onMetadata func(startTime int64, remoteAddr string, country string, username *string), + onMetadata func(startTime int64, remoteAddr string, proxyIp *string, country string, username *string), onPart func(), onClose func(), ) storage.Writer { @@ -30,14 +30,14 @@ type monitoringWriter struct { backingWriter io.WriteCloser bytesWritten uint64 partSize uint - onMetadata func(startTime int64, remoteAddr string, country string, username *string) + onMetadata func(startTime int64, remoteAddr string, proxyIp *string, country string, username *string) onPart func() onClose func() lastPart int } -func (m *monitoringWriter) SetMetadata(startTime int64, sourceIP string, country string, username *string) { - m.onMetadata(startTime, sourceIP, country, username) +func (m *monitoringWriter) SetMetadata(startTime int64, sourceIP string, proxyIP *string, country string, username *string) { + m.onMetadata(startTime, sourceIP, proxyIP, country, username) } func (m *monitoringWriter) Write(p []byte) (n int, err error) { diff --git a/internal/auditlog/storage/storage.go b/internal/auditlog/storage/storage.go index 2cb09fea..6db35ee4 100644 --- a/internal/auditlog/storage/storage.go +++ b/internal/auditlog/storage/storage.go @@ -41,8 +41,9 @@ type Writer interface { // // startTime is the time when the connection started in unix timestamp. // sourceIp is the IP address the user connected from. + // proxyIp is the IP address the user connected with (or nil) // country is the ISO country code or "XX" if the lookup failed. // username is the username the user entered. The first time this method is called the username will be nil, // may be called subsequently is the user authenticated. - SetMetadata(startTime int64, sourceIP string, country string, username *string) + SetMetadata(startTime int64, sourceIP string, proxyIp *string, country string, username *string) } diff --git a/internal/auditlogintegration/handler.go b/internal/auditlogintegration/handler.go index af85c7a6..3a896d4c 100644 --- a/internal/auditlogintegration/handler.go +++ b/internal/auditlogintegration/handler.go @@ -34,12 +34,12 @@ func (h *handler) OnShutdown(shutdownContext context.Context) { wg.Wait() } -func (h *handler) OnNetworkConnection(client net.TCPAddr, connectionID string) (sshserver.NetworkConnectionHandler, error) { - backend, err := h.backend.OnNetworkConnection(client, connectionID) +func (h *handler) OnNetworkConnection(client net.TCPAddr, proxy *net.TCPAddr, connectionID string) (sshserver.NetworkConnectionHandler, error) { + backend, err := h.backend.OnNetworkConnection(client, proxy, connectionID) if err != nil { return nil, err } - auditConnection, err := h.logger.OnConnect(message.ConnectionID(connectionID), client) + auditConnection, err := h.logger.OnConnect(message.ConnectionID(connectionID), client, proxy) if err != nil { return nil, fmt.Errorf( "failed to initialize audit logger for connection from %s (%w)", diff --git a/internal/auditlogintegration/integration_test.go b/internal/auditlogintegration/integration_test.go index 06cc0155..a862cad6 100644 --- a/internal/auditlogintegration/integration_test.go +++ b/internal/auditlogintegration/integration_test.go @@ -13,9 +13,9 @@ import ( "github.com/containerssh/libcontainerssh/auditlog/message" auth2 "github.com/containerssh/libcontainerssh/auth" "github.com/containerssh/libcontainerssh/config" - "github.com/containerssh/libcontainerssh/internal/auth" "github.com/containerssh/libcontainerssh/internal/auditlog/codec/binary" "github.com/containerssh/libcontainerssh/internal/auditlog/storage/file" + "github.com/containerssh/libcontainerssh/internal/auth" "github.com/containerssh/libcontainerssh/internal/geoip" "github.com/containerssh/libcontainerssh/internal/sshserver" "github.com/containerssh/libcontainerssh/log" @@ -352,6 +352,7 @@ func (b *backendHandler) OnShutdown(_ context.Context) { func (b *backendHandler) OnNetworkConnection( _ net.TCPAddr, + _ *net.TCPAddr, _ string, ) (sshserver.NetworkConnectionHandler, error) { return b, nil diff --git a/internal/authintegration/handler.go b/internal/authintegration/handler.go index a620d3db..4d98d181 100644 --- a/internal/authintegration/handler.go +++ b/internal/authintegration/handler.go @@ -59,11 +59,11 @@ func (h *handler) OnShutdown(shutdownContext context.Context) { } } -func (h *handler) OnNetworkConnection(client net.TCPAddr, connectionID string) (sshserver.NetworkConnectionHandler, error) { +func (h *handler) OnNetworkConnection(client net.TCPAddr, proxy *net.TCPAddr, connectionID string) (sshserver.NetworkConnectionHandler, error) { var backend sshserver.NetworkConnectionHandler = nil var err error if h.backend != nil { - backend, err = h.backend.OnNetworkConnection(client, connectionID) + backend, err = h.backend.OnNetworkConnection(client, proxy, connectionID) if err != nil { return nil, err } diff --git a/internal/authintegration/integration_test.go b/internal/authintegration/integration_test.go index 8670d598..d0234d39 100644 --- a/internal/authintegration/integration_test.go +++ b/internal/authintegration/integration_test.go @@ -167,7 +167,7 @@ func (t *testBackend) OnShutdown(_ context.Context) { } -func (t *testBackend) OnNetworkConnection(_ net.TCPAddr, _ string) ( +func (t *testBackend) OnNetworkConnection(_ net.TCPAddr, _ *net.TCPAddr, _ string) ( sshserver.NetworkConnectionHandler, error, ) { diff --git a/internal/backend/handler.go b/internal/backend/handler.go index 780c7c5d..7553243c 100644 --- a/internal/backend/handler.go +++ b/internal/backend/handler.go @@ -36,12 +36,19 @@ type handler struct { func (h *handler) OnNetworkConnection( remoteAddr net.TCPAddr, + proxy *net.TCPAddr, connectionID string, ) (sshserver.NetworkConnectionHandler, error) { + logger := h.logger. + WithLabel("connectionId", connectionID). + WithLabel("remoteAddr", remoteAddr.IP.String()) + + if proxy != nil { + logger = logger.WithLabel("fromProxy", proxy.IP.String()) + } + return &networkHandler{ - logger: h.logger. - WithLabel("connectionId", connectionID). - WithLabel("remoteAddr", remoteAddr.IP.String()), + logger: logger, rootHandler: h, remoteAddr: remoteAddr, connectionID: connectionID, diff --git a/internal/metricsintegration/handler.go b/internal/metricsintegration/handler.go index 7e428007..964b0f72 100644 --- a/internal/metricsintegration/handler.go +++ b/internal/metricsintegration/handler.go @@ -30,9 +30,10 @@ func (m *metricsHandler) OnShutdown(shutdownContext context.Context) { func (m *metricsHandler) OnNetworkConnection( client net.TCPAddr, + proxy *net.TCPAddr, connectionID string, ) (sshserver.NetworkConnectionHandler, error) { - networkBackend, err := m.backend.OnNetworkConnection(client, connectionID) + networkBackend, err := m.backend.OnNetworkConnection(client, proxy, connectionID) if err != nil { return networkBackend, err } diff --git a/internal/metricsintegration/integration_test.go b/internal/metricsintegration/integration_test.go index 1ffc0551..a325ad9c 100644 --- a/internal/metricsintegration/integration_test.go +++ b/internal/metricsintegration/integration_test.go @@ -51,6 +51,7 @@ func testAuthSuccessful( IP: net.ParseIP("127.0.0.1"), Port: 2222, }, + nil, sshserver.GenerateConnectionID(), ) if !assert.NoError(t, err) { @@ -95,6 +96,7 @@ func testAuthFailed( IP: net.ParseIP("127.0.0.1"), Port: 2222, }, + nil, sshserver.GenerateConnectionID(), ) assert.NoError(t, err) @@ -129,6 +131,7 @@ func (d *dummyBackendHandler) OnShutdown(_ context.Context) { func (d *dummyBackendHandler) OnNetworkConnection( _ net.TCPAddr, + _ *net.TCPAddr, _ string, ) (sshserver.NetworkConnectionHandler, error) { return d, nil diff --git a/internal/proxyproto/proxyproto.go b/internal/proxyproto/proxyproto.go new file mode 100644 index 00000000..ea1a207d --- /dev/null +++ b/internal/proxyproto/proxyproto.go @@ -0,0 +1,34 @@ +package proxyproto + +import ( + "net" + + "github.com/pires/go-proxyproto" +) + +// WrapProxy is a function that wraps a net.Conn around the PROXY tcp protocol. It is used for correctly reporting the originator IP address when a service is running behind a load balancer +// In case proxy use is allowed the wrapped network connection is returned along with the IP address of the proxy that it is used. The wrapped network connection will return the IP address +// of the client when RemoteAddr() is called +// +// conn is the network connection to wrap +// proxyList is a list of addresses that are allowed to send proxy information +// +func WrapProxy(conn net.Conn, proxyList []string) (net.Conn, *net.TCPAddr, error) { + if len(proxyList) == 0 { + return conn, nil, nil + } + policyFunc := proxyproto.MustStrictWhiteListPolicy(proxyList) + policy, err := policyFunc(conn.RemoteAddr()) + if err != nil { + return nil, nil, err + } + if policy == proxyproto.REJECT || policy == proxyproto.IGNORE { + // If it's not an approved proxy we should fail loudly, not silently + return conn, nil, nil + } + tcpAddr := conn.RemoteAddr().(*net.TCPAddr) + return proxyproto.NewConn( + conn, + proxyproto.WithPolicy(policy), + ), tcpAddr, nil +} diff --git a/internal/proxyproto/proxyproto_test.go b/internal/proxyproto/proxyproto_test.go new file mode 100644 index 00000000..0d18c5da --- /dev/null +++ b/internal/proxyproto/proxyproto_test.go @@ -0,0 +1,148 @@ +package proxyproto_test + +import ( + "fmt" + "io" + "net" + "testing" + "time" + + "github.com/containerssh/libcontainerssh/internal/proxyproto" + goproxyproto "github.com/pires/go-proxyproto" +) + +type fakeConn struct { + remoteAddr string + localAddr string + pipeReader io.ReadCloser + pipeWriter io.WriteCloser +} + +func NewFakeConn(clientAddr string, serverAddr string) (fakeConn, fakeConn) { + clientPipeReader, clientPipeWriter := io.Pipe() + serverPipeReader, serverPipeWriter := io.Pipe() + return fakeConn{ + remoteAddr: clientAddr, + localAddr: serverAddr, + pipeReader: serverPipeReader, + pipeWriter: clientPipeWriter, + }, fakeConn{ + remoteAddr: serverAddr, + localAddr: clientAddr, + pipeReader: clientPipeReader, + pipeWriter: serverPipeWriter, + } +} + +func (f fakeConn) Read(b []byte) (n int, err error) { + return f.pipeReader.Read(b) +} + +func (f fakeConn) Write(b []byte) (n int, err error) { + return f.pipeWriter.Write(b) +} + +func (f fakeConn) Close() error { + f.pipeWriter.Close() + f.pipeReader.Close() + return nil +} + +func (f fakeConn) LocalAddr() net.Addr { + return &net.TCPAddr{ + IP: net.ParseIP(f.localAddr), + } +} + +func (f fakeConn) RemoteAddr() net.Addr { + return &net.TCPAddr{ + IP: net.ParseIP(f.remoteAddr), + } +} +func (f fakeConn) SetDeadline(t time.Time) error { + return fmt.Errorf("Unimplemented") +} +func (f fakeConn) SetReadDeadline(t time.Time) error { + return fmt.Errorf("Unimplemented") +} +func (f fakeConn) SetWriteDeadline(t time.Time) error { + return fmt.Errorf("Unimplemented") +} + +func TestProxyWithHeader(t *testing.T) { + clientIP := "127.0.0.1" + proxyIP := "127.0.0.2" + serverIP := "127.0.0.3" + + server, proxy := NewFakeConn(proxyIP, serverIP) + wrappedConn, proxyAddr, err := proxyproto.WrapProxy(server, []string{proxyIP}) + if err != nil { + t.Fatal(err) + } + + header := &goproxyproto.Header{ + Version: 1, + Command: goproxyproto.PROXY, + TransportProtocol: goproxyproto.TCPv4, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP(clientIP), + Port: 1000, + }, + DestinationAddr: &net.TCPAddr{ + IP: net.ParseIP(proxyIP), + Port: 2000, + }, + } + go func() { + _, err := header.WriteTo(proxy) + if err != nil { + return + } + }() + + if proxyAddr == nil { + t.Fatalf("Proxy info was rejected") + } + if proxyAddr.String() != proxyIP+":0" { + t.Fatalf("Unexpected proxy address %s, expected %s", proxyAddr, proxyIP) + } + if wrappedConn.RemoteAddr().String() != clientIP+":1000" { + t.Fatalf("Header not accepted when it should be %s != %s", wrappedConn.RemoteAddr().String(), clientIP+":1000") + } +} + +func TestProxyUnauthorizedHeader(t *testing.T) { + clientIP := "127.0.0.1" + proxyIP := "127.0.0.2" + serverIP := "127.0.0.3" + + server, proxy := NewFakeConn(proxyIP, serverIP) + _, proxyAddr, err := proxyproto.WrapProxy(server, []string{"128.0.0.2"}) + if err != nil { + t.Fatal(err) + } + + header := &goproxyproto.Header{ + Version: 1, + Command: goproxyproto.PROXY, + TransportProtocol: goproxyproto.TCPv4, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP(clientIP), + Port: 1000, + }, + DestinationAddr: &net.TCPAddr{ + IP: net.ParseIP(proxyIP), + Port: 2000, + }, + } + go func() { + _, err := header.WriteTo(proxy) + if err != nil { + return + } + }() + + if proxyAddr != nil { + t.Fatalf("Proxy info was accepted when unauthorized") + } +} diff --git a/internal/sshserver/AbstractHandler.go b/internal/sshserver/AbstractHandler.go index 6c168b87..fc0e5c67 100644 --- a/internal/sshserver/AbstractHandler.go +++ b/internal/sshserver/AbstractHandler.go @@ -28,6 +28,6 @@ func (a *AbstractHandler) OnShutdown(_ context.Context) { // // The ip parameter provides the IP address of the connecting user. The connectionID parameter provides an opaque // binary identifier for the connection that can be used to track the connection across multiple subsystems. -func (a *AbstractHandler) OnNetworkConnection(_ net.TCPAddr, _ string) (NetworkConnectionHandler, error) { +func (a *AbstractHandler) OnNetworkConnection(_ net.TCPAddr, _ *net.TCPAddr, _ string) (NetworkConnectionHandler, error) { return nil, fmt.Errorf("not implemented") } diff --git a/internal/sshserver/Server_test.go b/internal/sshserver/Server_test.go index bde94938..6f9feb92 100644 --- a/internal/sshserver/Server_test.go +++ b/internal/sshserver/Server_test.go @@ -344,7 +344,6 @@ func TestKeepAlive(t *testing.T) { srv.Start() defer srv.Stop(1 * time.Minute) - hostkey, err := ssh.ParsePrivateKey([]byte(srv.GetHostKey())) if err != nil { t.Fatal("Failed to parse private key") @@ -394,10 +393,10 @@ func TestKeepAlive(t *testing.T) { elapsed := recv2.Sub(recv1) - if elapsed > 2 * time.Second { + if elapsed > 2*time.Second { t.Fatal("Received keepalive in too big of an interval", elapsed) } - if elapsed < time.Second / 2 { + if elapsed < time.Second/2 { t.Fatal("Received keepalive in too short of an interval", elapsed) } } @@ -601,7 +600,7 @@ func (r *rejectHandler) OnReady() error { func (r *rejectHandler) OnShutdown(_ context.Context) { } -func (r *rejectHandler) OnNetworkConnection(_ net.TCPAddr, _ string) (sshserver.NetworkConnectionHandler, error) { +func (r *rejectHandler) OnNetworkConnection(_ net.TCPAddr, _ *net.TCPAddr, _ string) (sshserver.NetworkConnectionHandler, error) { return nil, fmt.Errorf("not implemented") } @@ -651,7 +650,7 @@ func (f *fullHandler) OnShutdown(shutdownContext context.Context) { close(f.shutdownDone) } -func (f *fullHandler) OnNetworkConnection(_ net.TCPAddr, _ string) (sshserver.NetworkConnectionHandler, error) { +func (f *fullHandler) OnNetworkConnection(_ net.TCPAddr, _ *net.TCPAddr, _ string) (sshserver.NetworkConnectionHandler, error) { return &fullNetworkConnectionHandler{ handler: f, }, nil diff --git a/internal/sshserver/conformanceTestHandler.go b/internal/sshserver/conformanceTestHandler.go index 4adc56b6..c2a70f46 100644 --- a/internal/sshserver/conformanceTestHandler.go +++ b/internal/sshserver/conformanceTestHandler.go @@ -10,6 +10,6 @@ type conformanceTestHandler struct { backend NetworkConnectionHandler } -func (h *conformanceTestHandler) OnNetworkConnection(_ net.TCPAddr, _ string) (NetworkConnectionHandler, error) { +func (h *conformanceTestHandler) OnNetworkConnection(_ net.TCPAddr, _ *net.TCPAddr, _ string) (NetworkConnectionHandler, error) { return h.backend, nil } diff --git a/internal/sshserver/handler.go b/internal/sshserver/handler.go index 498da02c..7fa73206 100644 --- a/internal/sshserver/handler.go +++ b/internal/sshserver/handler.go @@ -28,9 +28,9 @@ type Handler interface { // OnNetworkConnection is called when a new network connection is opened. It must either return a // NetworkConnectionHandler object or an error. In case of an error the network connection is closed. // - // The ip parameter provides the IP address of the connecting user. The connectionID parameter provides an opaque - // binary identifier for the connection that can be used to track the connection across multiple subsystems. - OnNetworkConnection(client net.TCPAddr, connectionID string) (NetworkConnectionHandler, error) + // The ip parameter provides the IP address of the connecting user. The proxy parameter provides the IP of the load balancer if one is used. + // The connectionID parameter provides an opaque binary identifier for the connection that can be used to track the connection across multiple subsystems. + OnNetworkConnection(client net.TCPAddr, proxy *net.TCPAddr, connectionID string) (NetworkConnectionHandler, error) } // AuthResponse is the result of the authentication process. diff --git a/internal/sshserver/serverImpl.go b/internal/sshserver/serverImpl.go index e85c423f..fe71f358 100644 --- a/internal/sshserver/serverImpl.go +++ b/internal/sshserver/serverImpl.go @@ -10,6 +10,7 @@ import ( "github.com/containerssh/libcontainerssh/auth" "github.com/containerssh/libcontainerssh/config" + "github.com/containerssh/libcontainerssh/internal/proxyproto" ssh2 "github.com/containerssh/libcontainerssh/internal/ssh" "github.com/containerssh/libcontainerssh/log" messageCodes "github.com/containerssh/libcontainerssh/message" @@ -59,6 +60,7 @@ func (s *serverImpl) RunWithLifecycle(lifecycle service.Lifecycle) error { s.lock.Unlock() return messageCodes.Wrap(err, messageCodes.ESSHStartFailed, "failed to start SSH server on %s", s.cfg.Listen) } + s.listenSocket = netListener s.lock.Unlock() if err := s.handler.OnReady(); err != nil { @@ -77,14 +79,19 @@ func (s *serverImpl) RunWithLifecycle(lifecycle service.Lifecycle) error { s.logger.Info(messageCodes.NewMessage(messageCodes.MSSHServiceAvailable, "SSH server running on %s", s.cfg.Listen)) go s.handleListenSocketOnShutdown(lifecycle) + for { tcpConn, err := netListener.Accept() if err != nil { // Assume listen socket closed break } + tcpConn, proxyAddr, err := proxyproto.WrapProxy(tcpConn, s.cfg.AllowedProxies) + if err != nil { + break + } s.wg.Add(1) - go s.handleConnection(tcpConn) + go s.handleConnection(tcpConn, proxyAddr) } lifecycle.Stopping() s.shuttingDown = true @@ -318,7 +325,7 @@ func (s *serverImpl) createConfiguration( PasswordCallback: passwordCallback, PublicKeyCallback: pubkeyCallback, KeyboardInteractiveCallback: keyboardInteractiveCallback, - GSSAPIWithMICConfig: gssConfig, + GSSAPIWithMICConfig: gssConfig, ServerVersion: s.cfg.ServerVersion.String(), BannerCallback: func(conn ssh.ConnMetadata) string { return s.cfg.Banner }, } @@ -347,7 +354,7 @@ func (s *serverImpl) createAuthenticators( func (s *serverImpl) createGSSAPIConfig( handlerNetworkConnection *networkConnectionWrapper, logger log.Logger, -) (*ssh.GSSAPIWithMICConfig){ +) *ssh.GSSAPIWithMICConfig { var gssConfig *ssh.GSSAPIWithMICConfig gssServer := handlerNetworkConnection.OnAuthGSSAPI() @@ -501,13 +508,18 @@ func (s *serverImpl) createPasswordCallback( return passwordCallback } -func (s *serverImpl) handleConnection(conn net.Conn) { +func (s *serverImpl) handleConnection(conn net.Conn, proxy *net.TCPAddr) { addr := conn.RemoteAddr().(*net.TCPAddr) connectionID := GenerateConnectionID() logger := s.logger. WithLabel("remoteAddr", addr.IP.String()). WithLabel("connectionId", connectionID) - handlerNetworkConnection, err := s.handler.OnNetworkConnection(*addr, connectionID) + + if proxy != nil { + logger = logger.WithLabel("fromProxy", proxy.IP.String()) + } + + handlerNetworkConnection, err := s.handler.OnNetworkConnection(*addr, proxy, connectionID) if err != nil { logger.Info(err) _ = conn.Close() diff --git a/internal/sshserver/testAuthenticationHandler.go b/internal/sshserver/testAuthenticationHandler.go index 0c2977da..2e983f5b 100644 --- a/internal/sshserver/testAuthenticationHandler.go +++ b/internal/sshserver/testAuthenticationHandler.go @@ -21,9 +21,10 @@ func (t *testAuthenticationHandler) OnShutdown(ctx context.Context) { func (t *testAuthenticationHandler) OnNetworkConnection( client net.TCPAddr, + proxy *net.TCPAddr, connectionID string, ) (NetworkConnectionHandler, error) { - backend, err := t.backend.OnNetworkConnection(client, connectionID) + backend, err := t.backend.OnNetworkConnection(client, proxy, connectionID) if err != nil { return nil, err } diff --git a/internal/sshserver/testHandlerImpl.go b/internal/sshserver/testHandlerImpl.go index e77254da..88692af9 100644 --- a/internal/sshserver/testHandlerImpl.go +++ b/internal/sshserver/testHandlerImpl.go @@ -16,7 +16,7 @@ func (t *testHandlerImpl) OnShutdown(_ context.Context) { t.shutdown = true } -func (t *testHandlerImpl) OnNetworkConnection(client net.TCPAddr, connectionID string) (NetworkConnectionHandler, error) { +func (t *testHandlerImpl) OnNetworkConnection(client net.TCPAddr, proxy *net.TCPAddr, connectionID string) (NetworkConnectionHandler, error) { return &testNetworkHandlerImpl{ client: client, connectionID: connectionID,