diff --git a/CHANGELOG.md b/CHANGELOG.md index f56cb5a..898427a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +## [v0.3.2] - 2022-03-18 (Beta) + +## Changes + +- Swapping the lock-free Queue out with a simpler locking queue that has significantly less lock contention in scenarios + when multiple buffers are required. +- Refactoring server and client to spawn fewer goroutines per-connection. + ## [v0.3.1] - 2022-03-12 (Beta) ### Fixes @@ -171,7 +179,9 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). Initial Release of Frisbee -[unreleased]: https://github.com/loopholelabs/frisbee/compare/v0.3.0...HEAD +[unreleased]: https://github.com/loopholelabs/frisbee/compare/v0.3.2...HEAD +[v0.3.2]: https://github.com/loopholelabs/frisbee/compare/v0.3.1...v0.3.2 +[v0.3.1]: https://github.com/loopholelabs/frisbee/compare/v0.3.0...v0.3.1 [v0.3.0]: https://github.com/loopholelabs/frisbee/compare/v0.2.4...v0.3.0 [v0.2.4]: https://github.com/loopholelabs/frisbee/compare/v0.2.3...v0.2.4 [v0.2.3]: https://github.com/loopholelabs/frisbee/compare/v0.2.2...v0.2.3 diff --git a/async.go b/async.go index 96e55aa..2fbf3dd 100644 --- a/async.go +++ b/async.go @@ -42,7 +42,7 @@ type Async struct { closed *atomic.Bool writer *bufio.Writer flusher chan struct{} - incoming *queue.Queue + incoming *queue.Circular logger *zerolog.Logger wg sync.WaitGroup error *atomic.Error @@ -80,7 +80,7 @@ func NewAsync(c net.Conn, logger *zerolog.Logger, blocking bool) (conn *Async) { conn: c, closed: atomic.NewBool(false), writer: bufio.NewWriterSize(c, DefaultBufferSize), - incoming: queue.New(DefaultBufferSize, blocking), + incoming: queue.NewCircular(DefaultBufferSize), flusher: make(chan struct{}, 3), logger: logger, error: atomic.NewError(nil), diff --git a/conn.go b/conn.go index abb82dd..1f7c664 100644 --- a/conn.go +++ b/conn.go @@ -22,16 +22,16 @@ import ( "github.com/loopholelabs/frisbee/pkg/packet" "github.com/pkg/errors" "github.com/rs/zerolog" + "io/ioutil" "net" - "os" "time" ) // DefaultBufferSize is the size of the default buffer -const DefaultBufferSize = 1 << 19 +const DefaultBufferSize = 1 << 16 var ( - defaultLogger = zerolog.New(os.Stdout) + defaultLogger = zerolog.New(ioutil.Discard) defaultDeadline = time.Second diff --git a/internal/queue/circular.go b/internal/queue/circular.go new file mode 100644 index 0000000..6c04242 --- /dev/null +++ b/internal/queue/circular.go @@ -0,0 +1,170 @@ +/* + Copyright 2022 Loophole Labs + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package queue + +import ( + "github.com/loopholelabs/frisbee/pkg/packet" + "sync" + "unsafe" +) + +type Circular struct { + _padding0 [8]uint64 //nolint:structcheck,unused + head uint64 + _padding1 [8]uint64 //nolint:structcheck,unused + tail uint64 + _padding2 [8]uint64 //nolint:structcheck,unused + maxSize uint64 + _padding3 [8]uint64 //nolint:structcheck,unused + closed bool + _padding4 [8]uint64 //nolint:structcheck,unused + lock *sync.Mutex + _padding5 [8]uint64 //nolint:structcheck,unused + notEmpty *sync.Cond + _padding6 [8]uint64 //nolint:structcheck,unused + notFull *sync.Cond + _padding7 [8]uint64 //nolint:structcheck,unused + nodes []unsafe.Pointer +} + +func NewCircular(maxSize uint64) *Circular { + q := &Circular{} + q.lock = &sync.Mutex{} + q.notFull = sync.NewCond(q.lock) + q.notEmpty = sync.NewCond(q.lock) + + q.head = 0 + q.tail = 0 + maxSize++ + if maxSize < 2 { + q.maxSize = 2 + } else { + q.maxSize = round(maxSize) + } + + q.nodes = make([]unsafe.Pointer, q.maxSize) + return q +} + +func (q *Circular) IsEmpty() (empty bool) { + q.lock.Lock() + empty = q.isEmpty() + q.lock.Unlock() + return +} + +func (q *Circular) isEmpty() bool { + return q.head == q.tail +} + +func (q *Circular) IsFull() (full bool) { + q.lock.Lock() + full = q.isFull() + q.lock.Unlock() + return +} + +func (q *Circular) isFull() bool { + return q.head == (q.tail+1)%q.maxSize +} + +func (q *Circular) IsClosed() (closed bool) { + q.lock.Lock() + closed = q.isClosed() + q.lock.Unlock() + return +} + +func (q *Circular) isClosed() bool { + return q.closed +} + +func (q *Circular) Length() (size int) { + q.lock.Lock() + size = q.length() + q.lock.Unlock() + return +} + +func (q *Circular) length() int { + if q.tail < q.head { + return int(q.maxSize - q.head + q.tail) + } + return int(q.tail - q.head) +} + +func (q *Circular) Close() { + q.lock.Lock() + q.closed = true + q.notFull.Broadcast() + q.notEmpty.Broadcast() + q.lock.Unlock() +} + +func (q *Circular) Push(p *packet.Packet) error { + q.lock.Lock() +LOOP: + if q.isClosed() { + q.lock.Unlock() + return Closed + } + if q.isFull() { + q.notFull.Wait() + goto LOOP + } + + q.nodes[q.tail] = unsafe.Pointer(p) + q.tail = (q.tail + 1) % q.maxSize + q.notEmpty.Signal() + q.lock.Unlock() + return nil +} + +func (q *Circular) Pop() (p *packet.Packet, err error) { + q.lock.Lock() +LOOP: + if q.isClosed() { + q.lock.Unlock() + return nil, Closed + } + if q.isEmpty() { + q.notEmpty.Wait() + goto LOOP + } + + p = (*packet.Packet)(q.nodes[q.head]) + q.head = (q.head + 1) % q.maxSize + q.notFull.Signal() + q.lock.Unlock() + return +} + +func (q *Circular) Drain() (packets []*packet.Packet) { + q.lock.Lock() + if q.isEmpty() { + q.lock.Unlock() + return nil + } + if size := int(q.head) - int(q.tail); size > 0 { + packets = make([]*packet.Packet, 0, size) + } else { + packets = make([]*packet.Packet, 0, -1*size) + } + for i := 0; i < cap(packets); i++ { + packets = append(packets, (*packet.Packet)(q.nodes[q.head])) + q.head = (q.head + 1) % q.maxSize + } + q.lock.Unlock() + return packets +} diff --git a/internal/queue/circular_test.go b/internal/queue/circular_test.go new file mode 100644 index 0000000..ce51d9b --- /dev/null +++ b/internal/queue/circular_test.go @@ -0,0 +1,319 @@ +/* + Copyright 2022 Loophole Labs + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package queue + +import ( + "github.com/loopholelabs/frisbee/pkg/packet" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +func TestCircular(t *testing.T) { + t.Parallel() + + testPacket := func() *packet.Packet { + return packet.Get() + } + testPacket2 := func() *packet.Packet { + p := packet.Get() + p.Content.Write([]byte{1}) + return p + } + + t.Run("success", func(t *testing.T) { + rb := NewCircular(1) + p := testPacket() + err := rb.Push(p) + assert.NoError(t, err) + actual, err := rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p, actual) + }) + t.Run("out of capacity", func(t *testing.T) { + rb := NewCircular(0) + err := rb.Push(testPacket()) + assert.NoError(t, err) + }) + t.Run("out of capacity with non zero capacity, blocking", func(t *testing.T) { + rb := NewCircular(1) + p1 := testPacket() + err := rb.Push(p1) + assert.NoError(t, err) + doneCh := make(chan struct{}, 1) + p2 := testPacket2() + go func() { + err = rb.Push(p2) + assert.NoError(t, err) + doneCh <- struct{}{} + }() + select { + case <-doneCh: + t.Fatal("LockFree did not block on full write") + case <-time.After(time.Millisecond * 10): + actual, err := rb.Pop() + require.NoError(t, err) + assert.Equal(t, p1, actual) + select { + case <-doneCh: + actual, err := rb.Pop() + require.NoError(t, err) + assert.Equal(t, p2, actual) + case <-time.After(time.Millisecond * 10): + t.Fatal("Circular did not unblock on read from full write") + } + } + }) + t.Run("length calculations", func(t *testing.T) { + rb := NewCircular(1) + p1 := testPacket() + + err := rb.Push(p1) + assert.NoError(t, err) + assert.Equal(t, 1, rb.Length()) + assert.Equal(t, uint64(0), rb.head) + assert.Equal(t, uint64(1), rb.tail) + + actual, err := rb.Pop() + require.NoError(t, err) + assert.Equal(t, p1, actual) + assert.Equal(t, 0, rb.Length()) + assert.Equal(t, uint64(1), rb.head) + assert.Equal(t, uint64(1), rb.tail) + + err = rb.Push(p1) + assert.NoError(t, err) + assert.Equal(t, 1, rb.Length()) + assert.Equal(t, uint64(1), rb.head) + assert.Equal(t, uint64(0), rb.tail) + + rb = NewCircular(4) + + err = rb.Push(p1) + assert.NoError(t, err) + assert.Equal(t, 1, rb.Length()) + assert.Equal(t, uint64(0), rb.head) + assert.Equal(t, uint64(1), rb.tail) + + p2 := testPacket2() + err = rb.Push(p2) + assert.NoError(t, err) + assert.Equal(t, 2, rb.Length()) + assert.Equal(t, uint64(0), rb.head) + assert.Equal(t, uint64(2), rb.tail) + + err = rb.Push(p2) + assert.NoError(t, err) + assert.Equal(t, 3, rb.Length()) + assert.Equal(t, uint64(0), rb.head) + assert.Equal(t, uint64(3), rb.tail) + + actual, err = rb.Pop() + require.NoError(t, err) + assert.Equal(t, p1, actual) + assert.Equal(t, 2, rb.Length()) + assert.Equal(t, uint64(1), rb.head) + assert.Equal(t, uint64(3), rb.tail) + + actual, err = rb.Pop() + require.NoError(t, err) + assert.Equal(t, p2, actual) + assert.Equal(t, 1, rb.Length()) + assert.Equal(t, uint64(2), rb.head) + assert.Equal(t, uint64(3), rb.tail) + + err = rb.Push(p2) + assert.NoError(t, err) + assert.Equal(t, 2, rb.Length()) + assert.Equal(t, uint64(2), rb.head) + assert.Equal(t, uint64(4), rb.tail) + + err = rb.Push(p2) + assert.NoError(t, err) + assert.Equal(t, 3, rb.Length()) + assert.Equal(t, uint64(2), rb.head) + assert.Equal(t, uint64(5), rb.tail) + + actual, err = rb.Pop() + require.NoError(t, err) + assert.Equal(t, p2, actual) + assert.Equal(t, 2, rb.Length()) + assert.Equal(t, uint64(3), rb.head) + assert.Equal(t, uint64(5), rb.tail) + + actual, err = rb.Pop() + require.NoError(t, err) + assert.Equal(t, p2, actual) + assert.Equal(t, 1, rb.Length()) + assert.Equal(t, uint64(4), rb.head) + assert.Equal(t, uint64(5), rb.tail) + + actual, err = rb.Pop() + require.NoError(t, err) + assert.Equal(t, p2, actual) + assert.Equal(t, 0, rb.Length()) + assert.Equal(t, uint64(5), rb.head) + assert.Equal(t, uint64(5), rb.tail) + }) + t.Run("buffer closed", func(t *testing.T) { + rb := NewCircular(1) + assert.False(t, rb.IsClosed()) + rb.Close() + assert.True(t, rb.IsClosed()) + err := rb.Push(testPacket()) + assert.ErrorIs(t, Closed, err) + _, err = rb.Pop() + assert.ErrorIs(t, Closed, err) + }) + t.Run("pop empty", func(t *testing.T) { + done := make(chan struct{}, 1) + rb := NewCircular(1) + go func() { + _, _ = rb.Pop() + done <- struct{}{} + }() + assert.Equal(t, 0, len(done)) + _ = rb.Push(testPacket()) + <-done + assert.Equal(t, 0, rb.Length()) + }) + t.Run("partial overflow, blocking", func(t *testing.T) { + rb := NewCircular(4) + p1 := testPacket() + p1.Metadata.Id = 1 + + p2 := testPacket() + p2.Metadata.Id = 2 + + p3 := testPacket() + p3.Metadata.Id = 3 + + p4 := testPacket() + p4.Metadata.Id = 4 + + p5 := testPacket() + p5.Metadata.Id = 5 + + err := rb.Push(p1) + assert.NoError(t, err) + err = rb.Push(p2) + assert.NoError(t, err) + err = rb.Push(p3) + assert.NoError(t, err) + + assert.Equal(t, 3, rb.Length()) + + actual, err := rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p1, actual) + assert.Equal(t, 2, rb.Length()) + + err = rb.Push(p4) + assert.NoError(t, err) + err = rb.Push(p5) + assert.NoError(t, err) + + assert.Equal(t, 4, rb.Length()) + + actual, err = rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p2, actual) + + assert.Equal(t, 3, rb.Length()) + + actual, err = rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p3, actual) + + assert.Equal(t, 2, rb.Length()) + + actual, err = rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p4, actual) + + assert.Equal(t, 1, rb.Length()) + + actual, err = rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p5, actual) + assert.NotEqual(t, p1, p5) + assert.Equal(t, 0, rb.Length()) + }) + t.Run("partial overflow, non-blocking", func(t *testing.T) { + rb := NewCircular(4) + p1 := testPacket() + p1.Metadata.Id = 1 + + p2 := testPacket() + p2.Metadata.Id = 2 + + p3 := testPacket() + p3.Metadata.Id = 3 + + p4 := testPacket() + p4.Metadata.Id = 4 + + p5 := testPacket() + p5.Metadata.Id = 5 + + p6 := testPacket() + p6.Metadata.Id = 6 + + err := rb.Push(p1) + assert.NoError(t, err) + err = rb.Push(p2) + assert.NoError(t, err) + err = rb.Push(p3) + assert.NoError(t, err) + err = rb.Push(p4) + assert.NoError(t, err) + + assert.Equal(t, 4, rb.Length()) + + err = rb.Push(p5) + assert.NoError(t, err) + + assert.Equal(t, 5, rb.Length()) + + err = rb.Push(p6) + assert.NoError(t, err) + + assert.Equal(t, 6, rb.Length()) + + actual, err := rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p1, actual) + + assert.Equal(t, 5, rb.Length()) + + actual, err = rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p2, actual) + + assert.Equal(t, 4, rb.Length()) + + actual, err = rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p3, actual) + + assert.Equal(t, 3, rb.Length()) + + actual, err = rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p4, actual) + assert.NotEqual(t, p1, p4) + assert.Equal(t, 2, rb.Length()) + }) +} diff --git a/internal/queue/lockfree.go b/internal/queue/lockfree.go new file mode 100644 index 0000000..271fdef --- /dev/null +++ b/internal/queue/lockfree.go @@ -0,0 +1,259 @@ +/* + Copyright 2022 Loophole Labs + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package queue + +import ( + "github.com/loopholelabs/frisbee/pkg/packet" + "runtime" + "sync/atomic" + "unsafe" +) + +// node is a struct that keeps track of its own position as well as a piece of data +// stored as an unsafe.Pointer. Normally we would store the pointer to a packet.Packet +// directly, however benchmarking shows performance improvements with unsafe.Pointer instead +type node struct { + _padding0 [8]uint64 //nolint:structcheck,unused + position uint64 + _padding1 [8]uint64 //nolint:structcheck,unused + data unsafe.Pointer +} + +// nodes is a struct type containing a slice of node pointers +type nodes []*node + +// LockFree is the struct used to store a blocking or non-blocking FIFO queue of type *packet.Packet +// +// In it's non-blocking form it acts as a ringbuffer, overwriting old data when new data arrives. In its blocking +// form it waits for a space in the queue to open up before it adds the item to the LockFree. +type LockFree struct { + _padding0 [8]uint64 //nolint:structcheck,unused + head uint64 + _padding1 [8]uint64 //nolint:structcheck,unused + tail uint64 + _padding2 [8]uint64 //nolint:structcheck,unused + mask uint64 + _padding3 [8]uint64 //nolint:structcheck,unused + closed uint64 + _padding4 [8]uint64 //nolint:structcheck,unused + nodes nodes + _padding5 [8]uint64 //nolint:structcheck,unused + overflow func() (uint64, error) +} + +// NewLockFree creates a new LockFree with blocking or non-blocking behavior +func NewLockFree(size uint64, blocking bool) *LockFree { + q := new(LockFree) + if size < 1 { + size = 1 + } + if blocking { + q.overflow = q.blocker + } else { + q.overflow = q.unblocker + } + q.init(size) + return q +} + +// init actually initializes a queue and can be used in the future to reuse LockFree structs +// with their own pool +func (q *LockFree) init(size uint64) { + size = round(size) + q.nodes = make(nodes, size) + for i := uint64(0); i < size; i++ { + q.nodes[i] = &node{position: i} + } + q.mask = size - 1 +} + +// blocker is a LockFree.overflow function that blocks a Push operation from +// proceeding if the LockFree is ever full of data. +// +// If two Push operations happen simultaneously, blocker will block both of them until +// a Pop takes place, and unblock both of them at the same time. This can cause problems, +// however in our use case it won't because there shouldn't ever be more than one producer +// operating on the LockFree at any given time. There may be multiple consumers in the future, +// but that won't cause any problems. +// +// If we decide to use this as an MPMC LockFree instead of a SPMC LockFree (which is how we currently use it) +// then we can solve this bug by replacing the existing `default` switch case in the Push function with the +// following snippet: +// ``` +// default: +// head, err = q.overflow() +// if err != nil { +// return err +// } +// ``` +func (q *LockFree) blocker() (head uint64, err error) { +LOOP: + head = atomic.LoadUint64(&q.head) + if uint64(len(q.nodes)) == head-atomic.LoadUint64(&q.tail) { + if atomic.LoadUint64(&q.closed) == 1 { + err = Closed + return + } + runtime.Gosched() + goto LOOP + } + return +} + +// unblocker is a LockFree.overflow function that unblocks a Push operation from +// proceeding if the LockFree is full of data. It does this by adding its own Pop() +// operation before proceeding with the Push attempt. +// +// If two Push operations happen simultaneously, unblocker will unblock them both +// by running two Pop() operations. This function will also be called whenever there +// is a Push conflict (when two Push operations attempt to modify the queue concurrently). +// +// In highly concurrent situations we may lose more data than we should, however since we will +// be using this as a SPMC LockFree, this conflict will never arise. +func (q *LockFree) unblocker() (head uint64, err error) { + head = atomic.LoadUint64(&q.head) + if uint64(len(q.nodes)) == head-atomic.LoadUint64(&q.tail) { + var p *packet.Packet + p, err = q.Pop() + packet.Put(p) + if err != nil { + return + } + } + return +} + +// Push appends an item of type *packet.Packet to the LockFree, and will block +// until the item is pushed successfully (with the blocking function depending +// on whether this is a blocking LockFree). +// +// This method is not meant to be used concurrently, and the LockFree is meant to operate +// as an SPMC LockFree with one producer operating at a time. If we want to use this as an MPMC LockFree +// we can modify this Push function by replacing the existing `default` switch case with the +// following snippet: +// ``` +// default: +// head, err = q.overflow() +// if err != nil { +// return err +// } +// ``` +func (q *LockFree) Push(item *packet.Packet) error { + var newNode *node + head, err := q.overflow() + if err != nil { + return err + } +RETRY: + for { + if atomic.LoadUint64(&q.closed) == 1 { + return Closed + } + + newNode = q.nodes[head&q.mask] + switch dif := atomic.LoadUint64(&newNode.position) - head; { + case dif == 0: + if atomic.CompareAndSwapUint64(&q.head, head, head+1) { + break RETRY + } + default: + head = atomic.LoadUint64(&q.head) + } + runtime.Gosched() + } + newNode.data = unsafe.Pointer(item) + atomic.StoreUint64(&newNode.position, head+1) + return nil +} + +// Pop removes an item from the start of the LockFree and returns it to the caller. +// This method blocks until an item is available, but unblocks when the LockFree is closed. +// This allows for long-term listeners to wait on the LockFree until either an item is available +// or the LockFree is closed. +// +// This method is safe to be used concurrently and is even optimized for the SPMC use case. +func (q *LockFree) Pop() (*packet.Packet, error) { + var oldNode *node + var oldPosition = atomic.LoadUint64(&q.tail) +RETRY: + if atomic.LoadUint64(&q.closed) == 1 { + return nil, Closed + } + + oldNode = q.nodes[oldPosition&q.mask] + switch dif := atomic.LoadUint64(&oldNode.position) - (oldPosition + 1); { + case dif == 0: + if atomic.CompareAndSwapUint64(&q.tail, oldPosition, oldPosition+1) { + goto DONE + } + default: + oldPosition = atomic.LoadUint64(&q.tail) + } + runtime.Gosched() + goto RETRY +DONE: + data := oldNode.data + oldNode.data = nil + atomic.StoreUint64(&oldNode.position, oldPosition+q.mask+1) + return (*packet.Packet)(data), nil +} + +// Close marks the LockFree as closed, returns any waiting Pop() calls, +// and blocks all future Push calls from occurring. +func (q *LockFree) Close() { + atomic.CompareAndSwapUint64(&q.closed, 0, 1) +} + +// IsClosed returns whether the LockFree has been closed +func (q *LockFree) IsClosed() bool { + return atomic.LoadUint64(&q.closed) == 1 +} + +// Length is the current number of items in the LockFree +func (q *LockFree) Length() int { + return int(atomic.LoadUint64(&q.head) - atomic.LoadUint64(&q.tail)) +} + +// Drain drains all the current packets in the queue and returns them to the caller. +// +// It is an unsafe function that should only be used once, only after the queue has been closed, +// and only while there are no producers writing to it. If used incorrectly it has the potential +// to infinitely block the caller. If used correctly, it allows a single caller to drain any remaining +// packets in the queue after the queue has been closed. +func (q *LockFree) Drain() []*packet.Packet { + length := q.Length() + packets := make([]*packet.Packet, 0, length) + for i := 0; i < length; i++ { + var oldNode *node + var oldPosition = atomic.LoadUint64(&q.tail) + RETRY: + oldNode = q.nodes[oldPosition&q.mask] + switch dif := atomic.LoadUint64(&oldNode.position) - (oldPosition + 1); { + case dif == 0: + if atomic.CompareAndSwapUint64(&q.tail, oldPosition, oldPosition+1) { + goto DONE + } + default: + oldPosition = atomic.LoadUint64(&q.tail) + } + runtime.Gosched() + goto RETRY + DONE: + data := oldNode.data + oldNode.data = nil + atomic.StoreUint64(&oldNode.position, oldPosition+q.mask+1) + packets = append(packets, (*packet.Packet)(data)) + } + return packets +} diff --git a/internal/queue/lockfree_test.go b/internal/queue/lockfree_test.go new file mode 100644 index 0000000..506614b --- /dev/null +++ b/internal/queue/lockfree_test.go @@ -0,0 +1,243 @@ +/* + Copyright 2022 Loophole Labs + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package queue + +import ( + "github.com/loopholelabs/frisbee/pkg/packet" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +func TestLockFree(t *testing.T) { + t.Parallel() + + testPacket := func() *packet.Packet { + return packet.Get() + } + testPacket2 := func() *packet.Packet { + p := packet.Get() + p.Content.Write([]byte{1}) + return p + } + + t.Run("success", func(t *testing.T) { + rb := NewLockFree(1, false) + p := testPacket() + err := rb.Push(p) + assert.NoError(t, err) + actual, err := rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p, actual) + }) + t.Run("out of capacity", func(t *testing.T) { + rb := NewLockFree(0, false) + err := rb.Push(testPacket()) + assert.NoError(t, err) + }) + t.Run("out of capacity with non zero capacity, blocking", func(t *testing.T) { + rb := NewLockFree(1, true) + p1 := testPacket() + err := rb.Push(p1) + assert.NoError(t, err) + doneCh := make(chan struct{}, 1) + p2 := testPacket2() + go func() { + err = rb.Push(p2) + assert.NoError(t, err) + doneCh <- struct{}{} + }() + select { + case <-doneCh: + t.Fatal("LockFree did not block on full write") + case <-time.After(time.Millisecond * 10): + actual, err := rb.Pop() + require.NoError(t, err) + assert.Equal(t, p1, actual) + select { + case <-doneCh: + actual, err := rb.Pop() + require.NoError(t, err) + assert.Equal(t, p2, actual) + case <-time.After(time.Millisecond * 10): + t.Fatal("LockFree did not unblock on read from full write") + } + } + }) + t.Run("out of capacity with non zero capacity, non-blocking", func(t *testing.T) { + rb := NewLockFree(1, false) + p1 := testPacket() + err := rb.Push(p1) + assert.NoError(t, err) + assert.Equal(t, 1, rb.Length()) + p2 := testPacket2() + err = rb.Push(p2) + assert.NoError(t, err) + assert.Equal(t, 1, rb.Length()) + actual, err := rb.Pop() + require.NoError(t, err) + assert.Equal(t, p2, actual) + assert.Equal(t, 0, rb.Length()) + }) + t.Run("buffer closed", func(t *testing.T) { + rb := NewLockFree(1, false) + assert.False(t, rb.IsClosed()) + rb.Close() + assert.True(t, rb.IsClosed()) + err := rb.Push(testPacket()) + assert.ErrorIs(t, Closed, err) + _, err = rb.Pop() + assert.ErrorIs(t, Closed, err) + }) + t.Run("pop empty", func(t *testing.T) { + done := make(chan struct{}, 1) + rb := NewLockFree(1, false) + go func() { + _, _ = rb.Pop() + done <- struct{}{} + }() + assert.Equal(t, 0, len(done)) + _ = rb.Push(testPacket()) + <-done + assert.Equal(t, 0, rb.Length()) + }) + t.Run("partial overflow, blocking", func(t *testing.T) { + rb := NewLockFree(4, true) + p1 := testPacket() + p1.Metadata.Id = 1 + + p2 := testPacket() + p2.Metadata.Id = 2 + + p3 := testPacket() + p3.Metadata.Id = 3 + + p4 := testPacket() + p4.Metadata.Id = 4 + + p5 := testPacket() + p5.Metadata.Id = 5 + + err := rb.Push(p1) + assert.NoError(t, err) + err = rb.Push(p2) + assert.NoError(t, err) + err = rb.Push(p3) + assert.NoError(t, err) + + assert.Equal(t, 3, rb.Length()) + + actual, err := rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p1, actual) + assert.Equal(t, 2, rb.Length()) + + err = rb.Push(p4) + assert.NoError(t, err) + err = rb.Push(p5) + assert.NoError(t, err) + + assert.Equal(t, 4, rb.Length()) + + actual, err = rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p2, actual) + + assert.Equal(t, 3, rb.Length()) + + actual, err = rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p3, actual) + + assert.Equal(t, 2, rb.Length()) + + actual, err = rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p4, actual) + + assert.Equal(t, 1, rb.Length()) + + actual, err = rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p5, actual) + assert.NotEqual(t, p1, p5) + assert.Equal(t, 0, rb.Length()) + }) + t.Run("partial overflow, non-blocking", func(t *testing.T) { + rb := NewLockFree(4, false) + p1 := testPacket() + p1.Metadata.Id = 1 + + p2 := testPacket() + p2.Metadata.Id = 2 + + p3 := testPacket() + p3.Metadata.Id = 3 + + p4 := testPacket() + p4.Metadata.Id = 4 + + p5 := testPacket() + p5.Metadata.Id = 5 + + p6 := testPacket() + p6.Metadata.Id = 6 + + err := rb.Push(p1) + assert.NoError(t, err) + err = rb.Push(p2) + assert.NoError(t, err) + err = rb.Push(p3) + assert.NoError(t, err) + err = rb.Push(p4) + assert.NoError(t, err) + + assert.Equal(t, 4, rb.Length()) + + err = rb.Push(p5) + assert.NoError(t, err) + + assert.Equal(t, 4, rb.Length()) + + err = rb.Push(p6) + assert.NoError(t, err) + + assert.Equal(t, 4, rb.Length()) + + actual, err := rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p3, actual) + + assert.Equal(t, 3, rb.Length()) + + actual, err = rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p4, actual) + + assert.Equal(t, 2, rb.Length()) + + actual, err = rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p5, actual) + + assert.Equal(t, 1, rb.Length()) + + actual, err = rb.Pop() + assert.NoError(t, err) + assert.Equal(t, p6, actual) + assert.NotEqual(t, p1, p6) + assert.Equal(t, 0, rb.Length()) + }) +} diff --git a/internal/queue/queue.go b/internal/queue/queue.go index ed5fe3b..51e2d64 100644 --- a/internal/queue/queue.go +++ b/internal/queue/queue.go @@ -1,12 +1,9 @@ /* Copyright 2022 Loophole Labs - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,11 +14,7 @@ package queue import ( - "github.com/loopholelabs/frisbee/pkg/packet" "github.com/pkg/errors" - "runtime" - "sync/atomic" - "unsafe" ) var ( @@ -40,241 +33,3 @@ func round(value uint64) uint64 { value++ return value } - -// node is a struct that keeps track of its own position as well as a piece of data -// stored as an unsafe.Pointer. Normally we would store the pointer to a packet.Packet -// directly, however benchmarking shows performance improvements with unsafe.Pointer instead -type node struct { - _padding0 [8]uint64 //nolint:structcheck,unused - position uint64 - _padding1 [8]uint64 //nolint:structcheck,unused - data unsafe.Pointer -} - -// nodes is a struct type containing a slice of node pointers -type nodes []*node - -// Queue is the struct used to store a blocking or non-blocking FIFO queue of type *packet.Packet -// -// In it's non-blocking form it acts as a ringbuffer, overwriting old data when new data arrives. In its blocking -// form it waits for a space in the queue to open up before it adds the item to the Queue. -type Queue struct { - _padding0 [8]uint64 //nolint:structcheck,unused - head uint64 - _padding1 [8]uint64 //nolint:structcheck,unused - tail uint64 - _padding2 [8]uint64 //nolint:structcheck,unused - mask uint64 - _padding3 [8]uint64 //nolint:structcheck,unused - closed uint64 - _padding4 [8]uint64 //nolint:structcheck,unused - nodes nodes - _padding5 [8]uint64 //nolint:structcheck,unused - overflow func() (uint64, error) -} - -// New creates a new Queue with blocking or non-blocking behavior -func New(size uint64, blocking bool) *Queue { - q := new(Queue) - if size < 1 { - size = 1 - } - if blocking { - q.overflow = q.blocker - } else { - q.overflow = q.unblocker - } - q.init(size) - return q -} - -// init actually initializes a queue and can be used in the future to reuse Queue structs -// with their own pool -func (q *Queue) init(size uint64) { - size = round(size) - q.nodes = make(nodes, size) - for i := uint64(0); i < size; i++ { - q.nodes[i] = &node{position: i} - } - q.mask = size - 1 -} - -// blocker is a Queue.overflow function that blocks a Push operation from -// proceeding if the Queue is ever full of data. -// -// If two Push operations happen simultaneously, blocker will block both of them until -// a Pop takes place, and unblock both of them at the same time. This can cause problems, -// however in our use case it won't because there shouldn't ever be more than one producer -// operating on the Queue at any given time. There may be multiple consumers in the future, -// but that won't cause any problems. -// -// If we decide to use this as an MPMC Queue instead of a SPMC Queue (which is how we currently use it) -// then we can solve this bug by replacing the existing `default` switch case in the Push function with the -// following snippet: -// ``` -// default: -// head, err = q.overflow() -// if err != nil { -// return err -// } -// ``` -func (q *Queue) blocker() (head uint64, err error) { -LOOP: - head = atomic.LoadUint64(&q.head) - if uint64(len(q.nodes)) == head-atomic.LoadUint64(&q.tail) { - if atomic.LoadUint64(&q.closed) == 1 { - err = Closed - return - } - runtime.Gosched() - goto LOOP - } - return -} - -// unblocker is a Queue.overflow function that unblocks a Push operation from -// proceeding if the Queue is full of data. It does this by adding its own Pop() -// operation before proceeding with the Push attempt. -// -// If two Push operations happen simultaneously, unblocker will unblock them both -// by running two Pop() operations. This function will also be called whenever there -// is a Push conflict (when two Push operations attempt to modify the queue concurrently). -// -// In highly concurrent situations we may lose more data than we should, however since we will -// be using this as a SPMC Queue, this conflict will never arise. -func (q *Queue) unblocker() (head uint64, err error) { - head = atomic.LoadUint64(&q.head) - if uint64(len(q.nodes)) == head-atomic.LoadUint64(&q.tail) { - var p *packet.Packet - p, err = q.Pop() - packet.Put(p) - if err != nil { - return - } - } - return -} - -// Push appends an item of type *packet.Packet to the Queue, and will block -// until the item is pushed successfully (with the blocking function depending -// on whether this is a blocking Queue). -// -// This method is not meant to be used concurrently, and the Queue is meant to operate -// as an SPMC Queue with one producer operating at a time. If we want to use this as an MPMC Queue -// we can modify this Push function by replacing the existing `default` switch case with the -// following snippet: -// ``` -// default: -// head, err = q.overflow() -// if err != nil { -// return err -// } -// ``` -func (q *Queue) Push(item *packet.Packet) error { - var newNode *node - head, err := q.overflow() - if err != nil { - return err - } -RETRY: - for { - if atomic.LoadUint64(&q.closed) == 1 { - return Closed - } - - newNode = q.nodes[head&q.mask] - switch dif := atomic.LoadUint64(&newNode.position) - head; { - case dif == 0: - if atomic.CompareAndSwapUint64(&q.head, head, head+1) { - break RETRY - } - default: - head = atomic.LoadUint64(&q.head) - } - runtime.Gosched() - } - newNode.data = unsafe.Pointer(item) - atomic.StoreUint64(&newNode.position, head+1) - return nil -} - -// Pop removes an item from the start of the Queue and returns it to the caller. -// This method blocks until an item is available, but unblocks when the Queue is closed. -// This allows for long-term listeners to wait on the Queue until either an item is available -// or the Queue is closed. -// -// This method is safe to be used concurrently and is even optimized for the SPMC use case. -func (q *Queue) Pop() (*packet.Packet, error) { - var oldNode *node - var oldPosition = atomic.LoadUint64(&q.tail) -RETRY: - if atomic.LoadUint64(&q.closed) == 1 { - return nil, Closed - } - - oldNode = q.nodes[oldPosition&q.mask] - switch dif := atomic.LoadUint64(&oldNode.position) - (oldPosition + 1); { - case dif == 0: - if atomic.CompareAndSwapUint64(&q.tail, oldPosition, oldPosition+1) { - goto DONE - } - default: - oldPosition = atomic.LoadUint64(&q.tail) - } - runtime.Gosched() - goto RETRY -DONE: - data := oldNode.data - oldNode.data = nil - atomic.StoreUint64(&oldNode.position, oldPosition+q.mask+1) - return (*packet.Packet)(data), nil -} - -// Close marks the Queue as closed, returns any waiting Pop() calls, -// and blocks all future Push calls from occurring. -func (q *Queue) Close() { - atomic.CompareAndSwapUint64(&q.closed, 0, 1) -} - -// IsClosed returns whether the Queue has been closed -func (q *Queue) IsClosed() bool { - return atomic.LoadUint64(&q.closed) == 1 -} - -// Length is the current number of items in the Queue -func (q *Queue) Length() int { - return int(atomic.LoadUint64(&q.head) - atomic.LoadUint64(&q.tail)) -} - -// Drain drains all the current packets in the queue and returns them to the caller. -// -// It is an unsafe function that should only be used once, only after the queue has been closed, -// and only while there are no producers writing to it. If used incorrectly it has the potential -// to infinitely block the caller. If used correctly, it allows a single caller to drain any remaining -// packets in the queue after the queue has been closed. -func (q *Queue) Drain() []*packet.Packet { - length := q.Length() - packets := make([]*packet.Packet, 0, length) - for i := 0; i < length; i++ { - var oldNode *node - var oldPosition = atomic.LoadUint64(&q.tail) - RETRY: - oldNode = q.nodes[oldPosition&q.mask] - switch dif := atomic.LoadUint64(&oldNode.position) - (oldPosition + 1); { - case dif == 0: - if atomic.CompareAndSwapUint64(&q.tail, oldPosition, oldPosition+1) { - goto DONE - } - default: - oldPosition = atomic.LoadUint64(&q.tail) - } - runtime.Gosched() - goto RETRY - DONE: - data := oldNode.data - oldNode.data = nil - atomic.StoreUint64(&oldNode.position, oldPosition+q.mask+1) - packets = append(packets, (*packet.Packet)(data)) - } - return packets -} diff --git a/internal/queue/queue_test.go b/internal/queue/queue_test.go index 12d0d11..0bba933 100644 --- a/internal/queue/queue_test.go +++ b/internal/queue/queue_test.go @@ -1,12 +1,9 @@ /* Copyright 2022 Loophole Labs - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,258 +14,31 @@ package queue import ( - "github.com/loopholelabs/frisbee/pkg/packet" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "testing" - "time" ) -func TestHelpers(t *testing.T) { +func TestRound(t *testing.T) { t.Parallel() - - t.Run("test round", func(t *testing.T) { - tcs := []struct { - in uint64 - expected uint64 - }{ - {in: 0, expected: 0x0}, - {in: 1, expected: 0x1}, - {in: 2, expected: 0x2}, - {in: 3, expected: 0x4}, - {in: 4, expected: 0x4}, - {in: 5, expected: 0x8}, - {in: 7, expected: 0x8}, - {in: 8, expected: 0x8}, - {in: 9, expected: 0x10}, - {in: 16, expected: 0x10}, - {in: 32, expected: 0x20}, - {in: 0xFFFFFFF0, expected: 0x100000000}, - {in: 0xFFFFFFFF, expected: 0x100000000}, - } - for _, tc := range tcs { - assert.Equalf(t, tc.expected, round(tc.in), "in: %d", tc.in) - } - }) -} - -func TestQueue(t *testing.T) { - t.Parallel() - - testPacket := func() *packet.Packet { - return packet.Get() + tcs := []struct { + in uint64 + expected uint64 + }{ + {in: 0, expected: 0x0}, + {in: 1, expected: 0x1}, + {in: 2, expected: 0x2}, + {in: 3, expected: 0x4}, + {in: 4, expected: 0x4}, + {in: 5, expected: 0x8}, + {in: 7, expected: 0x8}, + {in: 8, expected: 0x8}, + {in: 9, expected: 0x10}, + {in: 16, expected: 0x10}, + {in: 32, expected: 0x20}, + {in: 0xFFFFFFF0, expected: 0x100000000}, + {in: 0xFFFFFFFF, expected: 0x100000000}, } - testPacket2 := func() *packet.Packet { - p := packet.Get() - p.Content.Write([]byte{1}) - return p + for _, tc := range tcs { + assert.Equalf(t, tc.expected, round(tc.in), "in: %d", tc.in) } - - t.Run("success", func(t *testing.T) { - rb := New(1, false) - p := testPacket() - err := rb.Push(p) - assert.NoError(t, err) - actual, err := rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p, actual) - }) - t.Run("out of capacity", func(t *testing.T) { - rb := New(0, false) - err := rb.Push(testPacket()) - assert.NoError(t, err) - }) - t.Run("out of capacity with non zero capacity, blocking", func(t *testing.T) { - rb := New(1, true) - p1 := testPacket() - err := rb.Push(p1) - assert.NoError(t, err) - doneCh := make(chan struct{}, 1) - p2 := testPacket2() - go func() { - err = rb.Push(p2) - assert.NoError(t, err) - doneCh <- struct{}{} - }() - select { - case <-doneCh: - t.Fatal("Queue did not block on full write") - case <-time.After(time.Millisecond * 10): - actual, err := rb.Pop() - require.NoError(t, err) - assert.Equal(t, p1, actual) - select { - case <-doneCh: - actual, err := rb.Pop() - require.NoError(t, err) - assert.Equal(t, p2, actual) - case <-time.After(time.Millisecond * 10): - t.Fatal("Queue did not unblock on read from full write") - } - } - }) - t.Run("out of capacity with non zero capacity, non-blocking", func(t *testing.T) { - rb := New(1, false) - p1 := testPacket() - err := rb.Push(p1) - assert.NoError(t, err) - assert.Equal(t, 1, rb.Length()) - p2 := testPacket2() - err = rb.Push(p2) - assert.NoError(t, err) - assert.Equal(t, 1, rb.Length()) - actual, err := rb.Pop() - require.NoError(t, err) - assert.Equal(t, p2, actual) - assert.Equal(t, 0, rb.Length()) - }) - t.Run("buffer closed", func(t *testing.T) { - rb := New(1, false) - assert.False(t, rb.IsClosed()) - rb.Close() - assert.True(t, rb.IsClosed()) - err := rb.Push(testPacket()) - assert.ErrorIs(t, Closed, err) - _, err = rb.Pop() - assert.ErrorIs(t, Closed, err) - }) - t.Run("pop empty", func(t *testing.T) { - done := make(chan struct{}, 1) - rb := New(1, false) - go func() { - _, _ = rb.Pop() - done <- struct{}{} - }() - assert.Equal(t, 0, len(done)) - _ = rb.Push(testPacket()) - <-done - assert.Equal(t, 0, rb.Length()) - }) - t.Run("partial overflow, blocking", func(t *testing.T) { - rb := New(4, true) - p1 := testPacket() - p1.Metadata.Id = 1 - - p2 := testPacket() - p2.Metadata.Id = 2 - - p3 := testPacket() - p3.Metadata.Id = 3 - - p4 := testPacket() - p4.Metadata.Id = 4 - - p5 := testPacket() - p5.Metadata.Id = 5 - - err := rb.Push(p1) - assert.NoError(t, err) - err = rb.Push(p2) - assert.NoError(t, err) - err = rb.Push(p3) - assert.NoError(t, err) - - assert.Equal(t, 3, rb.Length()) - - actual, err := rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p1, actual) - assert.Equal(t, 2, rb.Length()) - - err = rb.Push(p4) - assert.NoError(t, err) - err = rb.Push(p5) - assert.NoError(t, err) - - assert.Equal(t, 4, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p2, actual) - - assert.Equal(t, 3, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p3, actual) - - assert.Equal(t, 2, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p4, actual) - - assert.Equal(t, 1, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p5, actual) - assert.NotEqual(t, p1, p5) - assert.Equal(t, 0, rb.Length()) - }) - t.Run("partial overflow, non-blocking", func(t *testing.T) { - rb := New(4, false) - p1 := testPacket() - p1.Metadata.Id = 1 - - p2 := testPacket() - p2.Metadata.Id = 2 - - p3 := testPacket() - p3.Metadata.Id = 3 - - p4 := testPacket() - p4.Metadata.Id = 4 - - p5 := testPacket() - p5.Metadata.Id = 5 - - p6 := testPacket() - p6.Metadata.Id = 6 - - err := rb.Push(p1) - assert.NoError(t, err) - err = rb.Push(p2) - assert.NoError(t, err) - err = rb.Push(p3) - assert.NoError(t, err) - err = rb.Push(p4) - assert.NoError(t, err) - - assert.Equal(t, 4, rb.Length()) - - err = rb.Push(p5) - assert.NoError(t, err) - - assert.Equal(t, 4, rb.Length()) - - err = rb.Push(p6) - assert.NoError(t, err) - - assert.Equal(t, 4, rb.Length()) - - actual, err := rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p3, actual) - - assert.Equal(t, 3, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p4, actual) - - assert.Equal(t, 2, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p5, actual) - - assert.Equal(t, 1, rb.Length()) - - actual, err = rb.Pop() - assert.NoError(t, err) - assert.Equal(t, p6, actual) - assert.NotEqual(t, p1, p6) - assert.Equal(t, 0, rb.Length()) - }) } diff --git a/protoc-gen-frisbee/dockerfile b/protoc-gen-frisbee/dockerfile index bc12675..fe1263e 100644 --- a/protoc-gen-frisbee/dockerfile +++ b/protoc-gen-frisbee/dockerfile @@ -2,7 +2,7 @@ FROM golang as builder ENV GOOS=linux GOARCH=amd64 CGO_ENABLED=0 -RUN go install github.com/loopholelabs/frisbee/protoc-gen-frisbee@v0.3.0 +RUN go install github.com/loopholelabs/frisbee/protoc-gen-frisbee@v0.3.1 # Note, the Docker images must be built for amd64. If the host machine architecture is not amd64 # you need to cross-compile the binary and move it into /go/bin. @@ -12,7 +12,7 @@ FROM scratch # Runtime dependencies LABEL "build.buf.plugins.runtime_library_versions.0.name"="github.com/loopholelabs/frisbee" -LABEL "build.buf.plugins.runtime_library_versions.0.version"="v0.3.0" +LABEL "build.buf.plugins.runtime_library_versions.0.version"="v0.3.1" COPY --from=builder /go/bin / diff --git a/server.go b/server.go index ad02e56..afb1740 100644 --- a/server.go +++ b/server.go @@ -39,13 +39,14 @@ var ( // Server accepts connections from frisbee Clients and can send and receive frisbee Packets type Server struct { - listener net.Listener - addr string - handlerTable HandlerTable - shutdown *atomic.Bool - shutdownCh chan struct{} - options *Options - wg sync.WaitGroup + listener net.Listener + addr string + handlerTable HandlerTable + shutdown *atomic.Bool + options *Options + wg sync.WaitGroup + connections map[*Async]struct{} + connectionsMu sync.Mutex // BaseContext is used to define the base context for this Server and all incoming connections BaseContext func() context.Context @@ -91,7 +92,7 @@ func NewServer(addr string, handlerTable HandlerTable, opts ...Option) (*Server, handlerTable: handlerTable, options: options, shutdown: atomic.NewBool(false), - shutdownCh: make(chan struct{}), + connections: make(map[*Async]struct{}), }, nil } @@ -130,68 +131,31 @@ func (s *Server) Start() error { } func (s *Server) handleListener() { -LOOP: - newConn, err := s.listener.Accept() - if err != nil { - if s.shutdown.Load() { - s.wg.Done() - return - } - s.Logger().Fatal().Err(err).Msg("error while accepting connection") - s.wg.Done() - return - } - s.wg.Add(1) - go s.handleConn(newConn) - goto LOOP -} - -func (s *Server) connCloser(conn *Async) { - select { - case <-conn.CloseChannel(): - s.wg.Done() - case <-s.shutdownCh: - _ = conn.Close() - s.wg.Done() - } -} - -func (s *Server) handleConn(newConn net.Conn) { + var newConn net.Conn var err error - switch v := newConn.(type) { - case *net.TCPConn: - err = v.SetKeepAlive(true) - if err != nil { - s.Logger().Error().Err(err).Msg("Error while setting TCP Keepalive") - _ = v.Close() - s.wg.Done() - return - } - err = v.SetKeepAlivePeriod(s.options.KeepAlive) + for { + newConn, err = s.listener.Accept() if err != nil { - s.Logger().Error().Err(err).Msg("Error while setting TCP Keepalive Period") - _ = v.Close() + if s.shutdown.Load() { + s.wg.Done() + return + } + s.Logger().Fatal().Err(err).Msg("error while accepting connection") s.wg.Done() return } + s.wg.Add(1) + go s.handleConn(newConn) } +} - frisbeeConn := NewAsync(newConn, s.Logger(), true) - - s.wg.Add(1) - go s.connCloser(frisbeeConn) - - connCtx := s.BaseContext() - +func (s *Server) handlePacket(frisbeeConn *Async, connCtx context.Context) (err error) { var p *packet.Packet var outgoing *packet.Packet var action Action var handlerFunc Handler p, err = frisbeeConn.ReadPacket() if err != nil { - _ = frisbeeConn.Close() - s.OnClosed(frisbeeConn, err) - s.wg.Done() return } if s.ConnContext != nil { @@ -201,9 +165,6 @@ func (s *Server) handleConn(newConn net.Conn) { LOOP: p, err = frisbeeConn.ReadPacket() if err != nil { - _ = frisbeeConn.Close() - s.OnClosed(frisbeeConn, err) - s.wg.Done() return } HANDLE: @@ -222,9 +183,6 @@ HANDLE: } packet.Put(p) if err != nil { - _ = frisbeeConn.Close() - s.OnClosed(frisbeeConn, err) - s.wg.Done() return } } else { @@ -237,9 +195,6 @@ HANDLE: connCtx = s.UpdateContext(connCtx, frisbeeConn) } case CLOSE: - _ = frisbeeConn.Close() - s.OnClosed(frisbeeConn, nil) - s.wg.Done() return } } else { @@ -248,6 +203,48 @@ HANDLE: goto LOOP } +func (s *Server) handleConn(newConn net.Conn) { + var err error + switch v := newConn.(type) { + case *net.TCPConn: + err = v.SetKeepAlive(true) + if err != nil { + s.Logger().Error().Err(err).Msg("Error while setting TCP Keepalive") + _ = v.Close() + s.wg.Done() + return + } + err = v.SetKeepAlivePeriod(s.options.KeepAlive) + if err != nil { + s.Logger().Error().Err(err).Msg("Error while setting TCP Keepalive Period") + _ = v.Close() + s.wg.Done() + return + } + } + + frisbeeConn := NewAsync(newConn, s.Logger(), true) + connCtx := s.BaseContext() + + s.connectionsMu.Lock() + if s.shutdown.Load() { + s.wg.Done() + return + } + s.connections[frisbeeConn] = struct{}{} + s.connectionsMu.Unlock() + + err = s.handlePacket(frisbeeConn, connCtx) + _ = frisbeeConn.Close() + s.OnClosed(frisbeeConn, err) + s.connectionsMu.Lock() + if !s.shutdown.Load() { + delete(s.connections, frisbeeConn) + } + s.connectionsMu.Unlock() + s.wg.Done() +} + // Logger returns the server's logger (useful for ServerRouter functions) func (s *Server) Logger() *zerolog.Logger { return s.options.Logger @@ -256,7 +253,12 @@ func (s *Server) Logger() *zerolog.Logger { // Shutdown shuts down the frisbee server and kills all the goroutines and active connections func (s *Server) Shutdown() error { s.shutdown.Store(true) - close(s.shutdownCh) + s.connectionsMu.Lock() + for c := range s.connections { + _ = c.Close() + delete(s.connections, c) + } + s.connectionsMu.Unlock() defer s.wg.Wait() return s.listener.Close() }