diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0e9179f --- /dev/null +++ b/.gitignore @@ -0,0 +1,29 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof + +*.sw? + +# Project specific +.uptodate diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f6a7ca6 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2016 Paul Bellamy + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..0985757 --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +# SSH + +Easier ssh server wrapper, net/http-style. + +Work-in-progress diff --git a/example/server.go b/example/server.go new file mode 100644 index 0000000..db1ced8 --- /dev/null +++ b/example/server.go @@ -0,0 +1,97 @@ +package main + +import ( + "flag" + "fmt" + "io/ioutil" + "log" + "net" + "os" + "os/exec" + "path/filepath" + + "github.com/paulbellamy/ssh" + gossh "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/terminal" +) + +func main() { + addr := flag.String("addr", ":12345", "address to listen on") + flag.Parse() + + // Configure the server + server := &ssh.Server{ + Addr: *addr, + Handler: ssh.HandlerFunc(handle), + + // ConnState specifies an optional callback function that is + // called when a client connection changes state. See the + // ConnState type and associated constants for details. + ConnState: func(conn net.Conn, state ssh.ConnState) { + log.Printf("[ConnState] %v: %s", conn.RemoteAddr(), state) + }, + + ServerConfig: &gossh.ServerConfig{ + PasswordCallback: func(c gossh.ConnMetadata, pass []byte) (*gossh.Permissions, error) { + // Should use constant-time compare (or better, salt+hash) in + // a production setting. + if c.User() == "testuser" && string(pass) == "tiger" { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %q", c.User()) + }, + }, + } + + // Generate a random key for now + tempDir, err := ioutil.TempDir("", "") + if err != nil { + log.Fatal(err) + } + defer os.RemoveAll(tempDir) + privKeyPath := filepath.Join(tempDir, "key") + out, err := exec.Command("ssh-keygen", "-f", privKeyPath, "-t", "rsa", "-N", "").CombinedOutput() + if err != nil { + log.Fatalf("Fail to generate private key: %v - %q", err, out) + } + privKeyFile, err := os.Open(privKeyPath) + if err != nil { + log.Fatal(err) + } + defer privKeyFile.Close() + server.AddHostKey(privKeyFile) + + // Start the server + log.Println("Listening on:", *addr) + log.Fatal(server.ListenAndServe()) +} + +func handle(p *ssh.Permissions, c ssh.Channel, r <-chan *ssh.Request) { + term := terminal.NewTerminal(c, "> ") + + go func() { + defer c.Close() + for { + line, err := term.ReadLine() + if err != nil { + break + } + fmt.Println(line) + } + }() + + for req := range r { + ok := false + switch req.Type { + case "shell": + ok = true + if len(req.Payload) > 0 { + // We don't accept any + // commands, only the + // default shell. + ok = false + } + } + req.Reply(ok, nil) + } +} diff --git a/handler.go b/handler.go new file mode 100644 index 0000000..b1067ef --- /dev/null +++ b/handler.go @@ -0,0 +1,23 @@ +package ssh + +import "golang.org/x/crypto/ssh" + +var DefaultHandler = HandlerFunc(func(p *Permissions, c Channel, r <-chan *Request) { + ssh.DiscardRequests(unwrapRequests(r)) +}) + +// Handler handles ssh connections +type Handler interface { + ServeSSH(*Permissions, Channel, <-chan *Request) // TODO: Finish filling this in +} + +// The HandlerFunc type is an adapter to allow the use of +// ordinary functions as handlers. If f is a function +// with the appropriate signature, HandlerFunc(f) is a +// Handler that calls f. +type HandlerFunc func(*Permissions, Channel, <-chan *Request) + +// ServeSSH calls f(p, c, r). +func (f HandlerFunc) ServeSSH(p *Permissions, c Channel, r <-chan *Request) { + f(p, c, r) +} diff --git a/listen.go b/listen.go new file mode 100644 index 0000000..421d051 --- /dev/null +++ b/listen.go @@ -0,0 +1,304 @@ +package ssh + +import ( + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "runtime" + "sync" + "time" + + "golang.org/x/crypto/ssh" +) + +var ( + // ServerContextKey is a context key. It can be used in handlers with + // context.WithValue to access the server that started the handler. The + // associated value will be of type *Server. + ServerContextKey = &contextKey{"http-server"} + + // LocalAddrContextKey is a context key. It can be used in handlers with + // context.WithValue to access the address the local address the connection + // arrived on. The associated value will be of type net.Addr. + LocalAddrContextKey = &contextKey{"local-addr"} + + // ErrServerHasNoHostKeys is returned when you try to start a server without adding any host keys + ErrServerHasNoHostKeys = errors.New("server has no host keys") +) + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. +type contextKey struct { + name string +} + +// Server is an SSH server +type Server struct { + Addr string // TCP address to listen on, ":ssh" if empty + Handler Handler // handler to invoke, DefaultHandler if nil + + // ConnState specifies an optional callback function that is + // called when a client connection changes state. See the + // ConnState type and associated constants for details. + ConnState func(net.Conn, ConnState) + + // ErrorLog specifies an optional logger for errors accepting + // connections and unexpected behavior from handlers. + // If nil, logging goes to os.Stderr via the log package's + // standard logger. + ErrorLog *log.Logger + + // HostKey count to track the number of added host keys. + // + // This sucks a bit, but there's no way to get the number of host keys from + // the ServerConfig directly. + HostKeyCount int + *ssh.ServerConfig +} + +func (srv *Server) AddHostKey(r io.Reader) error { + privateBytes, err := ioutil.ReadAll(r) + if err != nil { + return err + } + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + return err + } + srv.HostKeyCount++ + cfg := srv.ServerConfig + cfg.AddHostKey(private) + return nil +} + +// ListenAndServe listens on the TCP network address srv.Addr and then calls +// Serve to handle requests on incoming connections. Accepted connections are +// configured to enable TCP keep-alives. If srv.Addr is blank, ":ssh" is used. +// ListenAndServe always returns a non-nil error. +func (srv *Server) ListenAndServe() error { + addr := srv.Addr + if addr == "" { + addr = ":ssh" + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + return srv.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)}) +} + +// Serve accepts incoming connections on the Listener l, creating a new service +// goroutine for each. The service goroutines read requests and then call +// srv.Handler to reply to them. +// +// Serve always returns a non-nil error. +func (srv *Server) Serve(l net.Listener) error { + defer l.Close() + var tempDelay time.Duration + + if srv.HostKeyCount == 0 { + return ErrServerHasNoHostKeys + } + + baseCtx := context.Background() + ctx := context.WithValue(baseCtx, ServerContextKey, srv) + ctx = context.WithValue(ctx, LocalAddrContextKey, l.Addr()) + for { + rw, e := l.Accept() + if e != nil { + if ne, ok := e.(net.Error); ok && ne.Temporary() { + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 1 * time.Second; tempDelay > max { + tempDelay = max + } + srv.logf("ssh: Accept error: %v; retrying in %v", e, tempDelay) + time.Sleep(tempDelay) + continue + } + return e + } + tempDelay = 0 + c := srv.newConn(rw) + c.setState(c.rwc, StateNew) // before Serve can return + go c.serve(ctx) + } +} + +// debugServerConnections controls whether all server connections are wrapped +// with a verbose logging wrapper. +const debugServerConnections = false + +// Create new connection from rwc. +func (srv *Server) newConn(rwc net.Conn) *conn { + c := &conn{ + server: srv, + rwc: rwc, + } + if debugServerConnections { + c.rwc = newLoggingConn("server", c.rwc) + } + return c +} + +func (srv *Server) logf(format string, args ...interface{}) { + if srv.ErrorLog != nil { + srv.ErrorLog.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +// loggingConn is used for debugging. +type loggingConn struct { + name string + net.Conn +} + +var ( + uniqNameMu sync.Mutex + uniqNameNext = make(map[string]int) +) + +func newLoggingConn(baseName string, c net.Conn) net.Conn { + uniqNameMu.Lock() + defer uniqNameMu.Unlock() + uniqNameNext[baseName]++ + return &loggingConn{ + name: fmt.Sprintf("%s-%d", baseName, uniqNameNext[baseName]), + Conn: c, + } +} + +func (c *loggingConn) Write(p []byte) (n int, err error) { + log.Printf("%s.Write(%d) = ....", c.name, len(p)) + n, err = c.Conn.Write(p) + log.Printf("%s.Write(%d) = %d, %v", c.name, len(p), n, err) + return +} + +func (c *loggingConn) Read(p []byte) (n int, err error) { + log.Printf("%s.Read(%d) = ....", c.name, len(p)) + n, err = c.Conn.Read(p) + log.Printf("%s.Read(%d) = %d, %v", c.name, len(p), n, err) + return +} + +func (c *loggingConn) Close() (err error) { + log.Printf("%s.Close() = ...", c.name) + err = c.Conn.Close() + log.Printf("%s.Close() = %v", c.name, err) + return +} + +// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted +// connections. It's used by ListenAndServe and ListenAndServeTLS so +// dead TCP connections (e.g. closing laptop mid-download) eventually +// go away. +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + +type conn struct { + server *Server + rwc net.Conn + remoteAddr string +} + +func (c *conn) setState(nc net.Conn, state ConnState) { + if hook := c.server.ConnState; hook != nil { + hook(nc, state) + } +} + +// Serve a new connection +func (c *conn) serve(ctx context.Context) { + c.remoteAddr = c.rwc.RemoteAddr().String() + defer func() { + if err := recover(); err != nil { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + c.server.logf("ssh: panic serving %v: %v\n%s", c.remoteAddr, err, buf) + } + c.close() + }() + + // Before use, a handshake must be performed on the incoming net.Conn. + c.setState(c.rwc, StateHandshake) + sConn, chans, reqs, err := ssh.NewServerConn(c.rwc, c.server.ServerConfig) + if err != nil { + // TODO: handle error + if err != io.EOF { + c.server.logf("ssh: Handshake error: %v", err) + } + return + } + c.setState(c.rwc, StateActive) + + // The incoming Request channel must be serviced. + go ssh.DiscardRequests(reqs) // TODO: Handle these + + for newChannel := range chans { + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + // TODO: Figure out what to do here. Are there errors which are like EOF? + return + } + ctx, cancelCtx := context.WithCancel(ctx) + var permissions *Permissions + if sConn.Permissions != nil { + permissions = &Permissions{*sConn.Permissions} + } + go serverHandler{c.server}.ServeSSH( + permissions, + Channel{ + Channel: channel, + ctx: ctx, + cancelCtx: cancelCtx, + }, + wrapRequests(requests), + ) + } +} + +func (c *conn) close() { + c.rwc.Close() + c.setState(c.rwc, StateClosed) +} + +// serverHandler delegates to either the server's Handler or DefaultHandler, +// and also cancels the context when finished. +type serverHandler struct { + srv *Server +} + +func (sh serverHandler) ServeSSH(p *Permissions, c Channel, reqs <-chan *Request) { + handler := sh.srv.Handler + if handler == nil { + handler = DefaultHandler + } + handler.ServeSSH(p, c, reqs) + c.cancelCtx() +} diff --git a/state.go b/state.go new file mode 100644 index 0000000..7ef0bd2 --- /dev/null +++ b/state.go @@ -0,0 +1,38 @@ +package ssh + +// A ConnState represents the state of a client connection to a server. +// It's used by the optional Server.ConnState hook. +type ConnState int + +const ( + // StateNew represents a new connection that has newly connected. Connections + // begin at this state and then transition to either StateActive or + // StateClosed. + StateNew ConnState = iota + + // StateHandshake represents a connection that is currently performing the + // handshake. + StateHandshake + + // StateActive represents a connection that has read 1 or more + // bytes of a request. + StateActive + + // StateClosed represents a closed connection. + // This is a terminal state. + StateClosed +) + +func (s ConnState) String() string { + switch s { + case StateNew: + return "new" + case StateHandshake: + return "handshake" + case StateActive: + return "active" + case StateClosed: + return "closed" + } + return "unknown" +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..c5dca0f --- /dev/null +++ b/types.go @@ -0,0 +1,41 @@ +package ssh + +import ( + "context" + + "golang.org/x/crypto/ssh" +) + +type Permissions struct { + ssh.Permissions +} + +type Channel struct { + ssh.Channel + ctx context.Context + cancelCtx context.CancelFunc +} + +type Request struct { + ssh.Request +} + +func wrapRequests(requests <-chan *ssh.Request) <-chan *Request { + wrapped := make(chan *Request) + go func() { + for r := range requests { + wrapped <- &Request{*r} + } + }() + return wrapped +} + +func unwrapRequests(requests <-chan *Request) <-chan *ssh.Request { + unwrapped := make(chan *ssh.Request) + go func() { + for r := range requests { + unwrapped <- &r.Request + } + }() + return unwrapped +}