Skip to content

Commit

Permalink
wsutil: Improved internal code
Browse files Browse the repository at this point in the history
  • Loading branch information
diamondburned committed Aug 20, 2020
1 parent fd818e1 commit 6b4e26e
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 70 deletions.
138 changes: 72 additions & 66 deletions utils/wsutil/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@ import (
"github.com/pkg/errors"
)

// CopyBufferSize is used for the initial size of the internal WS' buffer.
const CopyBufferSize = 2048

// MaxCapUntilReset determines the maximum capacity before the bytes buffer is
// re-allocated. This constant is 4MB.
const MaxCapUntilReset = 4 * (1 << 20)

// CloseDeadline controls the deadline to wait for sending the Close frame.
var CloseDeadline = time.Second

Expand Down Expand Up @@ -52,8 +57,7 @@ type Conn struct {
events chan Event

// write channels
writes chan []byte
errors chan error
writeMu *sync.Mutex

buf bytes.Buffer
zlib io.ReadCloser // (compress/zlib).reader
Expand All @@ -72,15 +76,23 @@ func NewConn() *Conn {
}

func NewConnWithDriver(driver json.Driver) *Conn {
writeMu := sync.Mutex{}
writeMu.Lock()

writeBuf := bytes.Buffer{}
writeBuf.Grow(CopyBufferSize)

return &Conn{
Driver: driver,
dialer: &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: WSTimeout,
ReadBufferSize: CopyBufferSize,
WriteBufferSize: CopyBufferSize,
EnableCompression: true,
},
// zlib: zlib.NewInflator(),
// buf: make([]byte, CopyBufferSize),
writeMu: &writeMu,
buf: writeBuf,
}
}

Expand All @@ -105,12 +117,11 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
// Set up the closer.
c.closeOnce = &sync.Once{}

c.events = make(chan Event)
c.events = make(chan Event, WSBuffer)
go c.readLoop()

c.writes = make(chan []byte)
c.errors = make(chan error)
go c.writeLoop()
// Unlock the mutex that would otherwise be acquired in NewConn and Close.
c.writeMu.Unlock()

return err
}
Expand All @@ -120,12 +131,6 @@ func (c *Conn) Listen() <-chan Event {
}

func (c *Conn) readLoop() {
// Acquire the read lock throughout the span of the loop. This would still
// allow Send to acquire another RLock, but wouldn't allow Close to
// prematurely exit, as Close acquires a write lock.
// c.mut.RLock()
// defer c.mut.RUnlock()

// Clean up the events channel in the end.
defer close(c.events)

Expand Down Expand Up @@ -157,27 +162,6 @@ func (c *Conn) readLoop() {
}
}

func (c *Conn) writeLoop() {
// Closing c.writes would break the loop immediately.
for b := range c.writes {
c.errors <- c.Conn.WriteMessage(websocket.TextMessage, b)
}

// Quick deadline:
deadline := time.Now().Add(CloseDeadline)

// Make a closure message:
msg := websocket.FormatCloseMessage(websocket.CloseGoingAway, "")

// Send a close message before closing the connection. We're not error
// checking this because it's not important.
c.Conn.WriteControl(websocket.TextMessage, msg, deadline)

// Safe to close now.
c.errors <- c.Conn.Close()
close(c.errors)
}

func (c *Conn) handle() ([]byte, error) {
// skip message type
t, r, err := c.Conn.NextReader()
Expand Down Expand Up @@ -230,50 +214,45 @@ func (c *Conn) handle() ([]byte, error) {
// return nil, errors.New("Unexpected binary message.")
}

// resetDeadline is used to reset the write deadline after using the context's.
var resetDeadline = time.Time{}

func (c *Conn) Send(ctx context.Context, b []byte) error {
// If websocket is already closed.
if c.writes == nil {
return ErrWebsocketClosed
}
c.writeMu.Lock()
defer c.writeMu.Unlock()

// Send the bytes.
select {
case c.writes <- b:
// continue
case <-ctx.Done():
return ctx.Err()
d, ok := ctx.Deadline()
if ok {
c.Conn.SetWriteDeadline(d)
defer c.Conn.SetWriteDeadline(resetDeadline)
}

// Receive the error.
select {
case err := <-c.errors:
return err
case <-ctx.Done():
return ctx.Err()
}
return c.Conn.WriteMessage(websocket.TextMessage, b)
}

func (c *Conn) Close() (err error) {
// Use a sync.Once to guarantee that other Close() calls block until the
// main call is done. It also prevents future calls.
c.closeOnce.Do(func() {
// Close c.writes. This should trigger the websocket to close itself.
close(c.writes)
// Mark c.writes as empty.
c.writes = nil

// Wait for the write loop to exit by flusing the errors channel.
err = <-c.errors // get close error
for range c.errors { // then flush
}
WSDebug("Conn: Acquiring write lock...")

// Acquire the write lock forever.
c.writeMu.Lock()

WSDebug("Conn: Write lock acquired; closing.")

// Close the WS.
err = c.closeWS()

WSDebug("Conn: Websocket closed; error:", err)
WSDebug("Conn: Flusing events...")

// Flush all events before closing the channel. This will return as soon as
// c.events is closed, or after closed.
for range c.events {
}

// Mark c.events as empty.
c.events = nil
WSDebug("Flushed events.")

// Mark c.Conn as empty.
c.Conn = nil
Expand All @@ -282,18 +261,45 @@ func (c *Conn) Close() (err error) {
return err
}

func (c *Conn) closeWS() (err error) {
// We can't close with a write control here, since it will invalidate the
// old session, breaking resumes.

// // Quick deadline:
// deadline := time.Now().Add(CloseDeadline)

// // Make a closure message:
// msg := websocket.FormatCloseMessage(websocket.CloseGoingAway, "")

// // Send a close message before closing the connection. We're not error
// // checking this because it's not important.
// err = c.Conn.WriteControl(websocket.CloseMessage, msg, deadline)

if err := c.Conn.Close(); err != nil {
return err
}

return
}

// readAll reads bytes into an existing buffer, copy it over, then wipe the old
// buffer.
func readAll(buf *bytes.Buffer, r io.Reader) ([]byte, error) {
defer buf.Reset()

if _, err := buf.ReadFrom(r); err != nil {
return nil, err
}

// Copy the bytes so we could empty the buffer for reuse.
p := buf.Bytes()
cpy := make([]byte, len(p))
copy(cpy, p)
cpy := make([]byte, buf.Len())
copy(cpy, buf.Bytes())

// If the buffer's capacity is over the limit, then re-allocate a new one.
if buf.Cap() > MaxCapUntilReset {
*buf = bytes.Buffer{}
buf.Grow(CopyBufferSize)
}

return cpy, nil
}
5 changes: 1 addition & 4 deletions utils/wsutil/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,12 @@ import (
var (
// WSTimeout is the timeout for connecting and writing to the Websocket,
// before Gateway cancels and fails.
WSTimeout = 5 * time.Minute
WSTimeout = 30 * time.Second
// WSBuffer is the size of the Event channel. This has to be at least 1 to
// make space for the first Event: Ready or Resumed.
WSBuffer = 10
// WSError is the default error handler
WSError = func(err error) { log.Println("Gateway error:", err) }
// WSExtraReadTimeout is the duration to be added to Hello, as a read
// timeout for the websocket.
WSExtraReadTimeout = time.Second
// WSDebug is used for extra debug logging. This is expected to behave
// similarly to log.Println().
WSDebug = func(v ...interface{}) {}
Expand Down

0 comments on commit 6b4e26e

Please sign in to comment.