Skip to content

Commit

Permalink
Implement keepalive support (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
tsipinakis authored Jan 4, 2022
1 parent 7e0ee33 commit ab477ce
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 33 deletions.
15 changes: 15 additions & 0 deletions config/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"os"
"regexp"
"strings"
"time"

"golang.org/x/crypto/ssh"
)
Expand All @@ -34,6 +35,14 @@ type SSHConfig struct {
Banner string `json:"banner" yaml:"banner" comment:"Host banner to show after the username" default:""`
// HostKeys are the host keys either in PEM format, or filenames to load.
HostKeys []string `json:"hostkeys" yaml:"hostkeys" comment:"Host keys in PEM format or files to load PEM host keys from."`
// ClientAliveInterval is the duration between keep alive messages that
// ContainerSSH will send to each client. If the duration is 0 or unset
// it disables the feature.
ClientAliveInterval time.Duration `json:"clientAliveInterval" yaml:"clientAliveInterval" comment:"Inverval to send keepalive packets to the client"`
// ClientAliveCountMax is the number of keepalive messages that is
// allowed to be sent without a response being received. If this number
// is exceeded the connection is considered dead
ClientAliveCountMax int `json:"clientAliveCountMax" yaml:"clientAliveCountMax" default:"3" comment:"Maximum number of failed keepalives"`
}

// GenerateHostKey generates a random host key and adds it to SSHConfig
Expand Down Expand Up @@ -105,6 +114,12 @@ func (cfg SSHConfig) Validate() error {
if err := cfg.MACs.Validate(); err != nil {
return fmt.Errorf("invalid MAc list (%w)", err)
}
if cfg.ClientAliveInterval != 0 && cfg.ClientAliveInterval < 1 * time.Second {
return fmt.Errorf("clientAliveInterval should be at least 1 second long")
}
if cfg.ClientAliveCountMax <= 0 {
return fmt.Errorf("clientAliveCountMax should be at least 1")
}
return nil
}

Expand Down
4 changes: 2 additions & 2 deletions internal/auditlogintegration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestKeyboardInteractiveAuthentication(t *testing.T) {
user := sshserver.NewTestUser("test")
user.AddKeyboardInteractiveChallengeResponse("Challenge", "Response")

srv := sshserver.NewTestServer(t, auditLogHandler, logger)
srv := sshserver.NewTestServer(t, auditLogHandler, logger, nil)
srv.Start()
client := sshserver.NewTestClient(srv.GetListen(), srv.GetHostKey(), user, logger)
connection := client.MustConnect()
Expand Down Expand Up @@ -129,7 +129,7 @@ func createTestServer(t *testing.T, dir string, logger log.Logger) (sshserver.Te
)
assert.NoError(t, err)

srv := sshserver.NewTestServer(t, auditLogHandler, logger)
srv := sshserver.NewTestServer(t, auditLogHandler, logger, nil)
user := sshserver.NewTestUser("test")
user.RandomPassword()

Expand Down
14 changes: 8 additions & 6 deletions internal/sshserver/NewTestServer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@ import (
)

// NewTestServer is a simplified API to start and stop a test server.
func NewTestServer(t *testing.T, handler Handler, logger log.Logger) TestServer {
config := config2.SSHConfig{}
structutils.Defaults(&config)
func NewTestServer(t *testing.T, handler Handler, logger log.Logger, config *config2.SSHConfig) TestServer {
if config == nil {
config = &config2.SSHConfig{}
structutils.Defaults(config)
}

port := test.GetNextPort(t, "SSH")
config.Listen = fmt.Sprintf("127.0.0.1:%d", port)
if err := config.GenerateHostKey(); err != nil {
panic(err)
}
svc, err := New(config, handler, logger)
svc, err := New(*config, handler, logger)
if err != nil {
panic(err)
}
Expand All @@ -31,15 +33,15 @@ func NewTestServer(t *testing.T, handler Handler, logger log.Logger) TestServer
lifecycle.OnRunning(
func(s service.Service, l service.Lifecycle) {
started <- struct{}{}
})
})

t.Cleanup(func() {
lifecycle.Stop(context.Background())
_ = lifecycle.Wait()
})

return &testServerImpl{
config: config,
config: *config,
lifecycle: lifecycle,
started: started,
}
Expand Down
82 changes: 82 additions & 0 deletions internal/sshserver/Server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ func TestAuthKeyboardInteractive(t *testing.T) {
user2,
),
logger,
nil,
)
srv.Start()

Expand Down Expand Up @@ -245,6 +246,87 @@ func TestPubKey(t *testing.T) {
assert.Equal(t, []byte("Hello world!"), reply)
}

func TestKeepAlive(t *testing.T) {
//t.Parallel()()

logger := log.NewTestLogger(t)

user := sshserver.NewTestUser("test")
user.RandomPassword()

config := config.SSHConfig{}
structutils.Defaults(&config)
config.ClientAliveInterval = 1 * time.Second
srv := sshserver.NewTestServer(
t,
sshserver.NewTestAuthenticationHandler(
sshserver.NewTestHandler(),
user,
),
logger,
&config,
)
srv.Start()
defer srv.Stop(1 * time.Minute)


hostkey, err := ssh.ParsePrivateKey([]byte(srv.GetHostKey()))
if err != nil {
t.Fatal("Failed to parse private key")
}
sshConfig := &ssh.ClientConfig{
User: user.Username(),
Auth: user.GetAuthMethods(),
}
sshConfig.HostKeyCallback = func(hostname string, remote net.Addr, key ssh.PublicKey) error {
if bytes.Equal(key.Marshal(), hostkey.PublicKey().Marshal()) {
return nil
}
return fmt.Errorf("invalid host")
}
tcpConnection, err := net.Dial("tcp", srv.GetListen())
if err != nil {
t.Fatal("tcp handshake failed (%w)", err)
}
connection, _, globalReq, err := ssh.NewClientConn(tcpConnection, srv.GetListen(), sshConfig)
if err != nil {
t.Fatal("ssh handshake failed (%w)", err)
}
defer func() {
_ = connection.Close()
}()

req := <-globalReq
err = req.Reply(false, nil)
if err != nil {
t.Fatal("Failed to respond to first request")
}
recv1 := time.Now()

req2 := <-globalReq
recv2 := time.Now()
err = req.Reply(false, nil)
if err != nil {
t.Fatal("Failed to respond to second request")
}

if req.Type != "[email protected]" {
t.Fatal("Expected keepalive request", req.Type)
}
if req2.Type != "[email protected]" {
t.Fatal("Expected keepalive request", req.Type)
}

elapsed := recv2.Sub(recv1)

if elapsed > 2 * time.Second {
t.Fatal("Received keepalive in too big of an interval", elapsed)
}
if elapsed < time.Second / 2 {
t.Fatal("Received keepalive in too short of an interval", elapsed)
}
}

//endregion

//region Helper
Expand Down
2 changes: 1 addition & 1 deletion internal/sshserver/TestUser.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (u *TestUser) GetAuthorizedKeys() []string {
return u.authorizedKeys
}

func (u *TestUser) getAuthMethods() []ssh.AuthMethod {
func (u *TestUser) GetAuthMethods() []ssh.AuthMethod {
var result []ssh.AuthMethod
if u.password != "" {
result = append(result, ssh.Password(u.password))
Expand Down
11 changes: 6 additions & 5 deletions internal/sshserver/conformanceTestSuite.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (c *conformanceTestSuite) singleProgramShouldRun(t *testing.T) {
srv := NewTestServer(t, NewTestAuthenticationHandler(
newConformanceTestHandler(backend),
user,
), logger)
), logger, nil)
srv.Start()
defer srv.Stop(1 * time.Minute)

Expand Down Expand Up @@ -70,7 +70,7 @@ func (c *conformanceTestSuite) settingEnvVariablesShouldWork(t *testing.T) {
srv := NewTestServer(t, NewTestAuthenticationHandler(
newConformanceTestHandler(backend),
user,
), logger)
), logger, nil)
srv.Start()
defer srv.Stop(1 * time.Minute)

Expand Down Expand Up @@ -117,7 +117,7 @@ func (c *conformanceTestSuite) runningInteractiveShellShouldWork(t *testing.T) {
srv := NewTestServer(t, NewTestAuthenticationHandler(
newConformanceTestHandler(backend),
user,
), logger)
), logger, nil)
srv.Start()
defer srv.Stop(1 * time.Minute)

Expand Down Expand Up @@ -201,7 +201,7 @@ func (c *conformanceTestSuite) reportingExitCodeShouldWork(t *testing.T) {
srv := NewTestServer(t, NewTestAuthenticationHandler(
newConformanceTestHandler(backend),
user,
), logger)
), logger, nil)
srv.Start()
defer srv.Stop(1 * time.Minute)

Expand Down Expand Up @@ -240,7 +240,8 @@ func (c *conformanceTestSuite) sendingSignalsShouldWork(t *testing.T) {
srv := NewTestServer(t, NewTestAuthenticationHandler(
newConformanceTestHandler(backend),
user,
), logger)
nil,
), logger, nil)
srv.Start()
defer srv.Stop(1 * time.Minute)

Expand Down
92 changes: 74 additions & 18 deletions internal/sshserver/serverImpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net"
"strings"
"sync"
"time"

"github.com/containerssh/libcontainerssh/config"
ssh2 "github.com/containerssh/libcontainerssh/internal/ssh"
Expand Down Expand Up @@ -482,6 +483,35 @@ func (s *serverImpl) handleConnection(conn net.Conn) {
sshShutdownHandlerID := fmt.Sprintf("ssh-%s", connectionID)
s.lock.Unlock()

if s.cfg.ClientAliveInterval > 0 {
go func() {
missedAlives := 0
for {
time.Sleep(s.cfg.ClientAliveInterval)

_, _, err := sshConn.SendRequest("[email protected]", true, []byte{})

if err != nil {
missedAlives++

logger.Debug(
messageCodes.Wrap(
err,
messageCodes.ESSHKeepAliveFailed,
"Keepalive error",
),
)
if missedAlives >= s.cfg.ClientAliveCountMax {
_ = sshConn.Close()
break
}
continue
}
missedAlives = 0
}
}()
}

go func() {
_ = sshConn.Wait()
logger.Debug(messageCodes.NewMessage(messageCodes.MSSHDisconnected, "Client disconnected"))
Expand All @@ -498,6 +528,44 @@ func (s *serverImpl) handleConnection(conn net.Conn) {
go s.handleGlobalRequests(globalRequests, handlerSSHConnection, logger)
}

func (s *serverImpl) handleKeepAliveRequest(req *ssh.Request, logger log.Logger) {
if req.WantReply {
if err := req.Reply(false, []byte{}); err != nil {
logger.Debug(
messageCodes.Wrap(
err,
messageCodes.ESSHReplyFailed,
"failed to send reply to global request type %s",
req.Type,
),
)
}
}
}

func (s *serverImpl) handleUnknownGlobalRequest(req *ssh.Request, requestID uint64, connection SSHConnectionHandler, logger log.Logger) {
logger.Debug(
messageCodes.NewMessage(messageCodes.ESSHUnsupportedGlobalRequest, "Unsupported global request").Label(
"type",
req.Type,
),
)

connection.OnUnsupportedGlobalRequest(requestID, req.Type, req.Payload)
if req.WantReply {
if err := req.Reply(false, []byte("request type not supported")); err != nil {
logger.Debug(
messageCodes.Wrap(
err,
messageCodes.ESSHReplyFailed,
"failed to send reply to global request type %s",
req.Type,
),
)
}
}
}

func (s *serverImpl) handleGlobalRequests(
requests <-chan *ssh.Request,
connection SSHConnectionHandler,
Expand All @@ -510,24 +578,12 @@ func (s *serverImpl) handleGlobalRequests(
}
requestID := s.nextGlobalRequestID
s.nextGlobalRequestID++
logger.Debug(
messageCodes.NewMessage(messageCodes.ESSHUnsupportedGlobalRequest, "Unsupported global request").Label(
"type",
request.Type,
),
)
connection.OnUnsupportedGlobalRequest(requestID, request.Type, request.Payload)
if request.WantReply {
if err := request.Reply(false, []byte("request type not supported")); err != nil {
logger.Debug(
messageCodes.Wrap(
err,
messageCodes.ESSHReplyFailed,
"failed to send reply to global request type %s",
request.Type,
),
)
}

switch request.Type {
case "[email protected]":
s.handleKeepAliveRequest(request, logger)
default:
s.handleUnknownGlobalRequest(request, requestID, connection, logger)
}
}
}
Expand Down
1 change: 1 addition & 0 deletions internal/sshserver/shutdown_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ func TestProperShutdown(t *testing.T) {
user,
),
logger,
nil,
)
testServer.Start()

Expand Down
2 changes: 1 addition & 1 deletion internal/sshserver/testClientImpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (t *testClientImpl) Connect() (TestClientConnection, error) {
t.logger.Debug(messageCodes.NewMessage(messageCodes.MTest, "Connecting SSH server..."))
sshConfig := &ssh.ClientConfig{
User: t.user.Username(),
Auth: t.user.getAuthMethods(),
Auth: t.user.GetAuthMethods(),
}
sshConfig.HostKeyCallback = func(hostname string, remote net.Addr, key ssh.PublicKey) error {
if bytes.Equal(key.Marshal(), t.hostKey) {
Expand Down
3 changes: 3 additions & 0 deletions message/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ const MSSHHandshakeSuccessful = "SSH_HANDSHAKE_SUCCESSFUL"
// This is nothing to worry about.
const ESSHUnsupportedGlobalRequest = "SSH_UNSUPPORTED_GLOBAL_REQUEST"

// ESSHKeepAliveFailed indicates that ContainerSSH couldn't send or didn't receive a response to a keepalive packet
const ESSHKeepAliveFailed = "SSH_KEEPALIVE_NORESP"

// ESSHReplyFailed indicates that ContainerSSH couldn't send the reply to a request to the user. This is usually the
// case if the user suddenly disconnects.
const ESSHReplyFailed = "SSH_REPLY_SEND_FAILED"
Expand Down

0 comments on commit ab477ce

Please sign in to comment.