Skip to content

Commit

Permalink
Merge pull request #123 from gliderlabs/optimize-add-host-key
Browse files Browse the repository at this point in the history
Update AddHostKey to avoid always appending
  • Loading branch information
belak authored Oct 23, 2019
2 parents 63518b5 + 1db07d8 commit 59d6e45
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 11 deletions.
38 changes: 37 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,17 @@ type Server struct {
RequestHandlers map[string]RequestHandler

listenerWg sync.WaitGroup
mu sync.Mutex
mu sync.RWMutex
listeners map[net.Listener]struct{}
conns map[*gossh.ServerConn]struct{}
connWg sync.WaitGroup
doneChan chan struct{}
}

func (srv *Server) ensureHostSigner() error {
srv.mu.Lock()
defer srv.mu.Unlock()

if len(srv.HostSigners) == 0 {
signer, err := generateSigner()
if err != nil {
Expand All @@ -79,6 +82,7 @@ func (srv *Server) ensureHostSigner() error {
func (srv *Server) ensureHandlers() {
srv.mu.Lock()
defer srv.mu.Unlock()

if srv.RequestHandlers == nil {
srv.RequestHandlers = map[string]RequestHandler{}
for k, v := range DefaultRequestHandlers {
Expand All @@ -94,6 +98,9 @@ func (srv *Server) ensureHandlers() {
}

func (srv *Server) config(ctx Context) *gossh.ServerConfig {
srv.mu.RLock()
defer srv.mu.RUnlock()

var config *gossh.ServerConfig
if srv.ServerConfigCallback == nil {
config = &gossh.ServerConfig{}
Expand Down Expand Up @@ -142,6 +149,9 @@ func (srv *Server) config(ctx Context) *gossh.ServerConfig {

// Handle sets the Handler for the server.
func (srv *Server) Handle(fn Handler) {
srv.mu.Lock()
defer srv.mu.Unlock()

srv.Handler = fn
}

Expand All @@ -153,6 +163,7 @@ func (srv *Server) Handle(fn Handler) {
func (srv *Server) Close() error {
srv.mu.Lock()
defer srv.mu.Unlock()

srv.closeDoneChanLocked()
err := srv.closeListenersLocked()
for c := range srv.conns {
Expand Down Expand Up @@ -313,19 +324,42 @@ func (srv *Server) ListenAndServe() error {
// with the same algorithm, it is overwritten. Each server config must have at
// least one host key.
func (srv *Server) AddHostKey(key Signer) {
srv.mu.Lock()
defer srv.mu.Unlock()

// these are later added via AddHostKey on ServerConfig, which performs the
// check for one of every algorithm.

// This check is based on the AddHostKey method from the x/crypto/ssh
// library. This allows us to only keep one active key for each type on a
// server at once. So, if you're dynamically updating keys at runtime, this
// list will not keep growing.
for i, k := range srv.HostSigners {
if k.PublicKey().Type() == key.PublicKey().Type() {
srv.HostSigners[i] = key
return
}
}

srv.HostSigners = append(srv.HostSigners, key)
}

// SetOption runs a functional option against the server.
func (srv *Server) SetOption(option Option) error {
// NOTE: there is a potential race here for any option that doesn't call an
// internal method. We can't actually lock here because if something calls
// (as an example) AddHostKey, it will deadlock.

//srv.mu.Lock()
//defer srv.mu.Unlock()

return option(srv)
}

func (srv *Server) getDoneChan() <-chan struct{} {
srv.mu.Lock()
defer srv.mu.Unlock()

return srv.getDoneChanLocked()
}

Expand Down Expand Up @@ -362,6 +396,7 @@ func (srv *Server) closeListenersLocked() error {
func (srv *Server) trackListener(ln net.Listener, add bool) {
srv.mu.Lock()
defer srv.mu.Unlock()

if srv.listeners == nil {
srv.listeners = make(map[net.Listener]struct{})
}
Expand All @@ -382,6 +417,7 @@ func (srv *Server) trackListener(ln net.Listener, add bool) {
func (srv *Server) trackConn(c *gossh.ServerConn, add bool) {
srv.mu.Lock()
defer srv.mu.Unlock()

if srv.conns == nil {
srv.conns = make(map[*gossh.ServerConn]struct{})
}
Expand Down
20 changes: 20 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,26 @@ import (
"time"
)

func TestAddHostKey(t *testing.T) {
s := Server{}
signer, err := generateSigner()
if err != nil {
t.Fatal(err)
}
s.AddHostKey(signer)
if len(s.HostSigners) != 1 {
t.Fatal("Key was not properly added")
}
signer, err = generateSigner()
if err != nil {
t.Fatal(err)
}
s.AddHostKey(signer)
if len(s.HostSigners) != 1 {
t.Fatal("Key was not properly replaced")
}
}

func TestServerShutdown(t *testing.T) {
l := newLocalListener()
testBytes := []byte("Hello world\n")
Expand Down
46 changes: 36 additions & 10 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,20 +289,40 @@ func TestPtyResize(t *testing.T) {
func TestSignals(t *testing.T) {
t.Parallel()

// errChan lets us get errors back from the session
errChan := make(chan error, 5)

// doneChan lets us specify that we should exit.
doneChan := make(chan interface{})

session, _, cleanup := newTestSession(t, &Server{
Handler: func(s Session) {
signals := make(chan Signal)
// We need to use a buffered channel here, otherwise it's possible for the
// second call to Signal to get discarded.
signals := make(chan Signal, 2)
s.Signals(signals)
if sig := <-signals; sig != SIGINT {
t.Fatalf("expected signal %v but got %v", SIGINT, sig)

select {
case sig := <-signals:
if sig != SIGINT {
errChan <- fmt.Errorf("expected signal %v but got %v", SIGINT, sig)
return
}
case <-doneChan:
errChan <- fmt.Errorf("Unexpected done")
return
}
exiter := make(chan bool)
go func() {
if sig := <-signals; sig == SIGKILL {
close(exiter)

select {
case sig := <-signals:
if sig != SIGKILL {
errChan <- fmt.Errorf("expected signal %v but got %v", SIGKILL, sig)
return
}
}()
<-exiter
case <-doneChan:
errChan <- fmt.Errorf("Unexpected done")
return
}
},
}, nil)
defer cleanup()
Expand All @@ -312,7 +332,13 @@ func TestSignals(t *testing.T) {
session.Signal(gossh.SIGKILL)
}()

err := session.Run("")
go func() {
errChan <- session.Run("")
}()

err := <-errChan
close(doneChan)

if err != nil {
t.Fatalf("expected nil but got %v", err)
}
Expand Down

0 comments on commit 59d6e45

Please sign in to comment.