Skip to content

Commit

Permalink
Feat: Add Streaming Support to an Async Connection (#152)
Browse files Browse the repository at this point in the history
* adding streams

* Updating polyglot and adding streaming support to Async connections and to the client

* Adding a test case to guarantee that both sides can open the stream and things will work exactly as expected

* Making stream additions blocking

* Refactoring streaming to use a callback handler for new streams and to properly discard them when required.

* Adding stream callback functions to the client and server

* Bumping polyglot for new release (and to use varints down the line)
  • Loading branch information
ShivanshVij authored Sep 27, 2022
1 parent c2238d0 commit a24fe81
Show file tree
Hide file tree
Showing 9 changed files with 648 additions and 177 deletions.
304 changes: 189 additions & 115 deletions async.go

Large diffs are not rendered by default.

24 changes: 20 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ func NewClient(handlerTable HandlerTable, ctx context.Context, opts ...Option) (

// Connect actually connects to the given frisbee server, and starts the reactor goroutines
// to receive and handle incoming packets. If this function is called, FromConn should not be called.
func (c *Client) Connect(addr string) error {
func (c *Client) Connect(addr string, streamHandler ...NewStreamHandler) error {
c.Logger().Debug().Msgf("Connecting to %s", addr)
var frisbeeConn *Async
var err error
frisbeeConn, err = ConnectAsync(addr, c.options.KeepAlive, c.Logger(), c.options.TLSConfig)
frisbeeConn, err = ConnectAsync(addr, c.options.KeepAlive, c.Logger(), c.options.TLSConfig, streamHandler...)
if err != nil {
return err
}
Expand All @@ -86,8 +86,8 @@ func (c *Client) Connect(addr string) error {

// FromConn takes a pre-existing connection to a Frisbee server and starts the reactor goroutines
// to receive and handle incoming packets. If this function is called, Connect should not be called.
func (c *Client) FromConn(conn net.Conn) error {
c.conn = NewAsync(conn, c.Logger())
func (c *Client) FromConn(conn net.Conn, streamHandler ...NewStreamHandler) error {
c.conn = NewAsync(conn, c.Logger(), streamHandler...)
c.wg.Add(1)
go c.handleConn()
c.Logger().Debug().Msgf("Connection handler started for %s", c.conn.RemoteAddr())
Expand Down Expand Up @@ -146,6 +146,22 @@ func (c *Client) Raw() (net.Conn, error) {
return c.conn.Raw(), nil
}

// Stream returns a new Stream object that can be used to send and receive frisbee packets
func (c *Client) Stream(id uint16) *Stream {
return c.conn.NewStream(id)
}

// SetNewStreamHandler sets the callback handler for new streams.
//
// It's important to note that this handler is called for new streams and if it is
// not set then stream packets will be dropped.
//
// It's also important to note that the handler itself is called in its own goroutine to
// avoid blocking the read lop. This means that the handler must be thread-safe.
func (c *Client) SetNewStreamHandler(handler NewStreamHandler) {
c.conn.SetNewStreamHandler(handler)
}

// Logger returns the client's logger (useful for ClientRouter functions)
func (c *Client) Logger() *zerolog.Logger {
return c.options.Logger
Expand Down
2 changes: 0 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ type Conn interface {
SetWriteDeadline(time.Time) error
WritePacket(*packet.Packet) error
ReadPacket() (*packet.Packet, error)
SetContext(context.Context)
Context() context.Context
Logger() *zerolog.Logger
Error() error
Raw() net.Conn
Expand Down
8 changes: 7 additions & 1 deletion frisbee.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ import (
var (
InvalidContentLength = errors.New("invalid content length")
ConnectionClosed = errors.New("connection closed")
StreamClosed = errors.New("stream closed")
InvalidStreamPacket = errors.New("invalid stream packet")
ConnectionNotInitialized = errors.New("connection not initialized")
InvalidBufferLength = errors.New("invalid buffer length")
InvalidHandlerTable = errors.New("invalid handler table configuration, a reserved value may have been used")
InvalidOperation = errors.New("invalid operation in packet, a reserved value may have been used")
)

// Action is an ENUM used to modify the state of the client or server from a Handler function
Expand Down Expand Up @@ -65,7 +68,10 @@ const (
// PONG is used to respond to a PING packets
PONG

RESERVED2
// STREAM is used to request that a new stream be created by the receiver to
// receive packets with the same packet ID until a packet with a ContentLength of 0 is received
STREAM

RESERVED3
RESERVED4
RESERVED5
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.18

require (
github.com/loopholelabs/common v0.4.4
github.com/loopholelabs/polyglot-go v0.3.0
github.com/loopholelabs/polyglot-go v0.5.0
github.com/loopholelabs/testing v0.2.3
github.com/pkg/errors v0.9.1
github.com/rs/zerolog v1.27.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/loopholelabs/common v0.4.4 h1:Ge+1v1WiLYgR/4pziOQoJAwUqUm1c9j6nQvnkiFFBsk=
github.com/loopholelabs/common v0.4.4/go.mod h1:YKnljczr4jgxkHhhAwIHh3CJXaff89YBd8Vp3pwpG3k=
github.com/loopholelabs/polyglot-go v0.3.0 h1:iOqPw5B3krCMYfgDaPgPoh1A87ACE8lKdbpbERM58pY=
github.com/loopholelabs/polyglot-go v0.3.0/go.mod h1:9/Hr1nFO9Al46806vMP3DB2k8blQ3gazBPaoOsdgo34=
github.com/loopholelabs/polyglot-go v0.5.0 h1:F65/d+65qgAu2F0GcWzP6UVIwd9897bNEgylNMr8FGk=
github.com/loopholelabs/polyglot-go v0.5.0/go.mod h1:Z0QiNv4KRuWjQWpUerMhmkvRh6ks1pYmEH4SGpG0EHQ=
github.com/loopholelabs/testing v0.2.3 h1:4nVuK5ctaE6ua5Z0dYk2l7xTFmcpCYLUeGjRBp8keOA=
github.com/loopholelabs/testing v0.2.3/go.mod h1:gqtGY91soYD1fQoKQt/6kP14OYpS7gcbcIgq5mc9m8Q=
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
Expand Down
114 changes: 62 additions & 52 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ type Server struct {
startedCh chan struct{}
concurrency uint64
limiter chan struct{}
streamHandler func(*Stream)

// baseContext is used to define the base context for this Server and all incoming connections
baseContext func() context.Context
Expand All @@ -65,6 +66,9 @@ type Server struct {
// preWrite is run by the server before a write happens
preWrite func()

// StreamHandler is used to handle incoming client-initiated streams on the server
StreamHandler func(*Async, *Stream)

// ConnContext is used to define a connection-specific context based on the incoming connection
// and is run whenever a new connection is opened
ConnContext func(context.Context, *Async) context.Context
Expand Down Expand Up @@ -174,6 +178,13 @@ func (s *Server) Start(addr string) error {
}
s.wg.Add(1)
close(s.startedCh)

if s.StreamHandler != nil {
s.streamHandler = func(stream *Stream) {
s.StreamHandler(stream.Conn(), stream)
}
}

return s.handleListener()
}

Expand Down Expand Up @@ -219,6 +230,50 @@ func (s *Server) handleListener() error {
}
}

func (s *Server) createHandler(conn *Async, closed *atomic.Bool, wg *sync.WaitGroup, ctx context.Context, cancel context.CancelFunc) func(*packet.Packet) {
return func(p *packet.Packet) {
handlerFunc := s.handlerTable[p.Metadata.Operation]
if handlerFunc != nil {
packetCtx := ctx
if s.PacketContext != nil {
packetCtx = s.PacketContext(packetCtx, p)
}
outgoing, action := handlerFunc(packetCtx, p)
if outgoing != nil && outgoing.Metadata.ContentLength == uint32(len(*outgoing.Content)) {
s.preWrite()
err := conn.WritePacket(outgoing)
if outgoing != p {
packet.Put(outgoing)
}
packet.Put(p)
if err != nil {
_ = conn.Close()
if closed.CAS(false, true) {
s.onClosed(conn, err)
}
cancel()
wg.Done()
return
}
} else {
packet.Put(p)
}
switch action {
case NONE:
case CLOSE:
_ = conn.Close()
if closed.CAS(false, true) {
s.onClosed(conn, nil)
}
cancel()
}
} else {
packet.Put(p)
}
wg.Done()
}
}

func (s *Server) handleSinglePacket(frisbeeConn *Async, connCtx context.Context) {
var p *packet.Packet
var outgoing *packet.Packet
Expand Down Expand Up @@ -276,50 +331,6 @@ func (s *Server) handleSinglePacket(frisbeeConn *Async, connCtx context.Context)
}
}

func (s *Server) handler(conn *Async, closed *atomic.Bool, wg *sync.WaitGroup, ctx context.Context, cancel context.CancelFunc) func(*packet.Packet) {
return func(p *packet.Packet) {
handlerFunc := s.handlerTable[p.Metadata.Operation]
if handlerFunc != nil {
packetCtx := ctx
if s.PacketContext != nil {
packetCtx = s.PacketContext(packetCtx, p)
}
outgoing, action := handlerFunc(packetCtx, p)
if outgoing != nil && outgoing.Metadata.ContentLength == uint32(len(*outgoing.Content)) {
s.preWrite()
err := conn.WritePacket(outgoing)
if outgoing != p {
packet.Put(outgoing)
}
packet.Put(p)
if err != nil {
_ = conn.Close()
if closed.CAS(false, true) {
s.onClosed(conn, err)
}
cancel()
wg.Done()
return
}
} else {
packet.Put(p)
}
switch action {
case NONE:
case CLOSE:
_ = conn.Close()
if closed.CAS(false, true) {
s.onClosed(conn, nil)
}
cancel()
}
} else {
packet.Put(p)
}
wg.Done()
}
}

func (s *Server) handleUnlimitedPacket(frisbeeConn *Async, connCtx context.Context) {
p, err := frisbeeConn.ReadPacket()
if err != nil {
Expand All @@ -333,10 +344,10 @@ func (s *Server) handleUnlimitedPacket(frisbeeConn *Async, connCtx context.Conte
wg := new(sync.WaitGroup)
closed := atomic.NewBool(false)
connCtx, cancel := context.WithCancel(connCtx)
handler := s.handler(frisbeeConn, closed, wg, connCtx, cancel)
handle := s.createHandler(frisbeeConn, closed, wg, connCtx, cancel)
for {
wg.Add(1)
go handler(p)
go handle(p)
p, err = frisbeeConn.ReadPacket()
if err != nil {
_ = frisbeeConn.Close()
Expand All @@ -362,16 +373,16 @@ func (s *Server) handleLimitedPacket(frisbeeConn *Async, connCtx context.Context
wg := new(sync.WaitGroup)
closed := atomic.NewBool(false)
connCtx, cancel := context.WithCancel(connCtx)
uHandler := s.handler(frisbeeConn, closed, wg, connCtx, cancel)
handler := func(p *packet.Packet) {
uHandler(p)
handler := s.createHandler(frisbeeConn, closed, wg, connCtx, cancel)
handle := func(p *packet.Packet) {
handler(p)
<-s.limiter
}
for {
select {
case s.limiter <- struct{}{}:
wg.Add(1)
go handler(p)
go handle(p)
p, err = frisbeeConn.ReadPacket()
if err != nil {
_ = frisbeeConn.Close()
Expand Down Expand Up @@ -421,9 +432,8 @@ func (s *Server) serveConn(newConn net.Conn) {
}
}

frisbeeConn := NewAsync(newConn, s.Logger())
frisbeeConn := NewAsync(newConn, s.Logger(), s.streamHandler)
connCtx := s.baseContext()

s.connectionsMu.Lock()
if s.shutdown.Load() {
s.wg.Done()
Expand Down
Loading

0 comments on commit a24fe81

Please sign in to comment.