diff --git a/CHANGELOG.md b/CHANGELOG.md index 898427a..d7aa034 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,23 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +## [v0.4.0] - 2022-03-24 (Beta) + +## Changes + +- Changing `Connect` signatures and `Start` signatures for servers, and clients +- Changing the functionality of Server.`Start` so that it blocks and returns an error +- Adding `ServeConn` and `FromConn` functions for severs and clients +- Updating `protoc-gen-frisbee` to comply with the new changes +- Updating the buf.build manifest for `protoc-gen-frisbee` +- Making `baseContext`, `onClosed`, and `preWrite` hooks for the Server private, and creating `Setter` functions that + make it impossible to set those functions to nil + +## Fixes + +- Fixing panics from `ConnectSync` and `ConnectAsync` functions when the connection cannot be established - it now + returns an error properly instead + ## [v0.3.2] - 2022-03-18 (Beta) ## Changes @@ -179,7 +196,8 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). Initial Release of Frisbee -[unreleased]: https://github.com/loopholelabs/frisbee/compare/v0.3.2...HEAD +[unreleased]: https://github.com/loopholelabs/frisbee/compare/v0.4.0...HEAD +[v0.4.0]: https://github.com/loopholelabs/frisbee/compare/v0.3.2...v0.4.0 [v0.3.2]: https://github.com/loopholelabs/frisbee/compare/v0.3.1...v0.3.2 [v0.3.1]: https://github.com/loopholelabs/frisbee/compare/v0.3.0...v0.3.1 [v0.3.0]: https://github.com/loopholelabs/frisbee/compare/v0.2.4...v0.3.0 diff --git a/async.go b/async.go index 2fbf3dd..665d685 100644 --- a/async.go +++ b/async.go @@ -55,7 +55,7 @@ type Async struct { } // ConnectAsync creates a new TCP connection (using net.Dial) and wraps it in a frisbee connection -func ConnectAsync(addr string, keepAlive time.Duration, logger *zerolog.Logger, TLSConfig *tls.Config, blocking bool) (*Async, error) { +func ConnectAsync(addr string, keepAlive time.Duration, logger *zerolog.Logger, TLSConfig *tls.Config) (*Async, error) { var conn net.Conn var err error @@ -63,19 +63,21 @@ func ConnectAsync(addr string, keepAlive time.Duration, logger *zerolog.Logger, conn, err = tls.Dial("tcp", addr, TLSConfig) } else { conn, err = net.Dial("tcp", addr) - _ = conn.(*net.TCPConn).SetKeepAlive(true) - _ = conn.(*net.TCPConn).SetKeepAlivePeriod(keepAlive) + if err == nil { + _ = conn.(*net.TCPConn).SetKeepAlive(true) + _ = conn.(*net.TCPConn).SetKeepAlivePeriod(keepAlive) + } } if err != nil { return nil, err } - return NewAsync(conn, logger, blocking), nil + return NewAsync(conn, logger), nil } // NewAsync takes an existing net.Conn object and wraps it in a frisbee connection -func NewAsync(c net.Conn, logger *zerolog.Logger, blocking bool) (conn *Async) { +func NewAsync(c net.Conn, logger *zerolog.Logger) (conn *Async) { conn = &Async{ conn: c, closed: atomic.NewBool(false), diff --git a/async_test.go b/async_test.go index 79582ad..68a5e37 100644 --- a/async_test.go +++ b/async_test.go @@ -40,8 +40,8 @@ func TestNewAsync(t *testing.T) { reader, writer := net.Pipe() - readerConn := NewAsync(reader, &emptyLogger, false) - writerConn := NewAsync(writer, &emptyLogger, false) + readerConn := NewAsync(reader, &emptyLogger) + writerConn := NewAsync(writer, &emptyLogger) p := packet.Get() p.Metadata.Id = 64 @@ -97,8 +97,8 @@ func TestAsyncLargeWrite(t *testing.T) { reader, writer := net.Pipe() - readerConn := NewAsync(reader, &emptyLogger, false) - writerConn := NewAsync(writer, &emptyLogger, false) + readerConn := NewAsync(reader, &emptyLogger) + writerConn := NewAsync(writer, &emptyLogger) randomData := make([][]byte, testSize) p := packet.Get() @@ -145,8 +145,8 @@ func TestAsyncRawConn(t *testing.T) { reader, writer, err := pair.New() require.NoError(t, err) - readerConn := NewAsync(reader, &emptyLogger, false) - writerConn := NewAsync(writer, &emptyLogger, false) + readerConn := NewAsync(reader, &emptyLogger) + writerConn := NewAsync(writer, &emptyLogger) randomData := make([]byte, packetSize) _, _ = rand.Read(randomData) @@ -204,8 +204,8 @@ func TestAsyncReadClose(t *testing.T) { emptyLogger := zerolog.New(ioutil.Discard) - readerConn := NewAsync(reader, &emptyLogger, false) - writerConn := NewAsync(writer, &emptyLogger, false) + readerConn := NewAsync(reader, &emptyLogger) + writerConn := NewAsync(writer, &emptyLogger) p := packet.Get() p.Metadata.Id = 64 @@ -252,8 +252,8 @@ func TestAsyncReadAvailableClose(t *testing.T) { emptyLogger := zerolog.New(ioutil.Discard) - readerConn := NewAsync(reader, &emptyLogger, false) - writerConn := NewAsync(writer, &emptyLogger, false) + readerConn := NewAsync(reader, &emptyLogger) + writerConn := NewAsync(writer, &emptyLogger) p := packet.Get() p.Metadata.Id = 64 @@ -302,8 +302,8 @@ func TestAsyncWriteClose(t *testing.T) { emptyLogger := zerolog.New(ioutil.Discard) - readerConn := NewAsync(reader, &emptyLogger, false) - writerConn := NewAsync(writer, &emptyLogger, false) + readerConn := NewAsync(reader, &emptyLogger) + writerConn := NewAsync(writer, &emptyLogger) p := packet.Get() p.Metadata.Id = 64 @@ -353,8 +353,8 @@ func TestAsyncTimeout(t *testing.T) { reader, writer, err := pair.New() require.NoError(t, err) - readerConn := NewAsync(reader, &emptyLogger, false) - writerConn := NewAsync(writer, &emptyLogger, false) + readerConn := NewAsync(reader, &emptyLogger) + writerConn := NewAsync(writer, &emptyLogger) p := packet.Get() p.Metadata.Id = 64 @@ -389,7 +389,9 @@ func TestAsyncTimeout(t *testing.T) { err = writerConn.conn.Close() assert.NoError(t, err) + runtime.Gosched() time.Sleep(defaultDeadline * 5) + runtime.Gosched() _, err = readerConn.ReadPacket() assert.ErrorIs(t, err, ConnectionClosed) @@ -414,8 +416,8 @@ func BenchmarkAsyncThroughputPipe(b *testing.B) { reader, writer := net.Pipe() - readerConn := NewAsync(reader, &emptyLogger, false) - writerConn := NewAsync(writer, &emptyLogger, false) + readerConn := NewAsync(reader, &emptyLogger) + writerConn := NewAsync(writer, &emptyLogger) b.Run("32 Bytes", throughputRunner(testSize, 32, readerConn, writerConn)) b.Run("512 Bytes", throughputRunner(testSize, 512, readerConn, writerConn)) @@ -437,8 +439,8 @@ func BenchmarkAsyncThroughputNetwork(b *testing.B) { b.Fatal(err) } - readerConn := NewAsync(reader, &emptyLogger, false) - writerConn := NewAsync(writer, &emptyLogger, false) + readerConn := NewAsync(reader, &emptyLogger) + writerConn := NewAsync(writer, &emptyLogger) b.Run("32 Bytes", throughputRunner(testSize, 32, readerConn, writerConn)) b.Run("512 Bytes", throughputRunner(testSize, 512, readerConn, writerConn)) diff --git a/client.go b/client.go index 2b6a5a2..595e76e 100644 --- a/client.go +++ b/client.go @@ -28,7 +28,6 @@ import ( // Client connects to a frisbee Server and can send and receive frisbee packets type Client struct { - addr string conn *Async handlerTable HandlerTable ctx context.Context @@ -48,12 +47,7 @@ type Client struct { // NewClient returns an uninitialized frisbee Client with the registered ClientRouter. // The ConnectAsync method must then be called to dial the server and initialize the connection. -// -// If poolSize == 0 then no pool will be allocated, and all handlers will be run synchronously for their -// incoming connections. If poolSize == -1 then a pool with unlimited size will be allocated. Otherwise, a pool -// with size `poolSize` will be allocated. -func NewClient(addr string, handlerTable HandlerTable, ctx context.Context, opts ...Option) (*Client, error) { - +func NewClient(handlerTable HandlerTable, ctx context.Context, opts ...Option) (*Client, error) { for i := uint16(0); i < RESERVED9; i++ { if _, ok := handlerTable[i]; ok { return nil, InvalidHandlerTable @@ -71,7 +65,6 @@ func NewClient(addr string, handlerTable HandlerTable, ctx context.Context, opts } return &Client{ - addr: addr, handlerTable: handlerTable, ctx: ctx, options: options, @@ -81,26 +74,43 @@ func NewClient(addr string, handlerTable HandlerTable, ctx context.Context, opts } // Connect actually connects to the given frisbee server, and starts the reactor goroutines -// to receive and handle incoming packets. -func (c *Client) Connect() error { - c.Logger().Debug().Msgf("Connecting to %s", c.addr) +// to receive and handle incoming packets. If this function is called, FromConn should not be called. +func (c *Client) Connect(addr string) error { + c.Logger().Debug().Msgf("Connecting to %s", addr) var frisbeeConn *Async var err error - frisbeeConn, err = ConnectAsync(c.addr, c.options.KeepAlive, c.Logger(), c.options.TLSConfig, true) + frisbeeConn, err = ConnectAsync(addr, c.options.KeepAlive, c.Logger(), c.options.TLSConfig) if err != nil { return err } c.conn = frisbeeConn - c.Logger().Info().Msgf("Connected to %s", c.addr) + c.Logger().Info().Msgf("Connected to %s", addr) + + c.wg.Add(1) + go c.handleConn() + c.Logger().Debug().Msgf("Connection handler started for %s", addr) + + if c.options.Heartbeat > time.Duration(0) { + c.wg.Add(1) + go c.heartbeat() + c.Logger().Debug().Msgf("Heartbeat started for %s", addr) + } + + return nil +} +// FromConn takes a pre-existing connection to a Frisbee server and starts the reactor goroutines +// to receive and handle incoming packets. If this function is called, Connect should not be called. +func (c *Client) FromConn(conn net.Conn) error { + c.conn = NewAsync(conn, c.Logger()) c.wg.Add(1) go c.handleConn() - c.Logger().Debug().Msgf("Connection handler started for %s", c.addr) + c.Logger().Debug().Msgf("Connection handler started for %s", c.conn.RemoteAddr()) if c.options.Heartbeat > time.Duration(0) { c.wg.Add(1) go c.heartbeat() - c.Logger().Debug().Msgf("Heartbeat started for %s", c.addr) + c.Logger().Debug().Msgf("Heartbeat started for %s", c.conn.RemoteAddr()) } return nil @@ -210,7 +220,7 @@ LOOP: c.ctx = c.UpdateContext(c.ctx, c.conn) } case CLOSE: - c.Logger().Debug().Msgf("Closing connection %s because of CLOSE action", c.addr) + c.Logger().Debug().Msgf("Closing connection %s because of CLOSE action", c.conn.RemoteAddr()) c.wg.Done() _ = c.Close() return diff --git a/client_test.go b/client_test.go index edca3a4..cd355ae 100644 --- a/client_test.go +++ b/client_test.go @@ -21,6 +21,7 @@ import ( "crypto/rand" "github.com/loopholelabs/frisbee/pkg/metadata" "github.com/loopholelabs/frisbee/pkg/packet" + "github.com/loopholelabs/testing/conn/pair" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -63,22 +64,24 @@ func TestClientRaw(t *testing.T) { } emptyLogger := zerolog.New(ioutil.Discard) - s, err := NewServer(":0", serverHandlerTable, WithLogger(&emptyLogger)) + s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) require.NoError(t, err) s.ConnContext = func(ctx context.Context, c *Async) context.Context { return context.WithValue(ctx, clientConnContextKey, c) } - err = s.Start() + serverConn, clientConn, err := pair.New() require.NoError(t, err) - c, err := NewClient(s.listener.Addr().String(), clientHandlerTable, context.Background(), WithLogger(&emptyLogger)) + go s.ServeConn(serverConn) + + c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(&emptyLogger)) assert.NoError(t, err) _, err = c.Raw() assert.ErrorIs(t, ConnectionNotInitialized, err) - err = c.Connect() + err = c.FromConn(clientConn) require.NoError(t, err) data := make([]byte, packetSize) @@ -154,18 +157,20 @@ func TestClientStaleClose(t *testing.T) { } emptyLogger := zerolog.New(ioutil.Discard) - s, err := NewServer(":0", serverHandlerTable, WithLogger(&emptyLogger)) + s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) require.NoError(t, err) - err = s.Start() + serverConn, clientConn, err := pair.New() require.NoError(t, err) - c, err := NewClient(s.listener.Addr().String(), clientHandlerTable, context.Background(), WithLogger(&emptyLogger)) + go s.ServeConn(serverConn) + + c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(&emptyLogger)) assert.NoError(t, err) _, err = c.Raw() assert.ErrorIs(t, ConnectionNotInitialized, err) - err = c.Connect() + err = c.FromConn(clientConn) require.NoError(t, err) data := make([]byte, packetSize) @@ -210,21 +215,23 @@ func BenchmarkThroughputClient(b *testing.B) { } emptyLogger := zerolog.New(ioutil.Discard) - s, err := NewServer(":0", serverHandlerTable, WithLogger(&emptyLogger)) + s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) if err != nil { b.Fatal(err) } - err = s.Start() + serverConn, clientConn, err := pair.New() if err != nil { b.Fatal(err) } - c, err := NewClient(s.listener.Addr().String(), clientHandlerTable, context.Background(), WithLogger(&emptyLogger)) + go s.ServeConn(serverConn) + + c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(&emptyLogger)) if err != nil { b.Fatal(err) } - err = c.Connect() + err = c.FromConn(clientConn) if err != nil { b.Fatal(err) } @@ -291,21 +298,23 @@ func BenchmarkThroughputResponseClient(b *testing.B) { } emptyLogger := zerolog.New(ioutil.Discard) - s, err := NewServer(":0", serverHandlerTable, WithLogger(&emptyLogger)) + s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) if err != nil { b.Fatal(err) } - err = s.Start() + serverConn, clientConn, err := pair.New() if err != nil { b.Fatal(err) } - c, err := NewClient(s.listener.Addr().String(), clientHandlerTable, context.Background(), WithLogger(&emptyLogger)) + go s.ServeConn(serverConn) + + c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(&emptyLogger)) if err != nil { b.Fatal(err) } - err = c.Connect() + err = c.FromConn(clientConn) if err != nil { b.Fatal(err) } diff --git a/frisbee_test.go b/frisbee_test.go index 35d75dc..a209a0d 100644 --- a/frisbee_test.go +++ b/frisbee_test.go @@ -33,7 +33,7 @@ func ExampleNewClient() { logger := zerolog.New(os.Stdout) - _, _ = frisbee.NewClient("127.0.0.1:8080", handlerTable, context.Background(), frisbee.WithLogger(&logger)) + _, _ = frisbee.NewClient(handlerTable, context.Background(), frisbee.WithLogger(&logger)) } func ExampleNewServer() { @@ -45,5 +45,5 @@ func ExampleNewServer() { logger := zerolog.New(os.Stdout) - _, _ = frisbee.NewServer("127.0.0.1:8080", handlerTable, frisbee.WithLogger(&logger)) + _, _ = frisbee.NewServer(handlerTable, frisbee.WithLogger(&logger)) } diff --git a/protoc-gen-frisbee/dockerfile b/protoc-gen-frisbee/dockerfile index fe1263e..088d96c 100644 --- a/protoc-gen-frisbee/dockerfile +++ b/protoc-gen-frisbee/dockerfile @@ -2,7 +2,7 @@ FROM golang as builder ENV GOOS=linux GOARCH=amd64 CGO_ENABLED=0 -RUN go install github.com/loopholelabs/frisbee/protoc-gen-frisbee@v0.3.1 +RUN go install github.com/loopholelabs/frisbee/protoc-gen-frisbee@v0.4.0 # Note, the Docker images must be built for amd64. If the host machine architecture is not amd64 # you need to cross-compile the binary and move it into /go/bin. @@ -12,7 +12,7 @@ FROM scratch # Runtime dependencies LABEL "build.buf.plugins.runtime_library_versions.0.name"="github.com/loopholelabs/frisbee" -LABEL "build.buf.plugins.runtime_library_versions.0.version"="v0.3.1" +LABEL "build.buf.plugins.runtime_library_versions.0.version"="v0.4.0" COPY --from=builder /go/bin / diff --git a/protoc-gen-frisbee/pkg/generator/client.go b/protoc-gen-frisbee/pkg/generator/client.go index 1869884..2147e0e 100644 --- a/protoc-gen-frisbee/pkg/generator/client.go +++ b/protoc-gen-frisbee/pkg/generator/client.go @@ -131,19 +131,19 @@ func writeClient(f File, services protoreflect.ServiceDescriptors) { f.P(typeClose) f.P() - f.P("func NewClient(addr string, tlsConfig *tls.Config, logger *zerolog.Logger) (*Client, error) {") + f.P("func NewClient(tlsConfig *tls.Config, logger *zerolog.Logger) (*Client, error) {") f.P(tab, "c := new(Client)") f.P(tab, "table := make(frisbee.HandlerTable)") writeClientHandlers(f, services) f.P(tab, "var err error") f.P(tab, "if tlsConfig != nil {") - f.P(tab, tab, "c.Client, err = frisbee.NewClient(addr, table, context.Background(), frisbee.WithTLS(tlsConfig), frisbee.WithLogger(logger))") + f.P(tab, tab, "c.Client, err = frisbee.NewClient(table, context.Background(), frisbee.WithTLS(tlsConfig), frisbee.WithLogger(logger))") f.P(tab, tab, "if err != nil {") f.P(tab, tab, tab, "return nil, err") f.P(tab, tab, "}") f.P(tab, "} else {") - f.P(tab, tab, "c.Client, err = frisbee.NewClient(addr, table, context.Background(), frisbee.WithLogger(logger))") + f.P(tab, tab, "c.Client, err = frisbee.NewClient(table, context.Background(), frisbee.WithLogger(logger))") f.P(tab, tab, "if err != nil {") f.P(tab, tab, tab, "return nil, err") f.P(tab, tab, "}") diff --git a/protoc-gen-frisbee/pkg/generator/server.go b/protoc-gen-frisbee/pkg/generator/server.go index eaefa0a..9721577 100644 --- a/protoc-gen-frisbee/pkg/generator/server.go +++ b/protoc-gen-frisbee/pkg/generator/server.go @@ -78,19 +78,19 @@ func writeServer(f File, services protoreflect.ServiceDescriptors) { } serverFields := builder.String() serverFields = serverFields[:len(serverFields)-2] - f.P("func NewServer(", serverFields, ", listenAddr string, tlsConfig *tls.Config, logger *zerolog.Logger) (*Server, error) {") + f.P("func NewServer(", serverFields, ", tlsConfig *tls.Config, logger *zerolog.Logger) (*Server, error) {") f.P(tab, "table := make(frisbee.HandlerTable)") writeServerHandlers(f, services) f.P(tab, "var s *frisbee.Server") f.P(tab, "var err error") f.P(tab, "if tlsConfig != nil {") - f.P(tab, tab, "s, err = frisbee.NewServer(listenAddr, table, frisbee.WithTLS(tlsConfig), frisbee.WithLogger(logger))") + f.P(tab, tab, "s, err = frisbee.NewServer(table, frisbee.WithTLS(tlsConfig), frisbee.WithLogger(logger))") f.P(tab, tab, "if err != nil {") f.P(tab, tab, tab, "return nil, err") f.P(tab, tab, "}") f.P(tab, "} else {") - f.P(tab, tab, "s, err = frisbee.NewServer(listenAddr, table, frisbee.WithLogger(logger))") + f.P(tab, tab, "s, err = frisbee.NewServer(table, frisbee.WithLogger(logger))") f.P(tab, tab, "if err != nil {") f.P(tab, tab, tab, "return nil, err") f.P(tab, tab, "}") diff --git a/server.go b/server.go index afb1740..ac5c9e5 100644 --- a/server.go +++ b/server.go @@ -20,6 +20,7 @@ import ( "context" "crypto/tls" "github.com/loopholelabs/frisbee/pkg/packet" + "github.com/pkg/errors" "github.com/rs/zerolog" "go.uber.org/atomic" "net" @@ -27,6 +28,12 @@ import ( "time" ) +var ( + BaseContextNil = errors.New("BaseContext cannot be nil") + OnClosedNil = errors.New("OnClosed cannot be nil") + PreWriteNil = errors.New("PreWrite cannot be nil") +) + var ( defaultBaseContext = func() context.Context { return context.Background() @@ -40,7 +47,6 @@ var ( // Server accepts connections from frisbee Clients and can send and receive frisbee Packets type Server struct { listener net.Listener - addr string handlerTable HandlerTable shutdown *atomic.Bool options *Options @@ -48,8 +54,14 @@ type Server struct { connections map[*Async]struct{} connectionsMu sync.Mutex - // BaseContext is used to define the base context for this Server and all incoming connections - BaseContext func() context.Context + // baseContext is used to define the base context for this Server and all incoming connections + baseContext func() context.Context + + // onClosed is a function run by the server whenever a connection is closed + onClosed func(*Async, error) + + // preWrite is run by the server before a write happens + preWrite func() // ConnContext is used to define a connection-specific context based on the incoming connection // and is run whenever a new connection is opened @@ -62,17 +74,11 @@ type Server struct { // UpdateContext is used to update a handler-specific context whenever the returned // Action from a handler is UPDATE UpdateContext func(context.Context, *Async) context.Context - - // OnClosed is a function run by the server whenever a connection is closed - OnClosed func(*Async, error) - - // PreWrite is run by the server before a write happens - PreWrite func() } // NewServer returns an uninitialized frisbee Server with the registered HandlerTable. // The Start method must then be called to start the server and listen for connections. -func NewServer(addr string, handlerTable HandlerTable, opts ...Option) (*Server, error) { +func NewServer(handlerTable HandlerTable, opts ...Option) (*Server, error) { for i := uint16(0); i < RESERVED9; i++ { if _, ok := handlerTable[i]; ok { return nil, InvalidHandlerTable @@ -88,64 +94,77 @@ func NewServer(addr string, handlerTable HandlerTable, opts ...Option) (*Server, } return &Server{ - addr: addr, handlerTable: handlerTable, options: options, shutdown: atomic.NewBool(false), connections: make(map[*Async]struct{}), + baseContext: defaultBaseContext, + onClosed: defaultOnClosed, + preWrite: defaultPreWrite, }, nil } -// Start will start the frisbee server and its reactor goroutines -// to receive and handle incoming connections. If the BaseContext, ConnContext, -// OnClosed, OnShutdown, or PreWrite functions have not been defined, it will -// use the default functions for these. -func (s *Server) Start() error { - - if s.BaseContext == nil { - s.BaseContext = defaultBaseContext +// SetBaseContext sets the baseContext function for the server. If f is nil, it returns an error. +func (s *Server) SetBaseContext(f func() context.Context) error { + if f != nil { + return BaseContextNil } + s.baseContext = f + return nil +} - if s.OnClosed == nil { - s.OnClosed = defaultOnClosed +// SetOnClosed sets the onClosed function for the server. If f is nil, it returns an error. +func (s *Server) SetOnClosed(f func(*Async, error)) error { + if f != nil { + return OnClosedNil } + s.onClosed = f + return nil +} - if s.PreWrite == nil { - s.PreWrite = defaultPreWrite +// SetPreWrite sets the preWrite function for the server. If f is nil, it returns an error. +func (s *Server) SetPreWrite(f func()) error { + if f != nil { + return PreWriteNil } + s.preWrite = f + return nil +} +// Start will start the frisbee server and its reactor goroutines +// to receive and handle incoming connections. If the baseContext, ConnContext, +// onClosed, OnShutdown, or preWrite functions have not been defined, it will +// use the default functions for these. +func (s *Server) Start(addr string) error { var err error if s.options.TLSConfig != nil { - s.listener, err = tls.Listen("tcp", s.addr, s.options.TLSConfig) + s.listener, err = tls.Listen("tcp", addr, s.options.TLSConfig) } else { - s.listener, err = net.Listen("tcp", s.addr) + s.listener, err = net.Listen("tcp", addr) } if err != nil { return err } - s.wg.Add(1) - go s.handleListener() - - return nil + return s.handleListener() } -func (s *Server) handleListener() { +func (s *Server) handleListener() error { var newConn net.Conn var err error for { newConn, err = s.listener.Accept() if err != nil { if s.shutdown.Load() { - s.wg.Done() - return + return nil } - s.Logger().Fatal().Err(err).Msg("error while accepting connection") - s.wg.Done() - return + return err } - s.wg.Add(1) - go s.handleConn(newConn) + go func() { + s.wg.Add(1) + s.ServeConn(newConn) + s.wg.Done() + }() } } @@ -176,7 +195,7 @@ HANDLE: } outgoing, action = handlerFunc(packetCtx, p) if outgoing != nil && outgoing.Metadata.ContentLength == uint32(len(outgoing.Content.B)) { - s.PreWrite() + s.preWrite() err = frisbeeConn.WritePacket(outgoing) if outgoing != p { packet.Put(outgoing) @@ -203,7 +222,8 @@ HANDLE: goto LOOP } -func (s *Server) handleConn(newConn net.Conn) { +// ServeConn takes a TCP net.Conn and serves it using the Server +func (s *Server) ServeConn(newConn net.Conn) { var err error switch v := newConn.(type) { case *net.TCPConn: @@ -211,24 +231,21 @@ func (s *Server) handleConn(newConn net.Conn) { if err != nil { s.Logger().Error().Err(err).Msg("Error while setting TCP Keepalive") _ = v.Close() - s.wg.Done() return } err = v.SetKeepAlivePeriod(s.options.KeepAlive) if err != nil { s.Logger().Error().Err(err).Msg("Error while setting TCP Keepalive Period") _ = v.Close() - s.wg.Done() return } } - frisbeeConn := NewAsync(newConn, s.Logger(), true) - connCtx := s.BaseContext() + frisbeeConn := NewAsync(newConn, s.Logger()) + connCtx := s.baseContext() s.connectionsMu.Lock() if s.shutdown.Load() { - s.wg.Done() return } s.connections[frisbeeConn] = struct{}{} @@ -236,13 +253,12 @@ func (s *Server) handleConn(newConn net.Conn) { err = s.handlePacket(frisbeeConn, connCtx) _ = frisbeeConn.Close() - s.OnClosed(frisbeeConn, err) + s.onClosed(frisbeeConn, err) s.connectionsMu.Lock() if !s.shutdown.Load() { delete(s.connections, frisbeeConn) } s.connectionsMu.Unlock() - s.wg.Done() } // Logger returns the server's logger (useful for ServerRouter functions) @@ -260,5 +276,8 @@ func (s *Server) Shutdown() error { } s.connectionsMu.Unlock() defer s.wg.Wait() - return s.listener.Close() + if s.listener != nil { + return s.listener.Close() + } + return nil } diff --git a/server_test.go b/server_test.go index 48d0725..4197bb3 100644 --- a/server_test.go +++ b/server_test.go @@ -21,13 +21,15 @@ import ( "crypto/rand" "github.com/loopholelabs/frisbee/pkg/metadata" "github.com/loopholelabs/frisbee/pkg/packet" + "github.com/loopholelabs/testing/conn/pair" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "io/ioutil" "net" + "runtime" + "sync" "testing" - "time" ) // trunk-ignore-all(golangci-lint/staticcheck) @@ -63,23 +65,26 @@ func TestServerRaw(t *testing.T) { } emptyLogger := zerolog.New(ioutil.Discard) - s, err := NewServer(":0", serverHandlerTable, WithLogger(&emptyLogger)) + s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) require.NoError(t, err) s.ConnContext = func(ctx context.Context, c *Async) context.Context { return context.WithValue(ctx, serverConnContextKey, c) } - err = s.Start() + serverConn, clientConn, err := pair.New() require.NoError(t, err) - c, err := NewClient(s.listener.Addr().String(), clientHandlerTable, context.Background(), WithLogger(&emptyLogger)) + go s.ServeConn(serverConn) + + c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(&emptyLogger)) assert.NoError(t, err) + _, err = c.Raw() assert.ErrorIs(t, ConnectionNotInitialized, err) - err = c.Connect() - require.NoError(t, err) + err = c.FromConn(clientConn) + assert.NoError(t, err) data := make([]byte, packetSize) _, _ = rand.Read(data) @@ -157,18 +162,20 @@ func TestServerStaleClose(t *testing.T) { } emptyLogger := zerolog.New(ioutil.Discard) - s, err := NewServer(":0", serverHandlerTable, WithLogger(&emptyLogger)) + s, err := NewServer(serverHandlerTable, WithLogger(&emptyLogger)) require.NoError(t, err) - err = s.Start() + serverConn, clientConn, err := pair.New() require.NoError(t, err) - c, err := NewClient(s.listener.Addr().String(), clientHandlerTable, context.Background(), WithLogger(&emptyLogger)) + go s.ServeConn(serverConn) + + c, err := NewClient(clientHandlerTable, context.Background(), WithLogger(&emptyLogger)) assert.NoError(t, err) _, err = c.Raw() assert.ErrorIs(t, ConnectionNotInitialized, err) - err = c.Connect() + err = c.FromConn(clientConn) require.NoError(t, err) data := make([]byte, packetSize) @@ -208,20 +215,19 @@ func BenchmarkThroughputServer(b *testing.B) { } emptyLogger := zerolog.New(ioutil.Discard) - server, err := NewServer(":0", handlerTable, WithLogger(&emptyLogger)) + server, err := NewServer(handlerTable, WithLogger(&emptyLogger)) if err != nil { b.Fatal(err) } - err = server.Start() + serverConn, clientConn, err := pair.New() if err != nil { b.Fatal(err) } - frisbeeConn, err := ConnectAsync(server.listener.Addr().String(), time.Minute*3, &emptyLogger, nil, true) - if err != nil { - b.Fatal(err) - } + go server.ServeConn(serverConn) + + frisbeeConn := NewAsync(clientConn, &emptyLogger) data := make([]byte, packetSize) _, _ = rand.Read(data) @@ -263,6 +269,11 @@ func BenchmarkThroughputResponseServer(b *testing.B) { const testSize = 1<<16 - 1 const packetSize = 512 + serverConn, clientConn, err := pair.New() + if err != nil { + b.Fatal(err) + } + handlerTable := make(HandlerTable) handlerTable[metadata.PacketPing] = func(_ context.Context, incoming *packet.Packet) (outgoing *packet.Packet, action Action) { @@ -276,20 +287,14 @@ func BenchmarkThroughputResponseServer(b *testing.B) { } emptyLogger := zerolog.New(ioutil.Discard) - server, err := NewServer(":0", handlerTable, WithLogger(&emptyLogger)) + server, err := NewServer(handlerTable, WithLogger(&emptyLogger)) if err != nil { b.Fatal(err) } - err = server.Start() - if err != nil { - b.Fatal(err) - } + go server.ServeConn(serverConn) - frisbeeConn, err := ConnectAsync(server.listener.Addr().String(), time.Minute*3, &emptyLogger, nil, true) - if err != nil { - b.Fatal(err) - } + frisbeeConn := NewAsync(clientConn, &emptyLogger) data := make([]byte, packetSize) _, _ = rand.Read(data) @@ -318,7 +323,11 @@ func BenchmarkThroughputResponseServer(b *testing.B) { } if readPacket.Metadata.Id != testSize { - b.Fatal("invalid decoded metadata id") + b.Fatal("invalid decoded metadata id", readPacket.Metadata.Id) + } + + if readPacket.Metadata.Operation != metadata.PacketPong { + b.Fatal("invalid decoded operation", readPacket.Metadata.Operation) } packet.Put(readPacket) } @@ -337,3 +346,104 @@ func BenchmarkThroughputResponseServer(b *testing.B) { b.Fatal(err) } } + +func BenchmarkAsyncThroughputNetworkMultiple(b *testing.B) { + const testSize = 100 + + throughputRunner := func(testSize uint32, packetSize uint32, readerConn Conn, writerConn Conn) func(b *testing.B) { + return func(b *testing.B) { + var err error + + randomData := make([]byte, packetSize) + + p := packet.Get() + p.Metadata.Id = 64 + p.Metadata.Operation = 32 + p.Content.Write(randomData) + p.Metadata.ContentLength = packetSize + for i := 0; i < b.N; i++ { + done := make(chan struct{}, 1) + errCh := make(chan error, 1) + go func() { + for i := uint32(0); i < testSize; i++ { + p, err := readerConn.ReadPacket() + if err != nil { + errCh <- err + return + } + packet.Put(p) + } + done <- struct{}{} + }() + for i := uint32(0); i < testSize; i++ { + select { + case err = <-errCh: + b.Fatal(err) + default: + err = writerConn.WritePacket(p) + if err != nil { + b.Fatal(err) + } + } + } + select { + case <-done: + continue + case err = <-errCh: + b.Fatal(err) + } + } + + packet.Put(p) + } + } + + runner := func(numClients int, packetSize uint32) func(b *testing.B) { + return func(b *testing.B) { + var wg sync.WaitGroup + wg.Add(numClients) + b.SetBytes(int64(testSize * packetSize)) + b.ReportAllocs() + for i := 0; i < numClients; i++ { + go func() { + emptyLogger := zerolog.New(ioutil.Discard) + + reader, writer, err := pair.New() + if err != nil { + b.Error(err) + } + + readerConn := NewAsync(reader, &emptyLogger) + writerConn := NewAsync(writer, &emptyLogger) + throughputRunner(testSize, packetSize, readerConn, writerConn)(b) + + _ = readerConn.Close() + _ = writerConn.Close() + wg.Done() + }() + } + wg.Wait() + } + } + + b.Run("1 Pair, 32 Bytes", runner(1, 32)) + b.Run("2 Pair, 32 Bytes", runner(2, 32)) + b.Run("5 Pair, 32 Bytes", runner(5, 32)) + b.Run("10 Pair, 32 Bytes", runner(10, 32)) + b.Run("Half CPU Pair, 32 Bytes", runner(runtime.NumCPU()/2, 32)) + b.Run("CPU Pair, 32 Bytes", runner(runtime.NumCPU(), 32)) + + b.Run("1 Pair, 512 Bytes", runner(1, 512)) + b.Run("2 Pair, 512 Bytes", runner(2, 512)) + b.Run("5 Pair, 512 Bytes", runner(5, 512)) + b.Run("10 Pair, 512 Bytes", runner(10, 512)) + b.Run("Half CPU Pair, 512 Bytes", runner(runtime.NumCPU()/2, 512)) + b.Run("CPU Pair, 512 Bytes", runner(runtime.NumCPU(), 512)) + + b.Run("1 Pair, 4096 Bytes", runner(1, 4096)) + b.Run("2 Pair, 4096 Bytes", runner(2, 4096)) + b.Run("5 Pair, 4096 Bytes", runner(5, 4096)) + b.Run("10 Pair, 4096 Bytes", runner(10, 4096)) + b.Run("Half CPU Pair, 4096 Bytes", runner(runtime.NumCPU()/2, 4096)) + b.Run("CPU Pair, 4096 Bytes", runner(runtime.NumCPU(), 4096)) +} diff --git a/sync.go b/sync.go index cc169e5..c96c359 100644 --- a/sync.go +++ b/sync.go @@ -53,8 +53,10 @@ func ConnectSync(addr string, keepAlive time.Duration, logger *zerolog.Logger, T conn, err = tls.Dial("tcp", addr, TLSConfig) } else { conn, err = net.Dial("tcp", addr) - _ = conn.(*net.TCPConn).SetKeepAlive(true) - _ = conn.(*net.TCPConn).SetKeepAlivePeriod(keepAlive) + if err == nil { + _ = conn.(*net.TCPConn).SetKeepAlive(true) + _ = conn.(*net.TCPConn).SetKeepAlivePeriod(keepAlive) + } } if err != nil { diff --git a/throughput_test.go b/throughput_test.go index 45e575a..647a834 100644 --- a/throughput_test.go +++ b/throughput_test.go @@ -40,8 +40,8 @@ func BenchmarkAsyncThroughputLarge(b *testing.B) { b.Fatal(err) } - readerConn := NewAsync(reader, &emptyLogger, false) - writerConn := NewAsync(writer, &emptyLogger, false) + readerConn := NewAsync(reader, &emptyLogger) + writerConn := NewAsync(writer, &emptyLogger) b.Run("1MB", throughputRunner(testSize, 1<<20, readerConn, writerConn)) b.Run("2MB", throughputRunner(testSize, 1<<21, readerConn, writerConn))