-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7e0ee33
commit ab477ce
Showing
10 changed files
with
193 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -111,6 +111,7 @@ func TestAuthKeyboardInteractive(t *testing.T) { | |
user2, | ||
), | ||
logger, | ||
nil, | ||
) | ||
srv.Start() | ||
|
||
|
@@ -245,6 +246,87 @@ func TestPubKey(t *testing.T) { | |
assert.Equal(t, []byte("Hello world!"), reply) | ||
} | ||
|
||
func TestKeepAlive(t *testing.T) { | ||
//t.Parallel()() | ||
|
||
logger := log.NewTestLogger(t) | ||
|
||
user := sshserver.NewTestUser("test") | ||
user.RandomPassword() | ||
|
||
config := config.SSHConfig{} | ||
structutils.Defaults(&config) | ||
config.ClientAliveInterval = 1 * time.Second | ||
srv := sshserver.NewTestServer( | ||
t, | ||
sshserver.NewTestAuthenticationHandler( | ||
sshserver.NewTestHandler(), | ||
user, | ||
), | ||
logger, | ||
&config, | ||
) | ||
srv.Start() | ||
defer srv.Stop(1 * time.Minute) | ||
|
||
|
||
hostkey, err := ssh.ParsePrivateKey([]byte(srv.GetHostKey())) | ||
if err != nil { | ||
t.Fatal("Failed to parse private key") | ||
} | ||
sshConfig := &ssh.ClientConfig{ | ||
User: user.Username(), | ||
Auth: user.GetAuthMethods(), | ||
} | ||
sshConfig.HostKeyCallback = func(hostname string, remote net.Addr, key ssh.PublicKey) error { | ||
if bytes.Equal(key.Marshal(), hostkey.PublicKey().Marshal()) { | ||
return nil | ||
} | ||
return fmt.Errorf("invalid host") | ||
} | ||
tcpConnection, err := net.Dial("tcp", srv.GetListen()) | ||
if err != nil { | ||
t.Fatal("tcp handshake failed (%w)", err) | ||
} | ||
connection, _, globalReq, err := ssh.NewClientConn(tcpConnection, srv.GetListen(), sshConfig) | ||
if err != nil { | ||
t.Fatal("ssh handshake failed (%w)", err) | ||
} | ||
defer func() { | ||
_ = connection.Close() | ||
}() | ||
|
||
req := <-globalReq | ||
err = req.Reply(false, nil) | ||
if err != nil { | ||
t.Fatal("Failed to respond to first request") | ||
} | ||
recv1 := time.Now() | ||
|
||
req2 := <-globalReq | ||
recv2 := time.Now() | ||
err = req.Reply(false, nil) | ||
if err != nil { | ||
t.Fatal("Failed to respond to second request") | ||
} | ||
|
||
if req.Type != "[email protected]" { | ||
t.Fatal("Expected keepalive request", req.Type) | ||
} | ||
if req2.Type != "[email protected]" { | ||
t.Fatal("Expected keepalive request", req.Type) | ||
} | ||
|
||
elapsed := recv2.Sub(recv1) | ||
|
||
if elapsed > 2 * time.Second { | ||
t.Fatal("Received keepalive in too big of an interval", elapsed) | ||
} | ||
if elapsed < time.Second / 2 { | ||
t.Fatal("Received keepalive in too short of an interval", elapsed) | ||
} | ||
} | ||
|
||
//endregion | ||
|
||
//region Helper | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ import ( | |
"net" | ||
"strings" | ||
"sync" | ||
"time" | ||
|
||
"github.com/containerssh/libcontainerssh/config" | ||
ssh2 "github.com/containerssh/libcontainerssh/internal/ssh" | ||
|
@@ -482,6 +483,35 @@ func (s *serverImpl) handleConnection(conn net.Conn) { | |
sshShutdownHandlerID := fmt.Sprintf("ssh-%s", connectionID) | ||
s.lock.Unlock() | ||
|
||
if s.cfg.ClientAliveInterval > 0 { | ||
go func() { | ||
missedAlives := 0 | ||
for { | ||
time.Sleep(s.cfg.ClientAliveInterval) | ||
|
||
_, _, err := sshConn.SendRequest("[email protected]", true, []byte{}) | ||
|
||
if err != nil { | ||
missedAlives++ | ||
|
||
logger.Debug( | ||
messageCodes.Wrap( | ||
err, | ||
messageCodes.ESSHKeepAliveFailed, | ||
"Keepalive error", | ||
), | ||
) | ||
if missedAlives >= s.cfg.ClientAliveCountMax { | ||
_ = sshConn.Close() | ||
break | ||
} | ||
continue | ||
} | ||
missedAlives = 0 | ||
} | ||
}() | ||
} | ||
|
||
go func() { | ||
_ = sshConn.Wait() | ||
logger.Debug(messageCodes.NewMessage(messageCodes.MSSHDisconnected, "Client disconnected")) | ||
|
@@ -498,6 +528,44 @@ func (s *serverImpl) handleConnection(conn net.Conn) { | |
go s.handleGlobalRequests(globalRequests, handlerSSHConnection, logger) | ||
} | ||
|
||
func (s *serverImpl) handleKeepAliveRequest(req *ssh.Request, logger log.Logger) { | ||
if req.WantReply { | ||
if err := req.Reply(false, []byte{}); err != nil { | ||
logger.Debug( | ||
messageCodes.Wrap( | ||
err, | ||
messageCodes.ESSHReplyFailed, | ||
"failed to send reply to global request type %s", | ||
req.Type, | ||
), | ||
) | ||
} | ||
} | ||
} | ||
|
||
func (s *serverImpl) handleUnknownGlobalRequest(req *ssh.Request, requestID uint64, connection SSHConnectionHandler, logger log.Logger) { | ||
logger.Debug( | ||
messageCodes.NewMessage(messageCodes.ESSHUnsupportedGlobalRequest, "Unsupported global request").Label( | ||
"type", | ||
req.Type, | ||
), | ||
) | ||
|
||
connection.OnUnsupportedGlobalRequest(requestID, req.Type, req.Payload) | ||
if req.WantReply { | ||
if err := req.Reply(false, []byte("request type not supported")); err != nil { | ||
logger.Debug( | ||
messageCodes.Wrap( | ||
err, | ||
messageCodes.ESSHReplyFailed, | ||
"failed to send reply to global request type %s", | ||
req.Type, | ||
), | ||
) | ||
} | ||
} | ||
} | ||
|
||
func (s *serverImpl) handleGlobalRequests( | ||
requests <-chan *ssh.Request, | ||
connection SSHConnectionHandler, | ||
|
@@ -510,24 +578,12 @@ func (s *serverImpl) handleGlobalRequests( | |
} | ||
requestID := s.nextGlobalRequestID | ||
s.nextGlobalRequestID++ | ||
logger.Debug( | ||
messageCodes.NewMessage(messageCodes.ESSHUnsupportedGlobalRequest, "Unsupported global request").Label( | ||
"type", | ||
request.Type, | ||
), | ||
) | ||
connection.OnUnsupportedGlobalRequest(requestID, request.Type, request.Payload) | ||
if request.WantReply { | ||
if err := request.Reply(false, []byte("request type not supported")); err != nil { | ||
logger.Debug( | ||
messageCodes.Wrap( | ||
err, | ||
messageCodes.ESSHReplyFailed, | ||
"failed to send reply to global request type %s", | ||
request.Type, | ||
), | ||
) | ||
} | ||
|
||
switch request.Type { | ||
case "[email protected]": | ||
s.handleKeepAliveRequest(request, logger) | ||
default: | ||
s.handleUnknownGlobalRequest(request, requestID, connection, logger) | ||
} | ||
} | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ func TestProperShutdown(t *testing.T) { | |
user, | ||
), | ||
logger, | ||
nil, | ||
) | ||
testServer.Start() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters