From 73f3aceca6fa573178e654b73f378a0d60f23465 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Thu, 9 Aug 2018 14:56:49 -0400 Subject: [PATCH] fixed a few reflector issues, added some tests --- Gopkg.lock | 9 ++ Gopkg.toml | 4 + cmd/reflector.go | 6 +- dht/dht_announce.go | 1 + reflector/client.go | 15 ++- reflector/client_test.go | 70 ------------- reflector/server.go | 208 +++++++++++++++++++++++++++++---------- reflector/server_test.go | 158 +++++++++++++++++++++++++++++ reflector/shared.go | 58 ----------- 9 files changed, 340 insertions(+), 189 deletions(-) delete mode 100644 reflector/client_test.go create mode 100644 reflector/server_test.go delete mode 100644 reflector/shared.go diff --git a/Gopkg.lock b/Gopkg.lock index 704b144..6b1c36b 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -255,6 +255,14 @@ revision = "0db1d5eae1116bf7c8ed96c6749acfbf4daaec3e" version = "v0.3.0" +[[projects]] + branch = "master" + digest = "1:d38c630298ac75e214f3caa5c240ea2923c7a089824d175ba4107d0650d56579" + name = "github.com/phayes/freeport" + packages = ["."] + pruneopts = "" + revision = "e27662a4a9d6b2083dfd7e7b5d0e30985daca925" + [[projects]] branch = "master" digest = "1:6ee36f2cea425916d81fdaaf983469fc18f91b3cf090cfe90fa0a9d85b8bfab7" @@ -398,6 +406,7 @@ "github.com/lbryio/lbry.go/stop", "github.com/lbryio/lbry.go/util", "github.com/lyoshenka/bencode", + "github.com/phayes/freeport", "github.com/sebdah/goldie", "github.com/sirupsen/logrus", "github.com/spf13/cast", diff --git a/Gopkg.toml b/Gopkg.toml index 72cdfd2..d0f1a96 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -29,3 +29,7 @@ [[constraint]] branch = "master" name = "github.com/uber-go/atomic" + +[[constraint]] + branch = "master" + name = "github.com/phayes/freeport" diff --git a/cmd/reflector.go b/cmd/reflector.go index 1b50fb2..75e1326 100644 --- a/cmd/reflector.go +++ b/cmd/reflector.go @@ -9,8 +9,8 @@ import ( "github.com/lbryio/reflector.go/db" "github.com/lbryio/reflector.go/reflector" "github.com/lbryio/reflector.go/store" - log "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" "github.com/spf13/cobra" ) @@ -26,7 +26,9 @@ func init() { func reflectorCmd(cmd *cobra.Command, args []string) { db := new(db.SQL) err := db.Connect(globalConfig.DBConn) - checkErr(err) + if err != nil { + log.Fatal(err) + } s3 := store.NewS3BlobStore(globalConfig.AwsID, globalConfig.AwsSecret, globalConfig.BucketRegion, globalConfig.BucketName) combo := store.NewDBBackedS3Store(s3, db) diff --git a/dht/dht_announce.go b/dht/dht_announce.go index 19ecc75..645c180 100644 --- a/dht/dht_announce.go +++ b/dht/dht_announce.go @@ -9,6 +9,7 @@ import ( "github.com/lbryio/lbry.go/errors" "github.com/lbryio/reflector.go/dht/bits" + "golang.org/x/time/rate" ) diff --git a/reflector/client.go b/reflector/client.go index b26d51d..6ea6e55 100644 --- a/reflector/client.go +++ b/reflector/client.go @@ -9,6 +9,9 @@ import ( log "github.com/sirupsen/logrus" ) +// ErrBlobExists is a default error for when a blob already exists on the reflector server. +var ErrBlobExists = errors.Base("blob exists on server") + // Client is an instance of a client connected to a server. type Client struct { conn net.Conn @@ -18,7 +21,7 @@ type Client struct { // Connect connects to a specific clients and errors if it cannot be contacted. func (c *Client) Connect(address string) error { var err error - c.conn, err = net.Dial("tcp", address) + c.conn, err = net.Dial(network, address) if err != nil { return err } @@ -38,8 +41,10 @@ func (c *Client) SendBlob(blob []byte) error { return errors.Err("not connected") } - if len(blob) != maxBlobSize { - return errors.Err("blob must be exactly " + strconv.Itoa(maxBlobSize) + " bytes") + if len(blob) > maxBlobSize { + return errors.Err("blob must be at most " + strconv.Itoa(maxBlobSize) + " bytes") + } else if len(blob) == 0 { + return errors.Err("blob is empty") } blobHash := getBlobHash(blob) @@ -50,6 +55,7 @@ func (c *Client) SendBlob(blob []byte) error { if err != nil { return err } + _, err = c.conn.Write(sendRequest) if err != nil { return err @@ -102,8 +108,7 @@ func (c *Client) doHandshake(version int) error { } var resp handshakeRequestResponse - dec := json.NewDecoder(c.conn) - err = dec.Decode(&resp) + err = json.NewDecoder(c.conn).Decode(&resp) if err != nil { return err } else if resp.Version != version { diff --git a/reflector/client_test.go b/reflector/client_test.go deleted file mode 100644 index 63c01d2..0000000 --- a/reflector/client_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package reflector - -import ( - "crypto/rand" - "io/ioutil" - "os" - "strconv" - "testing" - - "github.com/lbryio/reflector.go/store" - log "github.com/sirupsen/logrus" -) - -var address = "localhost:" + strconv.Itoa(DefaultPort) - -func TestMain(m *testing.M) { - dir, err := ioutil.TempDir("", "reflector_client_test") - if err != nil { - log.Panic("could not create temp directory - ", err) - } - defer func(directory string) { - err := os.RemoveAll(dir) - if err != nil { - log.Panic("error removing files and directory - ", err) - } - }(dir) - - ms := store.MemoryBlobStore{} - s := NewServer(&ms) - go func() { - err := s.Start(address) - if err != nil { - log.Panic("error starting up reflector server - ", err) - } - }() - - os.Exit(m.Run()) -} - -func TestNotConnected(t *testing.T) { - c := Client{} - err := c.SendBlob([]byte{}) - if err == nil { - t.Error("client should error if it is not connected") - } -} - -func TestSmallBlob(t *testing.T) { - c := Client{} - err := c.Connect(address) - if err != nil { - t.Error("error connecting client to server - ", err) - } - - err = c.SendBlob([]byte{}) - if err == nil { - t.Error("client should error if blob is empty") - } - - blob := make([]byte, 1000) - _, err = rand.Read(blob) - if err != nil { - t.Error("failed to make random blob") - } - - err = c.SendBlob([]byte{}) - if err == nil { - t.Error("client should error if blob is the wrong size") - } -} diff --git a/reflector/server.go b/reflector/server.go index cd3a38a..0bc7aa6 100644 --- a/reflector/server.go +++ b/reflector/server.go @@ -2,31 +2,47 @@ package reflector import ( "bufio" + "crypto/sha512" + "encoding/hex" "encoding/json" "io" "net" "strconv" + "time" + + "github.com/lbryio/reflector.go/store" "github.com/lbryio/lbry.go/errors" "github.com/lbryio/lbry.go/stop" - "github.com/lbryio/reflector.go/store" log "github.com/sirupsen/logrus" ) +const ( + // DefaultPort is the port the reflector server listens on if not passed in. + DefaultPort = 5566 + // DefaultTimeout is the default timeout to read or write the next message + DefaultTimeout = 5 * time.Second + + network = "tcp4" + protocolVersion1 = 0 + protocolVersion2 = 1 + maxBlobSize = 2 * 1024 * 1024 +) + // Server is and instance of the reflector server. It houses the blob store and listener. type Server struct { - store store.BlobStore - closed bool - - grp *stop.Group + store store.BlobStore + timeout time.Duration // timeout to read or write next message + grp *stop.Group } // NewServer returns an initialized reflector server pointer. func NewServer(store store.BlobStore) *Server { return &Server{ - store: store, - grp: stop.New(), + store: store, + grp: stop.New(), + timeout: DefaultTimeout, } } @@ -37,16 +53,23 @@ func (s *Server) Shutdown() { log.Debug("reflector server stopped") } -//Start starts the server listener to handle connections. +//Start starts the server to handle connections. func (s *Server) Start(address string) error { - //ToDo - We should make this DRY as it is the same code in both servers. log.Println("reflector listening on " + address) - l, err := net.Listen("tcp4", address) + l, err := net.Listen(network, address) if err != nil { - return err + return errors.Err(err) } - go s.listenForShutdown(l) + s.grp.Add(1) + go func() { + <-s.grp.Ch() + err := l.Close() + if err != nil { + log.Error(errors.Prefix("closing listener", err)) + } + s.grp.Done() + }() s.grp.Add(1) go func() { @@ -57,20 +80,11 @@ func (s *Server) Start(address string) error { return nil } -func (s *Server) listenForShutdown(listener net.Listener) { - <-s.grp.Ch() - s.closed = true - err := listener.Close() - if err != nil { - log.Error("error closing listener for peer server - ", err) - } -} - func (s *Server) listenAndServe(listener net.Listener) { for { conn, err := listener.Accept() if err != nil { - if s.closed { + if s.quitting() { return } log.Error(err) @@ -85,22 +99,32 @@ func (s *Server) listenAndServe(listener net.Listener) { } func (s *Server) handleConn(conn net.Conn) { + // all this stuff is to close the connections correctly when we're shutting down the server + connNeedsClosing := make(chan struct{}) defer func() { - if err := conn.Close(); err != nil { + close(connNeedsClosing) + }() + s.grp.Add(1) + go func() { + defer s.grp.Done() + select { + case <-connNeedsClosing: + case <-s.grp.Ch(): + } + err := conn.Close() + if err != nil { log.Error(errors.Prefix("closing peer conn", err)) } }() - // TODO: connection should time out eventually - err := s.doHandshake(conn) if err != nil { - if err == io.EOF { + if err == io.EOF || s.quitting() { return } err := s.doError(conn, err) if err != nil { - log.Error("error sending error response to reflector client connection - ", err) + log.Error(errors.Prefix("sending handshake error", err)) } return } @@ -108,11 +132,12 @@ func (s *Server) handleConn(conn net.Conn) { for { err = s.receiveBlob(conn) if err != nil { - if err != io.EOF { - err := s.doError(conn, err) - if err != nil { - log.Error("error sending error response for receiving a blob to reflector client connection - ", err) - } + if err == io.EOF || s.quitting() { + return + } + err := s.doError(conn, err) + if err != nil { + log.Error(errors.Prefix("sending blob receive error", err)) } return } @@ -120,7 +145,7 @@ func (s *Server) handleConn(conn net.Conn) { } func (s *Server) doError(conn net.Conn, err error) error { - log.Errorln(err) + log.Errorln(errors.FullTrace(err)) if e2, ok := err.(*json.SyntaxError); ok { log.Printf("syntax error at byte offset %d", e2.Offset) } @@ -128,8 +153,7 @@ func (s *Server) doError(conn net.Conn, err error) error { if err != nil { return err } - _, err = conn.Write(resp) - return err + return s.write(conn, resp) } func (s *Server) receiveBlob(conn net.Conn) error { @@ -165,8 +189,7 @@ func (s *Server) receiveBlob(conn net.Conn) error { return nil } - blob := make([]byte, blobSize) - _, err = io.ReadFull(bufio.NewReader(conn), blob) + blob, err := s.readRawBlob(conn, blobSize) if err != nil { return err } @@ -193,7 +216,7 @@ func (s *Server) receiveBlob(conn net.Conn) error { func (s *Server) doHandshake(conn net.Conn) error { var handshake handshakeRequestResponse - err := json.NewDecoder(conn).Decode(&handshake) + err := s.read(conn, &handshake) if err != nil { return err } else if handshake.Version != protocolVersion1 && handshake.Version != protocolVersion2 { @@ -205,29 +228,20 @@ func (s *Server) doHandshake(conn net.Conn) error { return err } - _, err = conn.Write(resp) - return err + return s.write(conn, resp) } func (s *Server) readBlobRequest(conn net.Conn) (int, string, bool, error) { var sendRequest sendBlobRequest - err := json.NewDecoder(conn).Decode(&sendRequest) + err := s.read(conn, &sendRequest) if err != nil { return 0, "", false, err } - if sendRequest.SdBlobHash != "" && sendRequest.BlobHash != "" { - return 0, "", false, errors.Err("invalid request") - } - var blobHash string var blobSize int isSdBlob := sendRequest.SdBlobHash != "" - if blobSize > maxBlobSize { - return 0, "", isSdBlob, errors.Err("blob cannot be more than " + strconv.Itoa(maxBlobSize) + " bytes") - } - if isSdBlob { blobSize = sendRequest.SdBlobSize blobHash = sendRequest.SdBlobHash @@ -236,6 +250,16 @@ func (s *Server) readBlobRequest(conn net.Conn) (int, string, bool, error) { blobHash = sendRequest.BlobHash } + if blobHash == "" { + return blobSize, blobHash, isSdBlob, errors.Err("blob hash is empty") + } + if blobSize > maxBlobSize { + return blobSize, blobHash, isSdBlob, errors.Err("blob must be at most " + strconv.Itoa(maxBlobSize) + " bytes") + } + if blobSize == 0 { + return blobSize, blobHash, isSdBlob, errors.Err("0-byte blob received") + } + return blobSize, blobHash, isSdBlob, nil } @@ -252,8 +276,7 @@ func (s *Server) sendBlobResponse(conn net.Conn, blobExists, isSdBlob bool) erro return err } - _, err = conn.Write(response) - return err + return s.write(conn, response) } func (s *Server) sendTransferResponse(conn net.Conn, receivedBlob, isSdBlob bool) error { @@ -262,7 +285,6 @@ func (s *Server) sendTransferResponse(conn net.Conn, receivedBlob, isSdBlob bool if isSdBlob { response, err = json.Marshal(sdBlobTransferResponse{ReceivedSdBlob: receivedBlob}) - } else { response, err = json.Marshal(blobTransferResponse{ReceivedBlob: receivedBlob}) } @@ -270,6 +292,84 @@ func (s *Server) sendTransferResponse(conn net.Conn, receivedBlob, isSdBlob bool return err } - _, err = conn.Write(response) - return err + return s.write(conn, response) +} + +func (s *Server) read(conn net.Conn, v interface{}) error { + err := conn.SetReadDeadline(time.Now().Add(s.timeout)) + if err != nil { + return errors.Err(err) + } + + return errors.Err(json.NewDecoder(conn).Decode(v)) +} + +func (s *Server) readRawBlob(conn net.Conn, blobSize int) ([]byte, error) { + err := conn.SetReadDeadline(time.Now().Add(s.timeout)) + if err != nil { + return nil, errors.Err(err) + } + + blob := make([]byte, blobSize) + _, err = io.ReadFull(bufio.NewReader(conn), blob) + return blob, errors.Err(err) +} + +func (s *Server) write(conn net.Conn, b []byte) error { + err := conn.SetWriteDeadline(time.Now().Add(s.timeout)) + if err != nil { + return errors.Err(err) + } + + n, err := conn.Write(b) + if err == nil && n != len(b) { + err = io.ErrShortWrite + } + return errors.Err(err) +} + +func (s *Server) quitting() bool { + select { + case <-s.grp.Ch(): + return true + default: + return false + } +} + +func getBlobHash(blob []byte) string { + hashBytes := sha512.Sum384(blob) + return hex.EncodeToString(hashBytes[:]) +} + +type errorResponse struct { + Error string `json:"error"` +} + +type handshakeRequestResponse struct { + Version int `json:"version"` +} + +type sendBlobRequest struct { + BlobHash string `json:"blob_hash,omitempty"` + BlobSize int `json:"blob_size,omitempty"` + SdBlobHash string `json:"sd_blob_hash,omitempty"` + SdBlobSize int `json:"sd_blob_size,omitempty"` +} + +type sendBlobResponse struct { + SendBlob bool `json:"send_blob"` +} + +type sendSdBlobResponse struct { + SendSdBlob bool `json:"send_sd_blob"` + NeededBlobs []string `json:"needed_blobs,omitempty"` +} + +type blobTransferResponse struct { + ReceivedBlob bool `json:"received_blob"` +} + +type sdBlobTransferResponse struct { + ReceivedSdBlob bool `json:"received_sd_blob"` } diff --git a/reflector/server_test.go b/reflector/server_test.go new file mode 100644 index 0000000..4cbe6ab --- /dev/null +++ b/reflector/server_test.go @@ -0,0 +1,158 @@ +package reflector + +import ( + "crypto/rand" + "io" + "strconv" + "testing" + "time" + + "github.com/davecgh/go-spew/spew" + "github.com/lbryio/reflector.go/store" + + "github.com/phayes/freeport" +) + +func startServerOnRandomPort(t *testing.T) (*Server, int) { + port, err := freeport.GetFreePort() + if err != nil { + t.Fatal(err) + } + + srv := NewServer(&store.MemoryBlobStore{}) + err = srv.Start("127.0.0.1:" + strconv.Itoa(port)) + if err != nil { + t.Fatal(err) + } + + return srv, port +} + +func TestClient_NotConnected(t *testing.T) { + c := Client{} + err := c.SendBlob([]byte{}) + if err == nil { + t.Error("client should error if it is not connected") + } +} + +func TestClient_EmptyBlob(t *testing.T) { + srv, port := startServerOnRandomPort(t) + defer srv.Shutdown() + + c := Client{} + err := c.Connect(":" + strconv.Itoa(port)) + if err != nil { + t.Fatal("error connecting client to server", err) + } + + err = c.SendBlob([]byte{}) + if err == nil { + t.Error("client should not send empty blob") + } +} + +func TestServer_MediumBlob(t *testing.T) { + srv, port := startServerOnRandomPort(t) + defer srv.Shutdown() + + c := Client{} + err := c.Connect(":" + strconv.Itoa(port)) + if err != nil { + t.Fatal("error connecting client to server", err) + } + + blob := make([]byte, 1000) + _, err = rand.Read(blob) + if err != nil { + t.Fatal("failed to make random blob") + } + + err = c.SendBlob(blob) + if err != nil { + t.Error(err) + } +} + +func TestServer_FullBlob(t *testing.T) { + srv, port := startServerOnRandomPort(t) + defer srv.Shutdown() + + c := Client{} + err := c.Connect(":" + strconv.Itoa(port)) + if err != nil { + t.Fatal("error connecting client to server", err) + } + + blob := make([]byte, maxBlobSize) + _, err = rand.Read(blob) + if err != nil { + t.Fatal("failed to make random blob") + } + + err = c.SendBlob(blob) + if err != nil { + t.Error(err) + } +} + +func TestServer_TooBigBlob(t *testing.T) { + srv, port := startServerOnRandomPort(t) + defer srv.Shutdown() + + c := Client{} + err := c.Connect(":" + strconv.Itoa(port)) + if err != nil { + t.Fatal("error connecting client to server", err) + } + + blob := make([]byte, maxBlobSize+1) + _, err = rand.Read(blob) + if err != nil { + t.Fatal("failed to make random blob") + } + + err = c.SendBlob(blob) + if err == nil { + t.Error("server should reject blob above max size") + } +} + +func TestServer_Timeout(t *testing.T) { + t.Skip("server and client have no way to detect errors right now") + + testTimeout := 50 * time.Millisecond + + port, err := freeport.GetFreePort() + if err != nil { + t.Fatal(err) + } + + srv := NewServer(&store.MemoryBlobStore{}) + srv.timeout = testTimeout + err = srv.Start("127.0.0.1:" + strconv.Itoa(port)) + if err != nil { + t.Fatal(err) + } + defer srv.Shutdown() + + c := Client{} + err = c.Connect(":" + strconv.Itoa(port)) + if err != nil { + t.Fatal("error connecting client to server", err) + } + + time.Sleep(testTimeout * 2) + + blob := make([]byte, 10) + _, err = rand.Read(blob) + if err != nil { + t.Fatal("failed to make random blob") + } + + err = c.SendBlob(blob) + t.Log(spew.Sdump(err)) + if err != io.EOF { + t.Error("server should have timed out by now") + } +} diff --git a/reflector/shared.go b/reflector/shared.go deleted file mode 100644 index 44ff5a6..0000000 --- a/reflector/shared.go +++ /dev/null @@ -1,58 +0,0 @@ -package reflector - -import ( - "crypto/sha512" - "encoding/hex" - - "github.com/lbryio/lbry.go/errors" -) - -const ( - // DefaultPort is the port the reflector server listens on if not passed in. - DefaultPort = 5566 - - maxBlobSize = 2 * 1024 * 1024 - - protocolVersion1 = 0 - protocolVersion2 = 1 -) - -// ErrBlobExists is a default error for when a blob already exists on the reflector server. -var ErrBlobExists = errors.Base("blob exists on server") - -type errorResponse struct { - Error string `json:"error"` -} - -type handshakeRequestResponse struct { - Version int `json:"version"` -} - -type sendBlobRequest struct { - BlobHash string `json:"blob_hash,omitempty"` - BlobSize int `json:"blob_size,omitempty"` - SdBlobHash string `json:"sd_blob_hash,omitempty"` - SdBlobSize int `json:"sd_blob_size,omitempty"` -} - -type sendBlobResponse struct { - SendBlob bool `json:"send_blob"` -} - -type sendSdBlobResponse struct { - SendSdBlob bool `json:"send_sd_blob"` - NeededBlobs []string `json:"needed_blobs,omitempty"` -} - -type blobTransferResponse struct { - ReceivedBlob bool `json:"received_blob"` -} - -type sdBlobTransferResponse struct { - ReceivedSdBlob bool `json:"received_sd_blob"` -} - -func getBlobHash(blob []byte) string { - hashBytes := sha512.Sum384(blob) - return hex.EncodeToString(hashBytes[:]) -}