diff --git a/cmd/guac/main.go b/cmd/guac/main.go index 12ba82c..bb6af00 100644 --- a/cmd/guac/main.go +++ b/cmd/guac/main.go @@ -14,7 +14,7 @@ func main() { fs := http.FileServer(http.Dir(".")) - servlet := guac.NewHTTPTunnelServlet(DemoDoConnect) + servlet := guac.NewServer(DemoDoConnect) mux := http.NewServeMux() mux.Handle("/tunnel", servlet) diff --git a/cmd/wwt/server.go b/cmd/wwt/server.go index 6ad594e..70c4b58 100644 --- a/cmd/wwt/server.go +++ b/cmd/wwt/server.go @@ -15,7 +15,7 @@ func main() { fs := http.FileServer(http.Dir(".")) - servlet := guac.NewHTTPTunnelServlet(DemoDoConnect) + servlet := guac.NewServer(DemoDoConnect) wsServer := guac.NewWebsocketServer(DemoDoConnect) mux := http.NewServeMux() @@ -104,7 +104,7 @@ func DemoDoConnect(request *http.Request) (guac.Tunnel, error) { info.AudioMimetypes = []string{"audio/L16", "rate=44100", "channels=2"} logrus.Debug("Connecting to guacd") - socket, err := guac.NewInetSocket("127.0.0.1", 4822) + stream, err := guac.NewInetSocket("127.0.0.1", 4822) if err != nil { return nil, err } @@ -112,10 +112,10 @@ func DemoDoConnect(request *http.Request) (guac.Tunnel, error) { if request.URL.Query().Get("uuid") != "" { config.ConnectionID = request.URL.Query().Get("uuid") } - err = guac.ConfigureSocket(socket, config, info) + err = guac.ConfigureSocket(stream, config, info) if err != nil { return nil, err } logrus.Debug("Socket configured") - return guac.NewSimpleTunnel(socket), nil + return guac.NewSimpleTunnel(stream), nil } diff --git a/reentrantLock.go b/counted_lock.go similarity index 100% rename from reentrantLock.go rename to counted_lock.go diff --git a/httpTunnelServer.go b/server.go similarity index 83% rename from httpTunnelServer.go rename to server.go index f3619e1..443f37a 100644 --- a/httpTunnelServer.go +++ b/server.go @@ -16,16 +16,16 @@ const ( uuidLength = 36 ) -// HttpTunnelServer uses HTTP requests to talk to guacd -type HttpTunnelServer struct { - tunnels HttpTunnelMap +// Server uses HTTP requests to talk to guacd +type Server struct { + tunnels TunnelMap connect func(*http.Request) (Tunnel, error) } -// NewHTTPTunnelServlet constructor -func NewHTTPTunnelServlet(connect func(r *http.Request) (Tunnel, error)) *HttpTunnelServer { - return &HttpTunnelServer{ - tunnels: NewHttpTunnelMap(), +// NewServer constructor +func NewServer(connect func(r *http.Request) (Tunnel, error)) *Server { + return &Server{ + tunnels: NewTunnelMap(), connect: connect, } } @@ -34,7 +34,7 @@ func NewHTTPTunnelServlet(connect func(r *http.Request) (Tunnel, error)) *HttpTu * Registers the given tunnel such that future read/write requests to that * tunnel will be properly directed. */ -func (s *HttpTunnelServer) registerTunnel(tunnel Tunnel) { +func (s *Server) registerTunnel(tunnel Tunnel) { s.tunnels.Put(tunnel.GetUUID(), tunnel) logger.Debugf("Registered tunnel \"%v\".", tunnel.GetUUID()) } @@ -43,7 +43,7 @@ func (s *HttpTunnelServer) registerTunnel(tunnel Tunnel) { * Deregisters the given tunnel such that future read/write requests to * that tunnel will be rejected. */ -func (s *HttpTunnelServer) deregisterTunnel(tunnel Tunnel) { +func (s *Server) deregisterTunnel(tunnel Tunnel) { s.tunnels.Remove(tunnel.GetUUID()) logger.Debugf("Deregistered tunnel \"%v\".", tunnel.GetUUID()) } @@ -52,7 +52,7 @@ func (s *HttpTunnelServer) deregisterTunnel(tunnel Tunnel) { * Returns the tunnel with the given UUID, if it has been registered with * registerTunnel() and not yet deregistered with deregisterTunnel(). */ -func (s *HttpTunnelServer) getTunnel(tunnelUUID string) (ret Tunnel, err error) { +func (s *Server) getTunnel(tunnelUUID string) (ret Tunnel, err error) { var ok bool ret, ok = s.tunnels.Get(tunnelUUID) @@ -62,13 +62,13 @@ func (s *HttpTunnelServer) getTunnel(tunnelUUID string) (ret Tunnel, err error) return } -func (s *HttpTunnelServer) sendError(response http.ResponseWriter, guacStatus Status, message string) { +func (s *Server) sendError(response http.ResponseWriter, guacStatus Status, message string) { response.Header().Set("Guacamole-Status-Code", fmt.Sprintf("%v", guacStatus.GetGuacamoleStatusCode())) response.Header().Set("Guacamole-Error-Message", message) response.WriteHeader(guacStatus.GetHTTPStatusCode()) } -func (s *HttpTunnelServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { err := s.handleTunnelRequestCore(w, r) if err == nil { return @@ -86,7 +86,7 @@ func (s *HttpTunnelServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } -func (s *HttpTunnelServer) handleTunnelRequestCore(response http.ResponseWriter, request *http.Request) (err error) { +func (s *Server) handleTunnelRequestCore(response http.ResponseWriter, request *http.Request) (err error) { query := request.URL.RawQuery if len(query) == 0 { return ErrClient.NewError("No query string provided.") @@ -133,7 +133,7 @@ func (s *HttpTunnelServer) handleTunnelRequestCore(response http.ResponseWriter, } // doRead takes guacd messages and sends them in the response -func (s *HttpTunnelServer) doRead(response http.ResponseWriter, request *http.Request, tunnelUUID string) error { +func (s *Server) doRead(response http.ResponseWriter, request *http.Request, tunnelUUID string) error { tunnel, err := s.getTunnel(tunnelUUID) if err != nil { return err @@ -181,7 +181,7 @@ func (s *HttpTunnelServer) doRead(response http.ResponseWriter, request *http.Re } // writeSome drains the guacd buffer holding instructions into the response -func (s *HttpTunnelServer) writeSome(response http.ResponseWriter, guacd InstructionReader, tunnel Tunnel) (err error) { +func (s *Server) writeSome(response http.ResponseWriter, guacd InstructionReader, tunnel Tunnel) (err error) { var message []byte for { @@ -225,7 +225,7 @@ func (s *HttpTunnelServer) writeSome(response http.ResponseWriter, guacd Instruc } // doWrite takes data from the request and sends it to guacd -func (s *HttpTunnelServer) doWrite(response http.ResponseWriter, request *http.Request, tunnelUUID string) error { +func (s *Server) doWrite(response http.ResponseWriter, request *http.Request, tunnelUUID string) error { tunnel, err := s.getTunnel(tunnelUUID) if err != nil { return err diff --git a/socket.go b/socket.go index ab126e9..51cc1dd 100644 --- a/socket.go +++ b/socket.go @@ -7,9 +7,7 @@ import ( "time" ) -// SocketTimeout stream timeout setting -// * The number of milliseconds to wait for data on the TCP stream before -// * timing out. +// SocketTimeout is the longest time a read or write from Guacamole may take const SocketTimeout = 15 * time.Second // NewInetSocket connects to Guacamole via non-tls dialer diff --git a/stream.go b/stream.go index 59eae9b..f454519 100644 --- a/stream.go +++ b/stream.go @@ -128,7 +128,7 @@ func (s *Stream) ReadSome() (instruction []byte, err error) { } n, err = s.conn.Read(s.buffer[len(s.buffer):cap(s.buffer)]) - if err != nil { + if err != nil && n == 0{ switch err.(type) { case net.Error: ex := err.(net.Error) diff --git a/stream_test.go b/stream_test.go index f426d4e..713f511 100644 --- a/stream_test.go +++ b/stream_test.go @@ -22,9 +22,13 @@ func TestInstructionReader_ReadSome(t *testing.T) { if !bytes.Equal(ins, []byte("4.copy,2.ab;")) { t.Error("Unexpected bytes returned") } + if !stream.Available() { + t.Error("Stream has more available but returned false") + } - // Read some more to simulate data being fragmented - copy(conn.ToRead, ",2.ab;") + // Read the rest of the fragmented instruction + n := copy(conn.ToRead, ",2.ab;") + conn.ToRead = conn.ToRead[:n] conn.HasRead = false ins, err = stream.ReadSome() @@ -34,6 +38,9 @@ func TestInstructionReader_ReadSome(t *testing.T) { if !bytes.Equal(ins, []byte("4.copy,2.ab;")) { t.Error("Unexpected bytes returned") } + if stream.Available() { + t.Error("Stream thinks it has more available but doesn't") + } } func TestInstructionReader_Flush(t *testing.T) { @@ -50,6 +57,9 @@ func TestInstructionReader_Flush(t *testing.T) { if s.buffer[0] != '3' && s.buffer[1] != '4' { t.Error("Unexpected buffer contents:", string(s.buffer[:2])) } + if len(s.buffer) != 2 { + t.Error("Unexpected length", len(s.buffer)) + } } type fakeConn struct { diff --git a/httpTunnel.go b/tunnel_map.go similarity index 63% rename from httpTunnel.go rename to tunnel_map.go index 76264db..7b56ab3 100644 --- a/httpTunnel.go +++ b/tunnel_map.go @@ -6,7 +6,7 @@ import ( "time" ) -/*HttpTunnel +/*LastAccessedTunnel * Tracks the last time a particular Tunnel was accessed. This * information is not necessary for tunnels associated with WebSocket * connections, as each WebSocket connection has its own read thread which @@ -16,23 +16,23 @@ import ( * multiple requests, tracking of activity on the tunnel must be performed * independently of the HTTP requests. */ -type HttpTunnel struct { +type LastAccessedTunnel struct { Tunnel lastAccessedTime time.Time } -func NewHttpTunnel(tunnel Tunnel) (ret HttpTunnel) { +func NewHttpTunnel(tunnel Tunnel) (ret LastAccessedTunnel) { ret.Tunnel = tunnel ret.Access() return } -func (opt *HttpTunnel) Access() { - opt.lastAccessedTime = time.Now() +func (t *LastAccessedTunnel) Access() { + t.lastAccessedTime = time.Now() } -func (opt *HttpTunnel) GetLastAccessedTime() time.Time { - return opt.lastAccessedTime +func (t *LastAccessedTunnel) GetLastAccessedTime() time.Time { + return t.lastAccessedTime } /*TunnelTimeout * @@ -43,13 +43,13 @@ func (opt *HttpTunnel) GetLastAccessedTime() time.Time { */ const TunnelTimeout = 15 * time.Second -/*HttpTunnelMap * +/*TunnelMap * * Map-style object which tracks in-use HTTP tunnels, automatically removing * and closing tunnels which have not been used recently. This class is - * intended for use only within the HttpTunnelServer implementation, + * intended for use only within the Server implementation, * and has no real utility outside that implementation. */ -type HttpTunnelMap struct { +type TunnelMap struct { /** * Executor service which runs the periodic tunnel timeout task. */ @@ -64,18 +64,18 @@ type HttpTunnelMap struct { /** * Map of all tunnels that are using HTTP, indexed by tunnel UUID. */ - tunnelMap map[string]*HttpTunnel + tunnelMap map[string]*LastAccessedTunnel tunnelMapLock sync.RWMutex } -/*NewHttpTunnelMap * - * Creates a new HttpTunnelMap which automatically closes and +/*NewTunnelMap * + * Creates a new TunnelMap which automatically closes and * removes HTTP tunnels which are no longer in use. */ -func NewHttpTunnelMap() (ret HttpTunnelMap) { +func NewTunnelMap() (ret TunnelMap) { ret.executor = make([]*time.Ticker, 0, 1) - ret.tunnelMap = make(map[string]*HttpTunnel) + ret.tunnelMap = make(map[string]*LastAccessedTunnel) ret.tunnelTimeout = TunnelTimeout @@ -83,49 +83,49 @@ func NewHttpTunnelMap() (ret HttpTunnelMap) { return } -func (opt *HttpTunnelMap) startScheduled(count int32, timeout time.Duration) { - for i := int32(len(opt.executor)); i < count; i++ { +func (m *TunnelMap) startScheduled(count int32, timeout time.Duration) { + for i := int32(len(m.executor)); i < count; i++ { tick := time.NewTicker(timeout) - go opt.tunnelTimeoutTask(tick.C) + go m.tunnelTimeoutTask(tick.C) - opt.executor = append(opt.executor, tick) + m.executor = append(m.executor, tick) } } -func (opt *HttpTunnelMap) tunnelTimeoutTask(c <-chan time.Time) { +func (m *TunnelMap) tunnelTimeoutTask(c <-chan time.Time) { for { _, ok := <-c if !ok { break } - opt.tunnelTimeoutTaskRun() + m.tunnelTimeoutTaskRun() } } -func (opt *HttpTunnelMap) tunnelTimeoutTaskRun() { +func (m *TunnelMap) tunnelTimeoutTaskRun() { // timeLine = Now() - tunnelTimeout - timeLine := time.Now().Add(0 - opt.tunnelTimeout) + timeLine := time.Now().Add(0 - m.tunnelTimeout) type pair struct { uuid string - tunnel *HttpTunnel + tunnel *LastAccessedTunnel } removeIDs := make([]pair, 0, 1) - opt.tunnelMapLock.RLock() - for uuid, tunnel := range opt.tunnelMap { + m.tunnelMapLock.RLock() + for uuid, tunnel := range m.tunnelMap { if tunnel.GetLastAccessedTime().Before(timeLine) { removeIDs = append(removeIDs, pair{uuid: uuid, tunnel: tunnel}) } } - opt.tunnelMapLock.RUnlock() + m.tunnelMapLock.RUnlock() for _, double := range removeIDs { logrus.Debugf("HTTP tunnel \"%v\" has timed out.", double.uuid) - opt.tunnelMapLock.Lock() - delete(opt.tunnelMap, double.uuid) - opt.tunnelMapLock.Unlock() + m.tunnelMapLock.Lock() + delete(m.tunnelMap, double.uuid) + m.tunnelMapLock.Unlock() if double.tunnel != nil { err := double.tunnel.Close() @@ -139,7 +139,7 @@ func (opt *HttpTunnelMap) tunnelTimeoutTaskRun() { /*Get * * Returns the Tunnel having the given UUID, wrapped within a - * HttpTunnel. If the no tunnel having the given UUID is + * LastAccessedTunnel. If the no tunnel having the given UUID is * available, null is returned. * * @param uuid @@ -147,15 +147,15 @@ func (opt *HttpTunnelMap) tunnelTimeoutTaskRun() { * * @return * The Tunnel having the given UUID, wrapped within a - * HttpTunnel, if such a tunnel exists, or null if there is no + * LastAccessedTunnel, if such a tunnel exists, or null if there is no * such tunnel. */ -func (opt *HttpTunnelMap) Get(uuid string) (tunnel *HttpTunnel, ok bool) { +func (m *TunnelMap) Get(uuid string) (tunnel *LastAccessedTunnel, ok bool) { // Update the last access time - opt.tunnelMapLock.RLock() - tunnel, ok = opt.tunnelMap[uuid] - opt.tunnelMapLock.RUnlock() + m.tunnelMapLock.RLock() + tunnel, ok = m.tunnelMap[uuid] + m.tunnelMapLock.RUnlock() if ok && tunnel != nil { tunnel.Access() @@ -179,36 +179,36 @@ func (opt *HttpTunnelMap) Get(uuid string) (tunnel *HttpTunnel, ok bool) { * The Tunnel being registered, its associated connection * having just been established via HTTP. */ -func (opt *HttpTunnelMap) Put(uuid string, tunnel Tunnel) { +func (m *TunnelMap) Put(uuid string, tunnel Tunnel) { one := NewHttpTunnel(tunnel) - opt.tunnelMapLock.Lock() - opt.tunnelMap[uuid] = &one - opt.tunnelMapLock.Unlock() + m.tunnelMapLock.Lock() + m.tunnelMap[uuid] = &one + m.tunnelMapLock.Unlock() } /*Remove * * Removes the Tunnel having the given UUID, if such a tunnel * exists. The original tunnel is returned wrapped within a - * HttpTunnel. + * LastAccessedTunnel. * * @param uuid * The UUID of the tunnel to remove (deregister). * * @return * The Tunnel having the given UUID, if such a tunnel exists, - * wrapped within a HttpTunnel, or null if no such tunnel + * wrapped within a LastAccessedTunnel, or null if no such tunnel * exists and no removal was performed. */ -func (opt *HttpTunnelMap) Remove(uuid string) (*HttpTunnel, bool) { +func (m *TunnelMap) Remove(uuid string) (*LastAccessedTunnel, bool) { - opt.tunnelMapLock.RLock() - v, ok := opt.tunnelMap[uuid] - opt.tunnelMapLock.RUnlock() + m.tunnelMapLock.RLock() + v, ok := m.tunnelMap[uuid] + m.tunnelMapLock.RUnlock() if ok { - opt.tunnelMapLock.Lock() - delete(opt.tunnelMap, uuid) - opt.tunnelMapLock.Unlock() + m.tunnelMapLock.Lock() + delete(m.tunnelMap, uuid) + m.tunnelMapLock.Unlock() } return v, ok } @@ -217,9 +217,9 @@ func (opt *HttpTunnelMap) Remove(uuid string) (*HttpTunnel, bool) { * Shuts down this tunnel map, disallowing future tunnels from being * registered and reclaiming any resources. */ -func (opt *HttpTunnelMap) Shutdown() { - for _, c := range opt.executor { +func (m *TunnelMap) Shutdown() { + for _, c := range m.executor { c.Stop() } - opt.executor = make([]*time.Ticker, 0, 1) + m.executor = make([]*time.Ticker, 0, 1) } diff --git a/wsServer.go b/ws_server.go similarity index 100% rename from wsServer.go rename to ws_server.go diff --git a/wsServer_test.go b/ws_server_test.go similarity index 100% rename from wsServer_test.go rename to ws_server_test.go