From 1c5768784faa2b514b0b363e51d692df25091c44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Mon, 13 Jun 2022 19:50:00 +0200 Subject: [PATCH 01/46] implement ticket queue --- go.mod | 1 + go.sum | 2 + pkg/foundation/ticketqueue/ticketqueue.go | 170 +++++++++++ .../ticketqueue/ticketqueue_test.go | 266 ++++++++++++++++++ 4 files changed, 439 insertions(+) create mode 100644 pkg/foundation/ticketqueue/ticketqueue.go create mode 100644 pkg/foundation/ticketqueue/ticketqueue_test.go diff --git a/go.mod b/go.mod index f00630e92..132eeb53e 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/conduitio/conduit-connector-sdk v0.2.0 github.com/dgraph-io/badger/v3 v3.2103.2 github.com/dop251/goja v0.0.0-20210225094849-f3cfc97811c0 + github.com/gammazero/deque v0.2.0 github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.5.8 github.com/google/uuid v1.3.0 diff --git a/go.sum b/go.sum index be1df8af7..2f695f1ce 100644 --- a/go.sum +++ b/go.sum @@ -205,6 +205,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWpgI= github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU= +github.com/gammazero/deque v0.2.0 h1:SkieyNB4bg2/uZZLxvya0Pq6diUlwx7m2TeT7GAIWaA= +github.com/gammazero/deque v0.2.0/go.mod h1:LFroj8x4cMYCukHJDbxFCkT+r9AndaJnFMuZDV34tuU= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-fonts/dejavu v0.1.0/go.mod h1:4Wt4I4OU2Nq9asgDCteaAaWZOV24E+0/Pwo0gppep4g= github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3T0ecnM9pNujks= diff --git a/pkg/foundation/ticketqueue/ticketqueue.go b/pkg/foundation/ticketqueue/ticketqueue.go new file mode 100644 index 000000000..d738ce607 --- /dev/null +++ b/pkg/foundation/ticketqueue/ticketqueue.go @@ -0,0 +1,170 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ticketqueue + +import ( + "context" + + "github.com/conduitio/conduit/pkg/foundation/cerrors" + "github.com/gammazero/deque" +) + +// TicketQueue dispenses tickets and keeps track of their order. Tickets can be +// "cashed-in" for channels that let the caller communicate with a worker that +// is supposed to handle the ticket. TicketQueue ensures that tickets are +// cashed-in in the exact same order as they were dispensed. +// +// Essentially TicketQueue simulates a "take a number" system where the number +// on the ticket is monotonically increasing with each dispensed ticket. When +// the monitor displays the number on the ticket, the person holding the ticket +// can approach the counter. +// +// TicketQueue contains an unbounded buffer for tickets. A goroutine is pushing +// tickets to workers calling Next. To stop this goroutine the TicketQueue needs +// to be closed and all tickets need to be drained through Next until it returns +// an error. +type TicketQueue[REQ, RES any] struct { + // in is the channel where incoming tickets are sent into (see Take) + in chan Ticket[REQ, RES] + // out is the channel where outgoing tickets are sent into (see Next) + out chan Ticket[REQ, RES] +} + +// NewTicketQueue returns an initialized TicketQueue. +func NewTicketQueue[REQ, RES any]() *TicketQueue[REQ, RES] { + tq := &TicketQueue[REQ, RES]{ + in: make(chan Ticket[REQ, RES]), + out: make(chan Ticket[REQ, RES]), + } + tq.run() + return tq +} + +// Ticket is dispensed by TicketQueue. Once TicketQueue.Wait is called with a +// Ticket it should be discarded. +type Ticket[REQ, RES any] struct { + ctrl chan struct{} + req chan REQ + res chan RES +} + +// run launches a goroutine that fetches tickets from the channel in and buffers +// them in an unbounded queue. It also pushes tickets from the queue into the +// channel out. +func (tq *TicketQueue[REQ, RES]) run() { + in := tq.in + + // Deque is used as a normal queue and holds references to all open tickets + var q deque.Deque[Ticket[REQ, RES]] + outOrNil := func() chan Ticket[REQ, RES] { + if q.Len() == 0 { + return nil + } + return tq.out + } + nextTicket := func() Ticket[REQ, RES] { + if q.Len() == 0 { + return Ticket[REQ, RES]{} + } + return q.Front() + } + + go func() { + defer close(tq.out) + for q.Len() > 0 || in != nil { + select { + case v, ok := <-in: + if !ok { + in = nil + continue + } + q.PushBack(v) + case outOrNil() <- nextTicket(): + q.PopFront() // remove ticket from queue + } + } + }() +} + +// Take creates a ticket. The ticket can be used to call Wait. If TicketQueue +// is already closed, the call panics. +func (tq *TicketQueue[REQ, RES]) Take() Ticket[REQ, RES] { + t := Ticket[REQ, RES]{ + ctrl: make(chan struct{}), + req: make(chan REQ), + res: make(chan RES), + } + tq.in <- t + return t +} + +// Wait will block until all tickets before this ticket were already processed. +// Essentially this method means the caller wants to enqueue and wait for their +// turn. The function returns two channels that can be used to communicate with +// the processor of the ticket. The caller determines what messages are sent +// through those channels (if any). After Wait returns the ticket should be +// discarded. +// +// If ctx gets cancelled before the ticket is redeemed, the function returns the +// context error. If Wait is called a second time with the same ticket, the call +// returns an error. +func (tq *TicketQueue[REQ, RES]) Wait(ctx context.Context, t Ticket[REQ, RES]) (chan<- REQ, <-chan RES, error) { + select { + case <-ctx.Done(): + return nil, nil, ctx.Err() + case _, ok := <-t.ctrl: + if !ok { + return nil, nil, cerrors.New("ticket already used") + } + } + return t.req, t.res, nil +} + +// Next can be used to fetch the channels to communicate with the next ticket +// holder in line. If there is no next ticket holder or if the next ticket +// holder did not call Wait, the call will block. +// +// If ctx gets cancelled before the next ticket holder is ready, the function +// returns the context error. If TicketQueue is closed and there are no more +// open tickets, the call returns an error. +func (tq *TicketQueue[REQ, RES]) Next(ctx context.Context) (<-chan REQ, chan<- RES, error) { + var t Ticket[REQ, RES] + var ok bool + + select { + case <-ctx.Done(): + return nil, nil, ctx.Err() + case t, ok = <-tq.out: + if !ok { + return nil, nil, cerrors.New("TicketQueue is closed") + } + } + + select { + case <-ctx.Done(): + // BUG: the ticket is lost at this point + return nil, nil, ctx.Err() + case t.ctrl <- struct{}{}: // signal that Next is ready to proceed + close(t.ctrl) // ticket is used + } + + return t.req, t.res, nil +} + +// Close the ticket queue, no more new tickets can be dispensed after this. +// Calls to Wait and Next are still allowed until all open tickets are redeemed. +func (tq *TicketQueue[REQ, RES]) Close() { + close(tq.in) +} diff --git a/pkg/foundation/ticketqueue/ticketqueue_test.go b/pkg/foundation/ticketqueue/ticketqueue_test.go new file mode 100644 index 000000000..76f350f46 --- /dev/null +++ b/pkg/foundation/ticketqueue/ticketqueue_test.go @@ -0,0 +1,266 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ticketqueue + +import ( + "context" + "fmt" + "math/rand" + "strings" + "sync" + "testing" + "time" + + "github.com/conduitio/conduit/pkg/foundation/cerrors" + "github.com/matryer/is" +) + +func TestTicketQueue_Next_ContextCanceled(t *testing.T) { + is := is.New(t) + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + defer cancel() + + tq := NewTicketQueue[int, float64]() + defer tq.Close() + + req, res, err := tq.Next(ctx) + is.Equal(req, nil) + is.Equal(res, nil) + is.Equal(err, context.DeadlineExceeded) +} + +func TestTicketQueue_Next_Closed(t *testing.T) { + is := is.New(t) + ctx := context.Background() + + tq := NewTicketQueue[int, float64]() + tq.Close() // close ticket queue + + req, res, err := tq.Next(ctx) + is.Equal(req, nil) + is.Equal(res, nil) + is.True(err != nil) +} + +func TestTicketQueue_Take_Closed(t *testing.T) { + is := is.New(t) + + tq := NewTicketQueue[int, float64]() + tq.Close() // close ticket queue, taking a ticket after this is not permitted + + defer func() { + is.True(recover() != nil) // expected Take to panic + }() + + tq.Take() +} + +func TestTicketQueue_Wait_ContextCanceled(t *testing.T) { + is := is.New(t) + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + defer cancel() + + tq := NewTicketQueue[int, float64]() + defer tq.Close() + + ticket := tq.Take() + req, res, err := tq.Wait(ctx, ticket) + is.Equal(req, nil) + is.Equal(res, nil) + is.Equal(err, context.DeadlineExceeded) +} + +func TestTicketQueue_Wait_ReuseTicket(t *testing.T) { + is := is.New(t) + ctx := context.Background() + + tq := NewTicketQueue[int, float64]() + defer tq.Close() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + defer cancel() + _, _, err := tq.Next(ctx) + is.NoErr(err) + _, _, err = tq.Next(ctx) + is.Equal(err, context.DeadlineExceeded) + }() + + ticket := tq.Take() + _, _, err := tq.Wait(ctx, ticket) + is.NoErr(err) + + _, _, err = tq.Wait(ctx, ticket) + is.True(err != nil) // expected error for ticket that was already cashed-in + wg.Wait() +} + +func TestTicketQueue_Next_NoTicketWaiting(t *testing.T) { + is := is.New(t) + ctx := context.Background() + + tq := NewTicketQueue[int, float64]() + defer tq.Close() + + tq.Take() // take ticket, but don't cash it in + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + defer cancel() + _, _, err := tq.Next(ctx) + is.Equal(err, context.DeadlineExceeded) +} + +func TestTicketQueue_Take_Buffer(t *testing.T) { + is := is.New(t) + ctx := context.Background() + + tq := NewTicketQueue[int, float64]() + defer tq.Close() + + // TicketQueue supports an unbounded amount of tickets, keep taking tickets + // for one second and take as many tickets as possible + testDuration := time.Second + + var wg sync.WaitGroup + var numTickets int + start := time.Now() + for time.Since(start) < testDuration { + numTickets += 1 + ticket := tq.Take() + go func() { + defer wg.Done() + _, _, err := tq.Wait(ctx, ticket) + is.NoErr(err) + }() + } + wg.Add(numTickets) + t.Logf("took %d tickets in %s", numTickets, testDuration) + + for i := 0; i < numTickets; i++ { + _, _, err := tq.Next(ctx) + is.NoErr(err) + } + + wg.Wait() // wait for all ticket goroutines to finish + + // try fetching next in line, but there is none + ctx, cancel := context.WithTimeout(ctx, time.Millisecond*10) + defer cancel() + _, _, err := tq.Next(ctx) + is.Equal(err, context.DeadlineExceeded) +} + +func TestTicketQueue_HandOff(t *testing.T) { + is := is.New(t) + ctx := context.Background() + + tq := NewTicketQueue[int, float64]() + defer tq.Close() + + wantInt := 123 + wantFloat := 1.23 + + done := make(chan struct{}) + go func() { + defer close(done) + ticket := tq.Take() + req, res, err := tq.Wait(ctx, ticket) + is.NoErr(err) + req <- wantInt + gotFloat := <-res + is.Equal(wantFloat, gotFloat) + }() + + req, res, err := tq.Next(ctx) + is.NoErr(err) + + gotInt := <-req + is.Equal(wantInt, gotInt) + + res <- wantFloat + <-done +} + +func ExampleTicketQueue() { + ctx := context.Background() + + tq := NewTicketQueue[string, error]() + defer tq.Close() + + sentence := []string{ + "Each", "word", "will", "be", "sent", "to", "the", "collector", "in", + "a", "separate", "goroutine", "and", "even", "though", "they", "will", + "sleep", "for", "a", "random", "amount", "of", "time,", "all", "words", + "will", "be", "processed", "in", "the", "right", "order.", + } + + r := rand.New(rand.NewSource(time.Now().UnixMilli())) + var wg sync.WaitGroup + for _, word := range sentence { + t := tq.Take() + wg.Add(1) + go func(word string) { + defer wg.Done() + // sleep for a random amount of time to simulate work being done + time.Sleep(time.Millisecond * time.Duration(r.Intn(100))) + // try to cash in ticket + req, res, err := tq.Wait(ctx, t) + if err != nil { + panic(cerrors.Errorf("unexpected error: %w", err)) + } + req <- word // send word to collector + err = <-res // receive error back + if err != nil { + panic(cerrors.Errorf("unexpected error: %w", err)) + } + }(word) + } + + // collect all tickets + var builder strings.Builder + for { + ctx, cancel := context.WithTimeout(ctx, time.Millisecond*200) + defer cancel() + + req, res, err := tq.Next(ctx) + if err != nil { + if err == context.DeadlineExceeded { + break + } + panic(cerrors.Errorf("unexpected error: %w", err)) + } + + word := <-req + _, err = builder.WriteRune(' ') + if err != nil { + res <- err + } + _, err = builder.WriteString(word) + if err != nil { + res <- err + } + close(res) + } + wg.Wait() + + fmt.Println(builder.String()) + + // Output: + // Each word will be sent to the collector in a separate goroutine and even though they will sleep for a random amount of time, all words will be processed in the right order. +} From 8bf3f66066559c8ca48d6fc7fba14147ba9275e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Tue, 14 Jun 2022 19:47:00 +0200 Subject: [PATCH 02/46] experiment with ordered semaphore --- pkg/foundation/semaphore/semaphore.go | 170 +++++++++++++++ .../semaphore/semaphore_bench_test.go | 90 ++++++++ pkg/foundation/semaphore/semaphore_test.go | 195 ++++++++++++++++++ .../ticketqueue/ticketqueue_test.go | 6 +- 4 files changed, 458 insertions(+), 3 deletions(-) create mode 100644 pkg/foundation/semaphore/semaphore.go create mode 100644 pkg/foundation/semaphore/semaphore_bench_test.go create mode 100644 pkg/foundation/semaphore/semaphore_test.go diff --git a/pkg/foundation/semaphore/semaphore.go b/pkg/foundation/semaphore/semaphore.go new file mode 100644 index 000000000..ad74ff126 --- /dev/null +++ b/pkg/foundation/semaphore/semaphore.go @@ -0,0 +1,170 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package semaphore provides a weighted semaphore implementation. +package semaphore + +import ( + "container/list" + "context" + "sync" +) + +// NewWeighted creates a new weighted semaphore with the given +// maximum combined weight for concurrent access. +func NewWeighted(n int64) *Weighted { + w := &Weighted{size: n} + return w +} + +// Weighted provides a way to bound concurrent access to a resource. +// The callers can request access with a given weight. +type Weighted struct { + size int64 + cur int64 + mu sync.Mutex + waiters list.List +} + +type waiter struct { + acquired bool + released bool + n int64 + ready chan struct{} // Closed when semaphore acquired. +} + +type Ticket struct { + elem *list.Element +} + +func (s *Weighted) Enqueue(n int64) Ticket { + if n > s.size { + panic("semaphore: tried to enqueue more than size of semaphore") + } + + s.mu.Lock() + w := waiter{n: n, ready: make(chan struct{})} + e := s.waiters.PushBack(w) + s.mu.Unlock() + return Ticket{elem: e} +} + +// Acquire acquires the semaphore with a weight of n, blocking until resources +// are available or ctx is done. On success, returns nil. On failure, returns +// ctx.Err() and leaves the semaphore unchanged. +// +// If ctx is already done, Acquire may still succeed without blocking. +func (s *Weighted) Acquire(ctx context.Context, t Ticket) error { + w := t.elem.Value.(waiter) + + s.mu.Lock() + if s.waiters.Front() == t.elem && s.size-s.cur >= w.n { + s.cur += w.n + s.waiters.Remove(t.elem) + w.acquired = true + t.elem.Value = w + // If there are extra tokens left, notify other waiters. + if s.size > s.cur { + s.notifyWaiters() + } + s.mu.Unlock() + return nil + } + if w.n > s.size { + // Don't make other Acquire calls block on one that's doomed to fail. + s.mu.Unlock() + <-ctx.Done() + return ctx.Err() + } + s.mu.Unlock() + + select { + case <-ctx.Done(): + err := ctx.Err() + s.mu.Lock() + select { + case <-w.ready: + // Acquired the semaphore after we were canceled. Rather than trying to + // fix up the queue, just pretend we didn't notice the cancelation. + err = nil + default: + isFront := s.waiters.Front() == t.elem + s.waiters.Remove(t.elem) + // If we're at the front and there are extra tokens left, notify other waiters. + if isFront && s.size > s.cur { + s.notifyWaiters() + } + } + s.mu.Unlock() + return err + + case <-w.ready: + return nil + } +} + +// Release releases the semaphore with a weight of n. +func (s *Weighted) Release(t Ticket) { + w := t.elem.Value.(waiter) + s.mu.Lock() + if !w.acquired { + s.mu.Unlock() + panic("semaphore: can't release ticket that was not acquired") + } + if w.released { + s.mu.Unlock() + panic("semaphore: ticket released twice") + } + + s.cur -= t.elem.Value.(waiter).n + w.released = true + t.elem.Value = w + if s.cur < 0 { + s.mu.Unlock() + panic("semaphore: released more than held") + } + s.notifyWaiters() + s.mu.Unlock() +} + +func (s *Weighted) notifyWaiters() { + for { + next := s.waiters.Front() + if next == nil { + break // No more waiters blocked. + } + + w := next.Value.(waiter) + if s.size-s.cur < w.n { + // Not enough tokens for the next waiter. We could keep going (to try to + // find a waiter with a smaller request), but under load that could cause + // starvation for large requests; instead, we leave all remaining waiters + // blocked. + // + // Consider a semaphore used as a read-write lock, with N tokens, N + // readers, and one writer. Each reader can Acquire(1) to obtain a read + // lock. The writer can Acquire(N) to obtain a write lock, excluding all + // of the readers. If we allow the readers to jump ahead in the queue, + // the writer will starve — there is always one token available for every + // reader. + break + } + + w.acquired = true + next.Value = w + s.cur += w.n + s.waiters.Remove(next) + close(w.ready) + } +} diff --git a/pkg/foundation/semaphore/semaphore_bench_test.go b/pkg/foundation/semaphore/semaphore_bench_test.go new file mode 100644 index 000000000..0dc1857ad --- /dev/null +++ b/pkg/foundation/semaphore/semaphore_bench_test.go @@ -0,0 +1,90 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package semaphore_test + +import ( + "container/list" + "context" + "fmt" + "testing" + + "github.com/conduitio/conduit/pkg/foundation/semaphore" +) + +// weighted is an interface matching a subset of *Weighted. It allows +// alternate implementations for testing and benchmarking. +type weighted interface { + Enqueue(int64) semaphore.Ticket + Acquire(context.Context, semaphore.Ticket) error + Release(semaphore.Ticket) +} + +// acquireN calls Acquire(size) on sem N times and then calls Release(size) N times. +func acquireN(b *testing.B, sem weighted, size int64, N int) { + b.ResetTimer() + tickets := list.New() + for i := 0; i < b.N; i++ { + tickets.Init() + for j := 0; j < N; j++ { + ticket := sem.Enqueue(size) + tickets.PushBack(ticket) + sem.Acquire(context.Background(), ticket) + } + ticket := tickets.Front() + for ticket != nil { + sem.Release(ticket.Value.(semaphore.Ticket)) + ticket = ticket.Next() + } + } +} + +func BenchmarkNewSeq(b *testing.B) { + for _, cap := range []int64{1, 128} { + b.Run(fmt.Sprintf("Weighted-%d", cap), func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = semaphore.NewWeighted(cap) + } + }) + } +} + +func BenchmarkAcquireSeq(b *testing.B) { + for _, c := range []struct { + cap, size int64 + N int + }{ + {1, 1, 1}, + {2, 1, 1}, + {16, 1, 1}, + {128, 1, 1}, + {2, 2, 1}, + {16, 2, 8}, + {128, 2, 64}, + {2, 1, 2}, + {16, 8, 2}, + {128, 64, 2}, + } { + for _, w := range []struct { + name string + w weighted + }{ + {"Weighted", semaphore.NewWeighted(c.cap)}, + } { + b.Run(fmt.Sprintf("%s-acquire-%d-%d-%d", w.name, c.cap, c.size, c.N), func(b *testing.B) { + acquireN(b, w.w, c.size, c.N) + }) + } + } +} diff --git a/pkg/foundation/semaphore/semaphore_test.go b/pkg/foundation/semaphore/semaphore_test.go new file mode 100644 index 000000000..89a2b5428 --- /dev/null +++ b/pkg/foundation/semaphore/semaphore_test.go @@ -0,0 +1,195 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package semaphore_test + +import ( + "context" + "math/rand" + "runtime" + "sync" + "testing" + "time" + + "github.com/conduitio/conduit/pkg/foundation/semaphore" +) + +const maxSleep = 1 * time.Millisecond + +func HammerWeighted(sem *semaphore.Weighted, n int64, loops int) { + for i := 0; i < loops; i++ { + tkn := sem.Enqueue(n) + err := sem.Acquire(context.Background(), tkn) + if err != nil { + panic(err) + } + time.Sleep(time.Duration(rand.Int63n(int64(maxSleep/time.Nanosecond))) * time.Nanosecond) + sem.Release(tkn) + } +} + +func TestWeighted(t *testing.T) { + t.Parallel() + + n := runtime.GOMAXPROCS(0) + loops := 10000 / n + sem := semaphore.NewWeighted(int64(n)) + + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + i := i + go func() { + defer wg.Done() + HammerWeighted(sem, int64(i), loops) + }() + } + wg.Wait() +} + +func TestWeightedPanicReleaseUnacquired(t *testing.T) { + t.Parallel() + + defer func() { + if recover() == nil { + t.Fatal("release of an unacquired weighted semaphore did not panic") + } + }() + w := semaphore.NewWeighted(1) + tkn := w.Enqueue(1) + w.Release(tkn) +} + +func TestWeightedPanicEnqueueTooBig(t *testing.T) { + t.Parallel() + + defer func() { + if recover() == nil { + t.Fatal("enqueue of size bigger than weighted semaphore did not panic") + } + }() + const n = 5 + sem := semaphore.NewWeighted(n) + sem.Enqueue(n + 1) +} + +func TestWeightedAcquire(t *testing.T) { + t.Parallel() + + ctx := context.Background() + sem := semaphore.NewWeighted(2) + tryAcquire := func(n int64) bool { + ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + defer cancel() + + tkn := sem.Enqueue(1) + return sem.Acquire(ctx, tkn) == nil + } + + tries := []bool{} + tkn := sem.Enqueue(1) + sem.Acquire(ctx, tkn) + tries = append(tries, tryAcquire(1)) + tries = append(tries, tryAcquire(1)) + + sem.Release(tkn) + + tkn = sem.Enqueue(1) + sem.Acquire(ctx, tkn) + tries = append(tries, tryAcquire(1)) + + want := []bool{true, false, false} + for i := range tries { + if tries[i] != want[i] { + t.Errorf("tries[%d]: got %t, want %t", i, tries[i], want[i]) + } + } +} + +// TestLargeAcquireDoesntStarve times out if a large call to Acquire starves. +// Merely returning from the test function indicates success. +func TestLargeAcquireDoesntStarve(t *testing.T) { + t.Parallel() + + ctx := context.Background() + n := int64(runtime.GOMAXPROCS(0)) + sem := semaphore.NewWeighted(n) + running := true + + var wg sync.WaitGroup + wg.Add(int(n)) + for i := n; i > 0; i-- { + tkn := sem.Enqueue(1) + sem.Acquire(ctx, tkn) + go func() { + defer func() { + sem.Release(tkn) + wg.Done() + }() + for running { + time.Sleep(1 * time.Millisecond) + sem.Release(tkn) + tkn = sem.Enqueue(1) + sem.Acquire(ctx, tkn) + } + }() + } + + tkn := sem.Enqueue(n) + sem.Acquire(ctx, tkn) + running = false + sem.Release(tkn) + wg.Wait() +} + +// translated from https://github.com/zhiqiangxu/util/blob/master/mutex/crwmutex_test.go#L43 +func TestAllocCancelDoesntStarve(t *testing.T) { + sem := semaphore.NewWeighted(10) + + // Block off a portion of the semaphore so that Acquire(_, 10) can eventually succeed. + tkn := sem.Enqueue(1) + sem.Acquire(context.Background(), tkn) + + // In the background, Acquire(_, 10). + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + tkn := sem.Enqueue(10) + sem.Acquire(ctx, tkn) + }() + + // Wait until the Acquire(_, 10) call blocks. + for { + ctx, cancel := context.WithTimeout(ctx, time.Millisecond) + tkn := sem.Enqueue(1) + err := sem.Acquire(ctx, tkn) + cancel() + if err != nil { + break + } + sem.Release(tkn) + runtime.Gosched() + } + + // Now try to grab a read lock, and simultaneously unblock the Acquire(_, 10) call. + // Both Acquire calls should unblock and return, in either order. + go cancel() + + tkn = sem.Enqueue(1) + err := sem.Acquire(context.Background(), tkn) + if err != nil { + t.Fatalf("Acquire(_, 1) failed unexpectedly: %v", err) + } + sem.Release(tkn) +} diff --git a/pkg/foundation/ticketqueue/ticketqueue_test.go b/pkg/foundation/ticketqueue/ticketqueue_test.go index 76f350f46..18a5adedb 100644 --- a/pkg/foundation/ticketqueue/ticketqueue_test.go +++ b/pkg/foundation/ticketqueue/ticketqueue_test.go @@ -215,10 +215,10 @@ func ExampleTicketQueue() { for _, word := range sentence { t := tq.Take() wg.Add(1) - go func(word string) { + go func(word string, delay time.Duration) { defer wg.Done() // sleep for a random amount of time to simulate work being done - time.Sleep(time.Millisecond * time.Duration(r.Intn(100))) + time.Sleep(delay) // try to cash in ticket req, res, err := tq.Wait(ctx, t) if err != nil { @@ -229,7 +229,7 @@ func ExampleTicketQueue() { if err != nil { panic(cerrors.Errorf("unexpected error: %w", err)) } - }(word) + }(word, time.Millisecond*time.Duration(r.Intn(100))) } // collect all tickets From 424d88956e53111dd512a9ef8e72b325cf360147 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Tue, 14 Jun 2022 20:01:01 +0200 Subject: [PATCH 03/46] ticketqueue benchmarks --- .../semaphore/semaphore_bench_test.go | 4 +- .../ticketqueue/ticketqueue_bench_test.go | 60 +++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 pkg/foundation/ticketqueue/ticketqueue_bench_test.go diff --git a/pkg/foundation/semaphore/semaphore_bench_test.go b/pkg/foundation/semaphore/semaphore_bench_test.go index 0dc1857ad..565b5ac94 100644 --- a/pkg/foundation/semaphore/semaphore_bench_test.go +++ b/pkg/foundation/semaphore/semaphore_bench_test.go @@ -50,7 +50,7 @@ func acquireN(b *testing.B, sem weighted, size int64, N int) { } } -func BenchmarkNewSeq(b *testing.B) { +func BenchmarkNewSem(b *testing.B) { for _, cap := range []int64{1, 128} { b.Run(fmt.Sprintf("Weighted-%d", cap), func(b *testing.B) { for i := 0; i < b.N; i++ { @@ -60,7 +60,7 @@ func BenchmarkNewSeq(b *testing.B) { } } -func BenchmarkAcquireSeq(b *testing.B) { +func BenchmarkAcquireSem(b *testing.B) { for _, c := range []struct { cap, size int64 N int diff --git a/pkg/foundation/ticketqueue/ticketqueue_bench_test.go b/pkg/foundation/ticketqueue/ticketqueue_bench_test.go new file mode 100644 index 000000000..844136c15 --- /dev/null +++ b/pkg/foundation/ticketqueue/ticketqueue_bench_test.go @@ -0,0 +1,60 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ticketqueue + +import ( + "context" + "fmt" + "testing" + + "github.com/gammazero/deque" +) + +func BenchmarkNewTicketQueue(b *testing.B) { + for i := 0; i < b.N; i++ { + tq := NewTicketQueue[int, int]() + defer tq.Close() + } + b.StopTimer() // don't measure Close +} + +func BenchmarkTicketQueueTake(b *testing.B) { + for _, N := range []int{1, 2, 8, 64, 128} { + b.Run(fmt.Sprintf("TicketQueue-%d", N), func(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + tq := NewTicketQueue[int, int]() + go func() { + for { + _, _, err := tq.Next(ctx) + if err == context.Canceled { + return + } + } + }() + + for i := 0; i < b.N; i++ { + tickets := deque.Deque[Ticket[int, int]]{} + for j := 0; j < N; j++ { + ticket := tq.Take() + tickets.PushBack(ticket) + } + for tickets.Len() > 0 { + tq.Wait(ctx, tickets.PopFront()) + } + } + cancel() + }) + } +} From 8a852db3ca4c92f66173c2d13544ae69d9dd2a09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Fri, 17 Jun 2022 20:45:01 +0200 Subject: [PATCH 04/46] reduce allocations --- pkg/foundation/semaphore/semaphore.go | 143 +++++++++--------- .../semaphore/semaphore_bench_test.go | 14 +- pkg/foundation/semaphore/semaphore_test.go | 140 ++++++++--------- 3 files changed, 138 insertions(+), 159 deletions(-) diff --git a/pkg/foundation/semaphore/semaphore.go b/pkg/foundation/semaphore/semaphore.go index ad74ff126..e1f06af5d 100644 --- a/pkg/foundation/semaphore/semaphore.go +++ b/pkg/foundation/semaphore/semaphore.go @@ -16,9 +16,9 @@ package semaphore import ( - "container/list" - "context" "sync" + + "github.com/conduitio/conduit/pkg/foundation/cerrors" ) // NewWeighted creates a new weighted semaphore with the given @@ -31,21 +31,28 @@ func NewWeighted(n int64) *Weighted { // Weighted provides a way to bound concurrent access to a resource. // The callers can request access with a given weight. type Weighted struct { - size int64 - cur int64 - mu sync.Mutex - waiters list.List + size int64 + cur int64 + released int + mu sync.Mutex + + waiters []waiter + front int + batch int64 } type waiter struct { - acquired bool + index int + n int64 + ready chan struct{} // Closed when semaphore acquired. + released bool - n int64 - ready chan struct{} // Closed when semaphore acquired. + acquired bool } type Ticket struct { - elem *list.Element + index int + batch int64 } func (s *Weighted) Enqueue(n int64) Ticket { @@ -54,10 +61,16 @@ func (s *Weighted) Enqueue(n int64) Ticket { } s.mu.Lock() - w := waiter{n: n, ready: make(chan struct{})} - e := s.waiters.PushBack(w) - s.mu.Unlock() - return Ticket{elem: e} + defer s.mu.Unlock() + + index := len(s.waiters) + w := waiter{index: index, n: n, ready: make(chan struct{})} + s.waiters = append(s.waiters, w) + + return Ticket{ + index: index, + batch: s.batch, + } } // Acquire acquires the semaphore with a weight of n, blocking until resources @@ -65,15 +78,24 @@ func (s *Weighted) Enqueue(n int64) Ticket { // ctx.Err() and leaves the semaphore unchanged. // // If ctx is already done, Acquire may still succeed without blocking. -func (s *Weighted) Acquire(ctx context.Context, t Ticket) error { - w := t.elem.Value.(waiter) - +func (s *Weighted) Acquire(t Ticket) error { s.mu.Lock() - if s.waiters.Front() == t.elem && s.size-s.cur >= w.n { + if s.batch != t.batch { + s.mu.Unlock() + return cerrors.Errorf("semaphore: invalid batch") + } + + w := s.waiters[t.index] + if w.acquired { + return cerrors.New("semaphore: can't acquire ticket that was already acquired") + } + + w.acquired = true // mark that Acquire was already called for this Ticket + s.waiters[t.index] = w + + if s.front == t.index && s.size-s.cur >= w.n { s.cur += w.n - s.waiters.Remove(t.elem) - w.acquired = true - t.elem.Value = w + s.front++ // If there are extra tokens left, notify other waiters. if s.size > s.cur { s.notifyWaiters() @@ -81,71 +103,42 @@ func (s *Weighted) Acquire(ctx context.Context, t Ticket) error { s.mu.Unlock() return nil } - if w.n > s.size { - // Don't make other Acquire calls block on one that's doomed to fail. - s.mu.Unlock() - <-ctx.Done() - return ctx.Err() - } s.mu.Unlock() - select { - case <-ctx.Done(): - err := ctx.Err() - s.mu.Lock() - select { - case <-w.ready: - // Acquired the semaphore after we were canceled. Rather than trying to - // fix up the queue, just pretend we didn't notice the cancelation. - err = nil - default: - isFront := s.waiters.Front() == t.elem - s.waiters.Remove(t.elem) - // If we're at the front and there are extra tokens left, notify other waiters. - if isFront && s.size > s.cur { - s.notifyWaiters() - } - } - s.mu.Unlock() - return err - - case <-w.ready: - return nil - } + <-w.ready + return nil } // Release releases the semaphore with a weight of n. -func (s *Weighted) Release(t Ticket) { - w := t.elem.Value.(waiter) +func (s *Weighted) Release(t Ticket) error { s.mu.Lock() + defer s.mu.Unlock() + + if s.batch != t.batch { + return cerrors.Errorf("semaphore: invalid batch") + } + w := s.waiters[t.index] if !w.acquired { - s.mu.Unlock() - panic("semaphore: can't release ticket that was not acquired") + return cerrors.New("semaphore: can't release ticket that was not acquired") } if w.released { - s.mu.Unlock() - panic("semaphore: ticket released twice") + return cerrors.New("semaphore: ticket already released") } - s.cur -= t.elem.Value.(waiter).n + s.cur -= w.n w.released = true - t.elem.Value = w - if s.cur < 0 { - s.mu.Unlock() - panic("semaphore: released more than held") - } + s.waiters[t.index] = w + s.released++ s.notifyWaiters() - s.mu.Unlock() + if s.released == len(s.waiters) { + s.increaseBatch() + } + return nil } func (s *Weighted) notifyWaiters() { - for { - next := s.waiters.Front() - if next == nil { - break // No more waiters blocked. - } - - w := next.Value.(waiter) + for len(s.waiters) > s.front { + w := s.waiters[s.front] if s.size-s.cur < w.n { // Not enough tokens for the next waiter. We could keep going (to try to // find a waiter with a smaller request), but under load that could cause @@ -161,10 +154,14 @@ func (s *Weighted) notifyWaiters() { break } - w.acquired = true - next.Value = w s.cur += w.n - s.waiters.Remove(next) + s.front++ close(w.ready) } } + +func (s *Weighted) increaseBatch() { + s.waiters = s.waiters[:0] + s.batch += 1 + s.front = 0 +} diff --git a/pkg/foundation/semaphore/semaphore_bench_test.go b/pkg/foundation/semaphore/semaphore_bench_test.go index 565b5ac94..73dcfb686 100644 --- a/pkg/foundation/semaphore/semaphore_bench_test.go +++ b/pkg/foundation/semaphore/semaphore_bench_test.go @@ -16,7 +16,6 @@ package semaphore_test import ( "container/list" - "context" "fmt" "testing" @@ -27,8 +26,8 @@ import ( // alternate implementations for testing and benchmarking. type weighted interface { Enqueue(int64) semaphore.Ticket - Acquire(context.Context, semaphore.Ticket) error - Release(semaphore.Ticket) + Acquire(semaphore.Ticket) error + Release(semaphore.Ticket) error } // acquireN calls Acquire(size) on sem N times and then calls Release(size) N times. @@ -38,14 +37,7 @@ func acquireN(b *testing.B, sem weighted, size int64, N int) { for i := 0; i < b.N; i++ { tickets.Init() for j := 0; j < N; j++ { - ticket := sem.Enqueue(size) - tickets.PushBack(ticket) - sem.Acquire(context.Background(), ticket) - } - ticket := tickets.Front() - for ticket != nil { - sem.Release(ticket.Value.(semaphore.Ticket)) - ticket = ticket.Next() + _ = sem.Enqueue(size) } } } diff --git a/pkg/foundation/semaphore/semaphore_test.go b/pkg/foundation/semaphore/semaphore_test.go index 89a2b5428..9845f58fc 100644 --- a/pkg/foundation/semaphore/semaphore_test.go +++ b/pkg/foundation/semaphore/semaphore_test.go @@ -15,7 +15,6 @@ package semaphore_test import ( - "context" "math/rand" "runtime" "sync" @@ -30,7 +29,7 @@ const maxSleep = 1 * time.Millisecond func HammerWeighted(sem *semaphore.Weighted, n int64, loops int) { for i := 0; i < loops; i++ { tkn := sem.Enqueue(n) - err := sem.Acquire(context.Background(), tkn) + err := sem.Acquire(tkn) if err != nil { panic(err) } @@ -58,17 +57,48 @@ func TestWeighted(t *testing.T) { wg.Wait() } -func TestWeightedPanicReleaseUnacquired(t *testing.T) { +func TestWeightedReleaseUnacquired(t *testing.T) { t.Parallel() - defer func() { - if recover() == nil { - t.Fatal("release of an unacquired weighted semaphore did not panic") - } - }() w := semaphore.NewWeighted(1) tkn := w.Enqueue(1) - w.Release(tkn) + err := w.Release(tkn) + if err == nil { + t.Errorf("release of an unacquired ticket did not return an error") + } +} + +func TestWeightedReleaseTwice(t *testing.T) { + t.Parallel() + + w := semaphore.NewWeighted(1) + tkn := w.Enqueue(1) + w.Acquire(tkn) + err := w.Release(tkn) + if err != nil { + t.Errorf("release of an acquired ticket errored out: %v", err) + } + + err = w.Release(tkn) + if err == nil { + t.Errorf("release of an already released ticket did not return an error") + } +} + +func TestWeightedAcquireTwice(t *testing.T) { + t.Parallel() + + w := semaphore.NewWeighted(1) + tkn := w.Enqueue(1) + err := w.Acquire(tkn) + if err != nil { + t.Errorf("acquire of a ticket errored out: %v", err) + } + + err = w.Acquire(tkn) + if err == nil { + t.Errorf("acquire of an already acquired ticket did not return an error") + } } func TestWeightedPanicEnqueueTooBig(t *testing.T) { @@ -87,33 +117,35 @@ func TestWeightedPanicEnqueueTooBig(t *testing.T) { func TestWeightedAcquire(t *testing.T) { t.Parallel() - ctx := context.Background() sem := semaphore.NewWeighted(2) - tryAcquire := func(n int64) bool { - ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) - defer cancel() - tkn := sem.Enqueue(1) - return sem.Acquire(ctx, tkn) == nil - } + tkn1 := sem.Enqueue(1) + sem.Acquire(tkn1) - tries := []bool{} - tkn := sem.Enqueue(1) - sem.Acquire(ctx, tkn) - tries = append(tries, tryAcquire(1)) - tries = append(tries, tryAcquire(1)) + tkn2 := sem.Enqueue(1) + sem.Acquire(tkn2) - sem.Release(tkn) + tkn3done := make(chan struct{}) + go func() { + defer close(tkn3done) + tkn3 := sem.Enqueue(1) + sem.Acquire(tkn3) + }() - tkn = sem.Enqueue(1) - sem.Acquire(ctx, tkn) - tries = append(tries, tryAcquire(1)) + select { + case <-tkn3done: + t.Errorf("tkn3done closed prematurely") + case <-time.After(time.Millisecond * 10): + // tkn3 Acquire is blocking as expected + } - want := []bool{true, false, false} - for i := range tries { - if tries[i] != want[i] { - t.Errorf("tries[%d]: got %t, want %t", i, tries[i], want[i]) - } + sem.Release(tkn1) + + select { + case <-tkn3done: + // tkn3 successfully acquired the semaphore + case <-time.After(time.Millisecond * 10): + t.Errorf("tkn3done didn't get closed") } } @@ -122,7 +154,6 @@ func TestWeightedAcquire(t *testing.T) { func TestLargeAcquireDoesntStarve(t *testing.T) { t.Parallel() - ctx := context.Background() n := int64(runtime.GOMAXPROCS(0)) sem := semaphore.NewWeighted(n) running := true @@ -131,7 +162,7 @@ func TestLargeAcquireDoesntStarve(t *testing.T) { wg.Add(int(n)) for i := n; i > 0; i-- { tkn := sem.Enqueue(1) - sem.Acquire(ctx, tkn) + sem.Acquire(tkn) go func() { defer func() { sem.Release(tkn) @@ -141,55 +172,14 @@ func TestLargeAcquireDoesntStarve(t *testing.T) { time.Sleep(1 * time.Millisecond) sem.Release(tkn) tkn = sem.Enqueue(1) - sem.Acquire(ctx, tkn) + sem.Acquire(tkn) } }() } tkn := sem.Enqueue(n) - sem.Acquire(ctx, tkn) + sem.Acquire(tkn) running = false sem.Release(tkn) wg.Wait() } - -// translated from https://github.com/zhiqiangxu/util/blob/master/mutex/crwmutex_test.go#L43 -func TestAllocCancelDoesntStarve(t *testing.T) { - sem := semaphore.NewWeighted(10) - - // Block off a portion of the semaphore so that Acquire(_, 10) can eventually succeed. - tkn := sem.Enqueue(1) - sem.Acquire(context.Background(), tkn) - - // In the background, Acquire(_, 10). - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - go func() { - tkn := sem.Enqueue(10) - sem.Acquire(ctx, tkn) - }() - - // Wait until the Acquire(_, 10) call blocks. - for { - ctx, cancel := context.WithTimeout(ctx, time.Millisecond) - tkn := sem.Enqueue(1) - err := sem.Acquire(ctx, tkn) - cancel() - if err != nil { - break - } - sem.Release(tkn) - runtime.Gosched() - } - - // Now try to grab a read lock, and simultaneously unblock the Acquire(_, 10) call. - // Both Acquire calls should unblock and return, in either order. - go cancel() - - tkn = sem.Enqueue(1) - err := sem.Acquire(context.Background(), tkn) - if err != nil { - t.Fatalf("Acquire(_, 1) failed unexpectedly: %v", err) - } - sem.Release(tkn) -} From ec7249ebfc0bd247fa4de9c98d00c7cc0f968ffd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Tue, 21 Jun 2022 17:24:27 +0200 Subject: [PATCH 05/46] remove ticketqueue (semaphore implementation is more performant) --- pkg/foundation/ticketqueue/ticketqueue.go | 170 ----------- .../ticketqueue/ticketqueue_bench_test.go | 60 ---- .../ticketqueue/ticketqueue_test.go | 266 ------------------ 3 files changed, 496 deletions(-) delete mode 100644 pkg/foundation/ticketqueue/ticketqueue.go delete mode 100644 pkg/foundation/ticketqueue/ticketqueue_bench_test.go delete mode 100644 pkg/foundation/ticketqueue/ticketqueue_test.go diff --git a/pkg/foundation/ticketqueue/ticketqueue.go b/pkg/foundation/ticketqueue/ticketqueue.go deleted file mode 100644 index d738ce607..000000000 --- a/pkg/foundation/ticketqueue/ticketqueue.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright © 2022 Meroxa, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ticketqueue - -import ( - "context" - - "github.com/conduitio/conduit/pkg/foundation/cerrors" - "github.com/gammazero/deque" -) - -// TicketQueue dispenses tickets and keeps track of their order. Tickets can be -// "cashed-in" for channels that let the caller communicate with a worker that -// is supposed to handle the ticket. TicketQueue ensures that tickets are -// cashed-in in the exact same order as they were dispensed. -// -// Essentially TicketQueue simulates a "take a number" system where the number -// on the ticket is monotonically increasing with each dispensed ticket. When -// the monitor displays the number on the ticket, the person holding the ticket -// can approach the counter. -// -// TicketQueue contains an unbounded buffer for tickets. A goroutine is pushing -// tickets to workers calling Next. To stop this goroutine the TicketQueue needs -// to be closed and all tickets need to be drained through Next until it returns -// an error. -type TicketQueue[REQ, RES any] struct { - // in is the channel where incoming tickets are sent into (see Take) - in chan Ticket[REQ, RES] - // out is the channel where outgoing tickets are sent into (see Next) - out chan Ticket[REQ, RES] -} - -// NewTicketQueue returns an initialized TicketQueue. -func NewTicketQueue[REQ, RES any]() *TicketQueue[REQ, RES] { - tq := &TicketQueue[REQ, RES]{ - in: make(chan Ticket[REQ, RES]), - out: make(chan Ticket[REQ, RES]), - } - tq.run() - return tq -} - -// Ticket is dispensed by TicketQueue. Once TicketQueue.Wait is called with a -// Ticket it should be discarded. -type Ticket[REQ, RES any] struct { - ctrl chan struct{} - req chan REQ - res chan RES -} - -// run launches a goroutine that fetches tickets from the channel in and buffers -// them in an unbounded queue. It also pushes tickets from the queue into the -// channel out. -func (tq *TicketQueue[REQ, RES]) run() { - in := tq.in - - // Deque is used as a normal queue and holds references to all open tickets - var q deque.Deque[Ticket[REQ, RES]] - outOrNil := func() chan Ticket[REQ, RES] { - if q.Len() == 0 { - return nil - } - return tq.out - } - nextTicket := func() Ticket[REQ, RES] { - if q.Len() == 0 { - return Ticket[REQ, RES]{} - } - return q.Front() - } - - go func() { - defer close(tq.out) - for q.Len() > 0 || in != nil { - select { - case v, ok := <-in: - if !ok { - in = nil - continue - } - q.PushBack(v) - case outOrNil() <- nextTicket(): - q.PopFront() // remove ticket from queue - } - } - }() -} - -// Take creates a ticket. The ticket can be used to call Wait. If TicketQueue -// is already closed, the call panics. -func (tq *TicketQueue[REQ, RES]) Take() Ticket[REQ, RES] { - t := Ticket[REQ, RES]{ - ctrl: make(chan struct{}), - req: make(chan REQ), - res: make(chan RES), - } - tq.in <- t - return t -} - -// Wait will block until all tickets before this ticket were already processed. -// Essentially this method means the caller wants to enqueue and wait for their -// turn. The function returns two channels that can be used to communicate with -// the processor of the ticket. The caller determines what messages are sent -// through those channels (if any). After Wait returns the ticket should be -// discarded. -// -// If ctx gets cancelled before the ticket is redeemed, the function returns the -// context error. If Wait is called a second time with the same ticket, the call -// returns an error. -func (tq *TicketQueue[REQ, RES]) Wait(ctx context.Context, t Ticket[REQ, RES]) (chan<- REQ, <-chan RES, error) { - select { - case <-ctx.Done(): - return nil, nil, ctx.Err() - case _, ok := <-t.ctrl: - if !ok { - return nil, nil, cerrors.New("ticket already used") - } - } - return t.req, t.res, nil -} - -// Next can be used to fetch the channels to communicate with the next ticket -// holder in line. If there is no next ticket holder or if the next ticket -// holder did not call Wait, the call will block. -// -// If ctx gets cancelled before the next ticket holder is ready, the function -// returns the context error. If TicketQueue is closed and there are no more -// open tickets, the call returns an error. -func (tq *TicketQueue[REQ, RES]) Next(ctx context.Context) (<-chan REQ, chan<- RES, error) { - var t Ticket[REQ, RES] - var ok bool - - select { - case <-ctx.Done(): - return nil, nil, ctx.Err() - case t, ok = <-tq.out: - if !ok { - return nil, nil, cerrors.New("TicketQueue is closed") - } - } - - select { - case <-ctx.Done(): - // BUG: the ticket is lost at this point - return nil, nil, ctx.Err() - case t.ctrl <- struct{}{}: // signal that Next is ready to proceed - close(t.ctrl) // ticket is used - } - - return t.req, t.res, nil -} - -// Close the ticket queue, no more new tickets can be dispensed after this. -// Calls to Wait and Next are still allowed until all open tickets are redeemed. -func (tq *TicketQueue[REQ, RES]) Close() { - close(tq.in) -} diff --git a/pkg/foundation/ticketqueue/ticketqueue_bench_test.go b/pkg/foundation/ticketqueue/ticketqueue_bench_test.go deleted file mode 100644 index 844136c15..000000000 --- a/pkg/foundation/ticketqueue/ticketqueue_bench_test.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright © 2022 Meroxa, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ticketqueue - -import ( - "context" - "fmt" - "testing" - - "github.com/gammazero/deque" -) - -func BenchmarkNewTicketQueue(b *testing.B) { - for i := 0; i < b.N; i++ { - tq := NewTicketQueue[int, int]() - defer tq.Close() - } - b.StopTimer() // don't measure Close -} - -func BenchmarkTicketQueueTake(b *testing.B) { - for _, N := range []int{1, 2, 8, 64, 128} { - b.Run(fmt.Sprintf("TicketQueue-%d", N), func(b *testing.B) { - ctx, cancel := context.WithCancel(context.Background()) - tq := NewTicketQueue[int, int]() - go func() { - for { - _, _, err := tq.Next(ctx) - if err == context.Canceled { - return - } - } - }() - - for i := 0; i < b.N; i++ { - tickets := deque.Deque[Ticket[int, int]]{} - for j := 0; j < N; j++ { - ticket := tq.Take() - tickets.PushBack(ticket) - } - for tickets.Len() > 0 { - tq.Wait(ctx, tickets.PopFront()) - } - } - cancel() - }) - } -} diff --git a/pkg/foundation/ticketqueue/ticketqueue_test.go b/pkg/foundation/ticketqueue/ticketqueue_test.go deleted file mode 100644 index 18a5adedb..000000000 --- a/pkg/foundation/ticketqueue/ticketqueue_test.go +++ /dev/null @@ -1,266 +0,0 @@ -// Copyright © 2022 Meroxa, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ticketqueue - -import ( - "context" - "fmt" - "math/rand" - "strings" - "sync" - "testing" - "time" - - "github.com/conduitio/conduit/pkg/foundation/cerrors" - "github.com/matryer/is" -) - -func TestTicketQueue_Next_ContextCanceled(t *testing.T) { - is := is.New(t) - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) - defer cancel() - - tq := NewTicketQueue[int, float64]() - defer tq.Close() - - req, res, err := tq.Next(ctx) - is.Equal(req, nil) - is.Equal(res, nil) - is.Equal(err, context.DeadlineExceeded) -} - -func TestTicketQueue_Next_Closed(t *testing.T) { - is := is.New(t) - ctx := context.Background() - - tq := NewTicketQueue[int, float64]() - tq.Close() // close ticket queue - - req, res, err := tq.Next(ctx) - is.Equal(req, nil) - is.Equal(res, nil) - is.True(err != nil) -} - -func TestTicketQueue_Take_Closed(t *testing.T) { - is := is.New(t) - - tq := NewTicketQueue[int, float64]() - tq.Close() // close ticket queue, taking a ticket after this is not permitted - - defer func() { - is.True(recover() != nil) // expected Take to panic - }() - - tq.Take() -} - -func TestTicketQueue_Wait_ContextCanceled(t *testing.T) { - is := is.New(t) - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) - defer cancel() - - tq := NewTicketQueue[int, float64]() - defer tq.Close() - - ticket := tq.Take() - req, res, err := tq.Wait(ctx, ticket) - is.Equal(req, nil) - is.Equal(res, nil) - is.Equal(err, context.DeadlineExceeded) -} - -func TestTicketQueue_Wait_ReuseTicket(t *testing.T) { - is := is.New(t) - ctx := context.Background() - - tq := NewTicketQueue[int, float64]() - defer tq.Close() - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) - defer cancel() - _, _, err := tq.Next(ctx) - is.NoErr(err) - _, _, err = tq.Next(ctx) - is.Equal(err, context.DeadlineExceeded) - }() - - ticket := tq.Take() - _, _, err := tq.Wait(ctx, ticket) - is.NoErr(err) - - _, _, err = tq.Wait(ctx, ticket) - is.True(err != nil) // expected error for ticket that was already cashed-in - wg.Wait() -} - -func TestTicketQueue_Next_NoTicketWaiting(t *testing.T) { - is := is.New(t) - ctx := context.Background() - - tq := NewTicketQueue[int, float64]() - defer tq.Close() - - tq.Take() // take ticket, but don't cash it in - - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) - defer cancel() - _, _, err := tq.Next(ctx) - is.Equal(err, context.DeadlineExceeded) -} - -func TestTicketQueue_Take_Buffer(t *testing.T) { - is := is.New(t) - ctx := context.Background() - - tq := NewTicketQueue[int, float64]() - defer tq.Close() - - // TicketQueue supports an unbounded amount of tickets, keep taking tickets - // for one second and take as many tickets as possible - testDuration := time.Second - - var wg sync.WaitGroup - var numTickets int - start := time.Now() - for time.Since(start) < testDuration { - numTickets += 1 - ticket := tq.Take() - go func() { - defer wg.Done() - _, _, err := tq.Wait(ctx, ticket) - is.NoErr(err) - }() - } - wg.Add(numTickets) - t.Logf("took %d tickets in %s", numTickets, testDuration) - - for i := 0; i < numTickets; i++ { - _, _, err := tq.Next(ctx) - is.NoErr(err) - } - - wg.Wait() // wait for all ticket goroutines to finish - - // try fetching next in line, but there is none - ctx, cancel := context.WithTimeout(ctx, time.Millisecond*10) - defer cancel() - _, _, err := tq.Next(ctx) - is.Equal(err, context.DeadlineExceeded) -} - -func TestTicketQueue_HandOff(t *testing.T) { - is := is.New(t) - ctx := context.Background() - - tq := NewTicketQueue[int, float64]() - defer tq.Close() - - wantInt := 123 - wantFloat := 1.23 - - done := make(chan struct{}) - go func() { - defer close(done) - ticket := tq.Take() - req, res, err := tq.Wait(ctx, ticket) - is.NoErr(err) - req <- wantInt - gotFloat := <-res - is.Equal(wantFloat, gotFloat) - }() - - req, res, err := tq.Next(ctx) - is.NoErr(err) - - gotInt := <-req - is.Equal(wantInt, gotInt) - - res <- wantFloat - <-done -} - -func ExampleTicketQueue() { - ctx := context.Background() - - tq := NewTicketQueue[string, error]() - defer tq.Close() - - sentence := []string{ - "Each", "word", "will", "be", "sent", "to", "the", "collector", "in", - "a", "separate", "goroutine", "and", "even", "though", "they", "will", - "sleep", "for", "a", "random", "amount", "of", "time,", "all", "words", - "will", "be", "processed", "in", "the", "right", "order.", - } - - r := rand.New(rand.NewSource(time.Now().UnixMilli())) - var wg sync.WaitGroup - for _, word := range sentence { - t := tq.Take() - wg.Add(1) - go func(word string, delay time.Duration) { - defer wg.Done() - // sleep for a random amount of time to simulate work being done - time.Sleep(delay) - // try to cash in ticket - req, res, err := tq.Wait(ctx, t) - if err != nil { - panic(cerrors.Errorf("unexpected error: %w", err)) - } - req <- word // send word to collector - err = <-res // receive error back - if err != nil { - panic(cerrors.Errorf("unexpected error: %w", err)) - } - }(word, time.Millisecond*time.Duration(r.Intn(100))) - } - - // collect all tickets - var builder strings.Builder - for { - ctx, cancel := context.WithTimeout(ctx, time.Millisecond*200) - defer cancel() - - req, res, err := tq.Next(ctx) - if err != nil { - if err == context.DeadlineExceeded { - break - } - panic(cerrors.Errorf("unexpected error: %w", err)) - } - - word := <-req - _, err = builder.WriteRune(' ') - if err != nil { - res <- err - } - _, err = builder.WriteString(word) - if err != nil { - res <- err - } - close(res) - } - wg.Wait() - - fmt.Println(builder.String()) - - // Output: - // Each word will be sent to the collector in a separate goroutine and even though they will sleep for a random amount of time, all words will be processed in the right order. -} From b288e216e7a1b6154a86822ceb4cc52c84adaf54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Tue, 21 Jun 2022 17:24:39 +0200 Subject: [PATCH 06/46] optimize semaphore for our use case --- pkg/foundation/semaphore/semaphore.go | 84 ++++++------------- .../semaphore/semaphore_bench_test.go | 46 +++++----- pkg/foundation/semaphore/semaphore_test.go | 79 +++++++---------- 3 files changed, 76 insertions(+), 133 deletions(-) diff --git a/pkg/foundation/semaphore/semaphore.go b/pkg/foundation/semaphore/semaphore.go index e1f06af5d..cb5c7c8a3 100644 --- a/pkg/foundation/semaphore/semaphore.go +++ b/pkg/foundation/semaphore/semaphore.go @@ -21,29 +21,19 @@ import ( "github.com/conduitio/conduit/pkg/foundation/cerrors" ) -// NewWeighted creates a new weighted semaphore with the given -// maximum combined weight for concurrent access. -func NewWeighted(n int64) *Weighted { - w := &Weighted{size: n} - return w -} - -// Weighted provides a way to bound concurrent access to a resource. -// The callers can request access with a given weight. -type Weighted struct { - size int64 - cur int64 +// Simple provides a way to bound concurrent access to a resource. It only +// allows one caller to gain access at a time. +type Simple struct { + waiters []waiter + front int + batch int64 + acquired bool released int mu sync.Mutex - - waiters []waiter - front int - batch int64 } type waiter struct { index int - n int64 ready chan struct{} // Closed when semaphore acquired. released bool @@ -55,16 +45,12 @@ type Ticket struct { batch int64 } -func (s *Weighted) Enqueue(n int64) Ticket { - if n > s.size { - panic("semaphore: tried to enqueue more than size of semaphore") - } - +func (s *Simple) Enqueue() Ticket { s.mu.Lock() defer s.mu.Unlock() index := len(s.waiters) - w := waiter{index: index, n: n, ready: make(chan struct{})} + w := waiter{index: index, ready: make(chan struct{})} s.waiters = append(s.waiters, w) return Ticket{ @@ -73,12 +59,10 @@ func (s *Weighted) Enqueue(n int64) Ticket { } } -// Acquire acquires the semaphore with a weight of n, blocking until resources -// are available or ctx is done. On success, returns nil. On failure, returns -// ctx.Err() and leaves the semaphore unchanged. -// -// If ctx is already done, Acquire may still succeed without blocking. -func (s *Weighted) Acquire(t Ticket) error { +// Acquire acquires the semaphore, blocking until resources are available. On +// success, returns nil. On failure, returns an error and leaves the semaphore +// unchanged. +func (s *Simple) Acquire(t Ticket) error { s.mu.Lock() if s.batch != t.batch { s.mu.Unlock() @@ -93,13 +77,9 @@ func (s *Weighted) Acquire(t Ticket) error { w.acquired = true // mark that Acquire was already called for this Ticket s.waiters[t.index] = w - if s.front == t.index && s.size-s.cur >= w.n { - s.cur += w.n + if s.front == t.index && !s.acquired { s.front++ - // If there are extra tokens left, notify other waiters. - if s.size > s.cur { - s.notifyWaiters() - } + s.acquired = true s.mu.Unlock() return nil } @@ -109,8 +89,10 @@ func (s *Weighted) Acquire(t Ticket) error { return nil } -// Release releases the semaphore with a weight of n. -func (s *Weighted) Release(t Ticket) error { +// Release releases the semaphore and notifies the next in line if any. +// If the ticket is not holding the lock on the semaphore the function returns +// an error. +func (s *Simple) Release(t Ticket) error { s.mu.Lock() defer s.mu.Unlock() @@ -125,43 +107,29 @@ func (s *Weighted) Release(t Ticket) error { return cerrors.New("semaphore: ticket already released") } - s.cur -= w.n w.released = true s.waiters[t.index] = w + s.acquired = false s.released++ - s.notifyWaiters() + s.notifyWaiter() if s.released == len(s.waiters) { s.increaseBatch() } return nil } -func (s *Weighted) notifyWaiters() { - for len(s.waiters) > s.front { +func (s *Simple) notifyWaiter() { + if len(s.waiters) > s.front { w := s.waiters[s.front] - if s.size-s.cur < w.n { - // Not enough tokens for the next waiter. We could keep going (to try to - // find a waiter with a smaller request), but under load that could cause - // starvation for large requests; instead, we leave all remaining waiters - // blocked. - // - // Consider a semaphore used as a read-write lock, with N tokens, N - // readers, and one writer. Each reader can Acquire(1) to obtain a read - // lock. The writer can Acquire(N) to obtain a write lock, excluding all - // of the readers. If we allow the readers to jump ahead in the queue, - // the writer will starve — there is always one token available for every - // reader. - break - } - - s.cur += w.n + s.acquired = true s.front++ close(w.ready) } } -func (s *Weighted) increaseBatch() { +func (s *Simple) increaseBatch() { s.waiters = s.waiters[:0] s.batch += 1 s.front = 0 + s.released = 0 } diff --git a/pkg/foundation/semaphore/semaphore_bench_test.go b/pkg/foundation/semaphore/semaphore_bench_test.go index 73dcfb686..8a72c4826 100644 --- a/pkg/foundation/semaphore/semaphore_bench_test.go +++ b/pkg/foundation/semaphore/semaphore_bench_test.go @@ -22,22 +22,14 @@ import ( "github.com/conduitio/conduit/pkg/foundation/semaphore" ) -// weighted is an interface matching a subset of *Weighted. It allows -// alternate implementations for testing and benchmarking. -type weighted interface { - Enqueue(int64) semaphore.Ticket - Acquire(semaphore.Ticket) error - Release(semaphore.Ticket) error -} - // acquireN calls Acquire(size) on sem N times and then calls Release(size) N times. -func acquireN(b *testing.B, sem weighted, size int64, N int) { +func acquireN(b *testing.B, sem *semaphore.Simple, N int) { b.ResetTimer() tickets := list.New() for i := 0; i < b.N; i++ { tickets.Init() for j := 0; j < N; j++ { - _ = sem.Enqueue(size) + _ = sem.Enqueue() } } } @@ -46,7 +38,7 @@ func BenchmarkNewSem(b *testing.B) { for _, cap := range []int64{1, 128} { b.Run(fmt.Sprintf("Weighted-%d", cap), func(b *testing.B) { for i := 0; i < b.N; i++ { - _ = semaphore.NewWeighted(cap) + _ = &semaphore.Simple{} } }) } @@ -54,28 +46,28 @@ func BenchmarkNewSem(b *testing.B) { func BenchmarkAcquireSem(b *testing.B) { for _, c := range []struct { - cap, size int64 - N int + cap int64 + N int }{ - {1, 1, 1}, - {2, 1, 1}, - {16, 1, 1}, - {128, 1, 1}, - {2, 2, 1}, - {16, 2, 8}, - {128, 2, 64}, - {2, 1, 2}, - {16, 8, 2}, - {128, 64, 2}, + {1, 1}, + {2, 1}, + {16, 1}, + {128, 1}, + {2, 1}, + {16, 8}, + {128, 64}, + {2, 2}, + {16, 2}, + {128, 2}, } { for _, w := range []struct { name string - w weighted + w *semaphore.Simple }{ - {"Weighted", semaphore.NewWeighted(c.cap)}, + {"Simple", &semaphore.Simple{}}, } { - b.Run(fmt.Sprintf("%s-acquire-%d-%d-%d", w.name, c.cap, c.size, c.N), func(b *testing.B) { - acquireN(b, w.w, c.size, c.N) + b.Run(fmt.Sprintf("%s-acquire-%d-%d", w.name, c.cap, c.N), func(b *testing.B) { + acquireN(b, w.w, c.N) }) } } diff --git a/pkg/foundation/semaphore/semaphore_test.go b/pkg/foundation/semaphore/semaphore_test.go index 9845f58fc..0e72eb7da 100644 --- a/pkg/foundation/semaphore/semaphore_test.go +++ b/pkg/foundation/semaphore/semaphore_test.go @@ -26,9 +26,9 @@ import ( const maxSleep = 1 * time.Millisecond -func HammerWeighted(sem *semaphore.Weighted, n int64, loops int) { +func HammerSimple(sem *semaphore.Simple, loops int) { for i := 0; i < loops; i++ { - tkn := sem.Enqueue(n) + tkn := sem.Enqueue() err := sem.Acquire(tkn) if err != nil { panic(err) @@ -38,41 +38,40 @@ func HammerWeighted(sem *semaphore.Weighted, n int64, loops int) { } } -func TestWeighted(t *testing.T) { +func TestSimple(t *testing.T) { t.Parallel() n := runtime.GOMAXPROCS(0) - loops := 10000 / n - sem := semaphore.NewWeighted(int64(n)) + loops := 5000 / n + sem := &semaphore.Simple{} var wg sync.WaitGroup wg.Add(n) for i := 0; i < n; i++ { - i := i go func() { defer wg.Done() - HammerWeighted(sem, int64(i), loops) + HammerSimple(sem, loops) }() } wg.Wait() } -func TestWeightedReleaseUnacquired(t *testing.T) { +func TestSimpleReleaseUnacquired(t *testing.T) { t.Parallel() - w := semaphore.NewWeighted(1) - tkn := w.Enqueue(1) + w := &semaphore.Simple{} + tkn := w.Enqueue() err := w.Release(tkn) if err == nil { t.Errorf("release of an unacquired ticket did not return an error") } } -func TestWeightedReleaseTwice(t *testing.T) { +func TestSimpleReleaseTwice(t *testing.T) { t.Parallel() - w := semaphore.NewWeighted(1) - tkn := w.Enqueue(1) + w := &semaphore.Simple{} + tkn := w.Enqueue() w.Acquire(tkn) err := w.Release(tkn) if err != nil { @@ -85,11 +84,11 @@ func TestWeightedReleaseTwice(t *testing.T) { } } -func TestWeightedAcquireTwice(t *testing.T) { +func TestSimpleAcquireTwice(t *testing.T) { t.Parallel() - w := semaphore.NewWeighted(1) - tkn := w.Enqueue(1) + w := &semaphore.Simple{} + tkn := w.Enqueue() err := w.Acquire(tkn) if err != nil { t.Errorf("acquire of a ticket errored out: %v", err) @@ -101,51 +100,35 @@ func TestWeightedAcquireTwice(t *testing.T) { } } -func TestWeightedPanicEnqueueTooBig(t *testing.T) { +func TestSimpleAcquire(t *testing.T) { t.Parallel() - defer func() { - if recover() == nil { - t.Fatal("enqueue of size bigger than weighted semaphore did not panic") - } - }() - const n = 5 - sem := semaphore.NewWeighted(n) - sem.Enqueue(n + 1) -} + sem := &semaphore.Simple{} -func TestWeightedAcquire(t *testing.T) { - t.Parallel() - - sem := semaphore.NewWeighted(2) - - tkn1 := sem.Enqueue(1) + tkn1 := sem.Enqueue() sem.Acquire(tkn1) - tkn2 := sem.Enqueue(1) - sem.Acquire(tkn2) - - tkn3done := make(chan struct{}) + tkn2done := make(chan struct{}) go func() { - defer close(tkn3done) - tkn3 := sem.Enqueue(1) - sem.Acquire(tkn3) + defer close(tkn2done) + tkn2 := sem.Enqueue() + sem.Acquire(tkn2) }() select { - case <-tkn3done: - t.Errorf("tkn3done closed prematurely") + case <-tkn2done: + t.Errorf("tkn2done closed prematurely") case <-time.After(time.Millisecond * 10): - // tkn3 Acquire is blocking as expected + // tkn2 Acquire is blocking as expected } sem.Release(tkn1) select { - case <-tkn3done: + case <-tkn2done: // tkn3 successfully acquired the semaphore case <-time.After(time.Millisecond * 10): - t.Errorf("tkn3done didn't get closed") + t.Errorf("tkn2done didn't get closed") } } @@ -155,13 +138,13 @@ func TestLargeAcquireDoesntStarve(t *testing.T) { t.Parallel() n := int64(runtime.GOMAXPROCS(0)) - sem := semaphore.NewWeighted(n) + sem := &semaphore.Simple{} running := true var wg sync.WaitGroup wg.Add(int(n)) for i := n; i > 0; i-- { - tkn := sem.Enqueue(1) + tkn := sem.Enqueue() sem.Acquire(tkn) go func() { defer func() { @@ -171,13 +154,13 @@ func TestLargeAcquireDoesntStarve(t *testing.T) { for running { time.Sleep(1 * time.Millisecond) sem.Release(tkn) - tkn = sem.Enqueue(1) + tkn = sem.Enqueue() sem.Acquire(tkn) } }() } - tkn := sem.Enqueue(n) + tkn := sem.Enqueue() sem.Acquire(tkn) running = false sem.Release(tkn) From 0471fbe2500123ffb06ca388eb9267a47d37fef8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Tue, 21 Jun 2022 18:45:20 +0200 Subject: [PATCH 07/46] fix linter warnings, better benchmarks --- pkg/foundation/semaphore/semaphore.go | 3 +- .../semaphore/semaphore_bench_test.go | 69 +++++++++---------- pkg/foundation/semaphore/semaphore_test.go | 61 ++++++++++++---- 3 files changed, 81 insertions(+), 52 deletions(-) diff --git a/pkg/foundation/semaphore/semaphore.go b/pkg/foundation/semaphore/semaphore.go index cb5c7c8a3..86f2279ae 100644 --- a/pkg/foundation/semaphore/semaphore.go +++ b/pkg/foundation/semaphore/semaphore.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package semaphore provides a weighted semaphore implementation. package semaphore import ( @@ -129,7 +128,7 @@ func (s *Simple) notifyWaiter() { func (s *Simple) increaseBatch() { s.waiters = s.waiters[:0] - s.batch += 1 + s.batch++ s.front = 0 s.released = 0 } diff --git a/pkg/foundation/semaphore/semaphore_bench_test.go b/pkg/foundation/semaphore/semaphore_bench_test.go index 8a72c4826..8858428ed 100644 --- a/pkg/foundation/semaphore/semaphore_bench_test.go +++ b/pkg/foundation/semaphore/semaphore_bench_test.go @@ -22,53 +22,46 @@ import ( "github.com/conduitio/conduit/pkg/foundation/semaphore" ) -// acquireN calls Acquire(size) on sem N times and then calls Release(size) N times. -func acquireN(b *testing.B, sem *semaphore.Simple, N int) { - b.ResetTimer() - tickets := list.New() +func BenchmarkNewSem(b *testing.B) { for i := 0; i < b.N; i++ { - tickets.Init() - for j := 0; j < N; j++ { - _ = sem.Enqueue() - } + _ = &semaphore.Simple{} } } -func BenchmarkNewSem(b *testing.B) { - for _, cap := range []int64{1, 128} { - b.Run(fmt.Sprintf("Weighted-%d", cap), func(b *testing.B) { +func BenchmarkAcquireSem(b *testing.B) { + for _, N := range []int{1, 2, 8, 64, 128} { + b.Run(fmt.Sprintf("acquire-%d", N), func(b *testing.B) { + b.ResetTimer() + sem := &semaphore.Simple{} for i := 0; i < b.N; i++ { - _ = &semaphore.Simple{} + for j := 0; j < N; j++ { + t := sem.Enqueue() + _ = sem.Acquire(t) + _ = sem.Release(t) + } } }) } } -func BenchmarkAcquireSem(b *testing.B) { - for _, c := range []struct { - cap int64 - N int - }{ - {1, 1}, - {2, 1}, - {16, 1}, - {128, 1}, - {2, 1}, - {16, 8}, - {128, 64}, - {2, 2}, - {16, 2}, - {128, 2}, - } { - for _, w := range []struct { - name string - w *semaphore.Simple - }{ - {"Simple", &semaphore.Simple{}}, - } { - b.Run(fmt.Sprintf("%s-acquire-%d-%d", w.name, c.cap, c.N), func(b *testing.B) { - acquireN(b, w.w, c.N) - }) - } +func BenchmarkEnqueueReleaseSem(b *testing.B) { + for _, N := range []int{1, 2, 8, 64, 128} { + b.Run(fmt.Sprintf("enqueue/release-%d", N), func(b *testing.B) { + b.ResetTimer() + sem := &semaphore.Simple{} + tickets := list.New() + for i := 0; i < b.N; i++ { + tickets.Init() + for j := 0; j < N; j++ { + t := sem.Enqueue() + tickets.PushBack(t) + } + ticket := tickets.Front() + for ticket != nil { + _ = sem.Release(ticket.Value.(semaphore.Ticket)) + ticket = ticket.Next() + } + } + }) } } diff --git a/pkg/foundation/semaphore/semaphore_test.go b/pkg/foundation/semaphore/semaphore_test.go index 0e72eb7da..021887c6c 100644 --- a/pkg/foundation/semaphore/semaphore_test.go +++ b/pkg/foundation/semaphore/semaphore_test.go @@ -33,8 +33,12 @@ func HammerSimple(sem *semaphore.Simple, loops int) { if err != nil { panic(err) } + //nolint:gosec // math/rand is good enough for a test time.Sleep(time.Duration(rand.Int63n(int64(maxSleep/time.Nanosecond))) * time.Nanosecond) - sem.Release(tkn) + err = sem.Release(tkn) + if err != nil { + panic(err) + } } } @@ -72,8 +76,11 @@ func TestSimpleReleaseTwice(t *testing.T) { w := &semaphore.Simple{} tkn := w.Enqueue() - w.Acquire(tkn) - err := w.Release(tkn) + err := w.Acquire(tkn) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + err = w.Release(tkn) if err != nil { t.Errorf("release of an acquired ticket errored out: %v", err) } @@ -106,13 +113,19 @@ func TestSimpleAcquire(t *testing.T) { sem := &semaphore.Simple{} tkn1 := sem.Enqueue() - sem.Acquire(tkn1) + err := sem.Acquire(tkn1) + if err != nil { + t.Errorf("unexpected error: %v", err) + } tkn2done := make(chan struct{}) go func() { defer close(tkn2done) tkn2 := sem.Enqueue() - sem.Acquire(tkn2) + err := sem.Acquire(tkn2) + if err != nil { + t.Errorf("unexpected error: %v", err) + } }() select { @@ -122,7 +135,10 @@ func TestSimpleAcquire(t *testing.T) { // tkn2 Acquire is blocking as expected } - sem.Release(tkn1) + err = sem.Release(tkn1) + if err != nil { + t.Errorf("unexpected error: %v", err) + } select { case <-tkn2done: @@ -145,24 +161,45 @@ func TestLargeAcquireDoesntStarve(t *testing.T) { wg.Add(int(n)) for i := n; i > 0; i-- { tkn := sem.Enqueue() - sem.Acquire(tkn) + err := sem.Acquire(tkn) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + go func() { defer func() { - sem.Release(tkn) + err := sem.Release(tkn) + if err != nil { + t.Errorf("unexpected error: %v", err) + } wg.Done() }() for running { time.Sleep(1 * time.Millisecond) - sem.Release(tkn) + err := sem.Release(tkn) + if err != nil { + t.Errorf("unexpected error: %v", err) + } tkn = sem.Enqueue() - sem.Acquire(tkn) + err = sem.Acquire(tkn) + if err != nil { + t.Errorf("unexpected error: %v", err) + } } }() } tkn := sem.Enqueue() - sem.Acquire(tkn) + err := sem.Acquire(tkn) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + running = false - sem.Release(tkn) + err = sem.Release(tkn) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + wg.Wait() } From 83f818429749ec18459980335fd55194546e10a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Tue, 21 Jun 2022 18:56:03 +0200 Subject: [PATCH 08/46] better docs --- pkg/foundation/semaphore/semaphore.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pkg/foundation/semaphore/semaphore.go b/pkg/foundation/semaphore/semaphore.go index 86f2279ae..9a267eba4 100644 --- a/pkg/foundation/semaphore/semaphore.go +++ b/pkg/foundation/semaphore/semaphore.go @@ -39,11 +39,16 @@ type waiter struct { acquired bool } +// Ticket reserves a place in the queue and can be used to acquire access to a +// resource. type Ticket struct { index int batch int64 } +// Enqueue reserves the next place in the queue and returns a Ticket used to +// acquire access to the resource when it's the callers turn. The Ticket has to +// be supplied to Release before discarding. func (s *Simple) Enqueue() Ticket { s.mu.Lock() defer s.mu.Unlock() @@ -89,8 +94,8 @@ func (s *Simple) Acquire(t Ticket) error { } // Release releases the semaphore and notifies the next in line if any. -// If the ticket is not holding the lock on the semaphore the function returns -// an error. +// If the ticket was already released the function returns an error. After the +// ticket is released it should be discarded. func (s *Simple) Release(t Ticket) error { s.mu.Lock() defer s.mu.Unlock() From 83c97e0120d36225b2df7eaa6540fa70365156ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Tue, 21 Jun 2022 18:58:30 +0200 Subject: [PATCH 09/46] go mod tidy --- go.mod | 1 - go.sum | 2 -- 2 files changed, 3 deletions(-) diff --git a/go.mod b/go.mod index 44b4bb415..81cf5535c 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,6 @@ require ( github.com/conduitio/conduit-connector-sdk v0.2.0 github.com/dgraph-io/badger/v3 v3.2103.2 github.com/dop251/goja v0.0.0-20210225094849-f3cfc97811c0 - github.com/gammazero/deque v0.2.0 github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.5.8 github.com/google/uuid v1.3.0 diff --git a/go.sum b/go.sum index 7d881de12..e4a9d71b3 100644 --- a/go.sum +++ b/go.sum @@ -205,8 +205,6 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWpgI= github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU= -github.com/gammazero/deque v0.2.0 h1:SkieyNB4bg2/uZZLxvya0Pq6diUlwx7m2TeT7GAIWaA= -github.com/gammazero/deque v0.2.0/go.mod h1:LFroj8x4cMYCukHJDbxFCkT+r9AndaJnFMuZDV34tuU= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-fonts/dejavu v0.1.0/go.mod h1:4Wt4I4OU2Nq9asgDCteaAaWZOV24E+0/Pwo0gppep4g= github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3T0ecnM9pNujks= From c0ded08951cb2019997cd7511d54006b5bfa704e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Fri, 10 Jun 2022 19:02:08 +0200 Subject: [PATCH 10/46] rename AckerNode to DestinationAckerNode --- pkg/pipeline/lifecycle.go | 4 +-- pkg/pipeline/stream/destination.go | 2 +- .../stream/{acker.go => destination_acker.go} | 29 ++++++++++--------- ...cker_test.go => destination_acker_test.go} | 4 +-- pkg/pipeline/stream/stream_test.go | 6 ++-- 5 files changed, 23 insertions(+), 22 deletions(-) rename pkg/pipeline/stream/{acker.go => destination_acker.go} (90%) rename pkg/pipeline/stream/{acker_test.go => destination_acker_test.go} (97%) diff --git a/pkg/pipeline/lifecycle.go b/pkg/pipeline/lifecycle.go index e7b8789c0..954166516 100644 --- a/pkg/pipeline/lifecycle.go +++ b/pkg/pipeline/lifecycle.go @@ -290,8 +290,8 @@ func (s *Service) buildMetricsNode( func (s *Service) buildAckerNode( dest connector.Destination, -) *stream.AckerNode { - return &stream.AckerNode{ +) *stream.DestinationAckerNode { + return &stream.DestinationAckerNode{ Name: dest.ID() + "-acker", Destination: dest, } diff --git a/pkg/pipeline/stream/destination.go b/pkg/pipeline/stream/destination.go index 467b97f7e..0c75f2179 100644 --- a/pkg/pipeline/stream/destination.go +++ b/pkg/pipeline/stream/destination.go @@ -30,7 +30,7 @@ type DestinationNode struct { Destination connector.Destination ConnectorTimer metrics.Timer // AckerNode is responsible for handling acks - AckerNode *AckerNode + AckerNode *DestinationAckerNode base subNodeBase logger log.CtxLogger diff --git a/pkg/pipeline/stream/acker.go b/pkg/pipeline/stream/destination_acker.go similarity index 90% rename from pkg/pipeline/stream/acker.go rename to pkg/pipeline/stream/destination_acker.go index 3f819605b..f2fa04861 100644 --- a/pkg/pipeline/stream/acker.go +++ b/pkg/pipeline/stream/destination_acker.go @@ -27,9 +27,9 @@ import ( "github.com/conduitio/conduit/pkg/record" ) -// AckerNode is responsible for handling acknowledgments received from the -// destination and forwarding them to the correct message. -type AckerNode struct { +// DestinationAckerNode is responsible for handling acknowledgments received +// from the destination and forwarding them to the correct message. +type DestinationAckerNode struct { Name string Destination connector.Destination @@ -49,8 +49,8 @@ type AckerNode struct { stopOnce sync.Once } -// init initializes AckerNode internal fields. -func (n *AckerNode) init() { +// init initializes DestinationAckerNode internal fields. +func (n *DestinationAckerNode) init() { n.initOnce.Do(func() { n.cache = &positionMessageMap{} n.start = make(chan struct{}) @@ -58,13 +58,13 @@ func (n *AckerNode) init() { }) } -func (n *AckerNode) ID() string { +func (n *DestinationAckerNode) ID() string { return n.Name } // Run continuously fetches acks from the destination and forwards them to the // correct message by calling Ack or Nack on that message. -func (n *AckerNode) Run(ctx context.Context) (err error) { +func (n *DestinationAckerNode) Run(ctx context.Context) (err error) { n.logger.Trace(ctx).Msg("starting acker node") defer n.logger.Trace(ctx).Msg("acker node stopped") @@ -122,7 +122,7 @@ func (n *AckerNode) Run(ctx context.Context) (err error) { // teardown will drop all messages still in the cache and return an error in // case there were still unprocessed messages in the cache. -func (n *AckerNode) teardown() error { +func (n *DestinationAckerNode) teardown() error { var dropped int n.cache.Range(func(pos record.Position, msg *Message) bool { msg.Drop() @@ -138,7 +138,7 @@ func (n *AckerNode) teardown() error { // handleAck either acks or nacks the message, depending on the supplied error. // If the nacking or acking fails, the message is dropped and the error is // returned. -func (n *AckerNode) handleAck(msg *Message, err error) error { +func (n *DestinationAckerNode) handleAck(msg *Message, err error) error { switch { case err != nil: n.logger.Trace(msg.Ctx).Err(err).Msg("nacking message") @@ -160,7 +160,7 @@ func (n *AckerNode) handleAck(msg *Message, err error) error { // ExpectAck makes the handler aware of the message and signals to it that an // ack for this message might be received at some point. -func (n *AckerNode) ExpectAck(msg *Message) error { +func (n *DestinationAckerNode) ExpectAck(msg *Message) error { // happens only once to signal Run that the destination is ready to be used. n.startOnce.Do(func() { n.init() @@ -185,7 +185,7 @@ func (n *AckerNode) ExpectAck(msg *Message) error { // ForgetAndDrop signals the handler that an ack for this message won't be // received, and it should remove it from its cache. In case an ack for this // message wasn't yet received it drops the message, otherwise it does nothing. -func (n *AckerNode) ForgetAndDrop(msg *Message) { +func (n *DestinationAckerNode) ForgetAndDrop(msg *Message) { _, ok := n.cache.LoadAndDelete(msg.Record.Position) if !ok { // message wasn't found in the cache, looks like the message was already @@ -197,8 +197,9 @@ func (n *AckerNode) ForgetAndDrop(msg *Message) { // Wait can be used to wait for the count of outstanding acks to drop to 0 or // the context gets canceled. Wait is expected to be the last function called on -// AckerNode, after Wait returns AckerNode will soon stop running. -func (n *AckerNode) Wait(ctx context.Context) { +// DestinationAckerNode, after Wait returns DestinationAckerNode will soon stop +// running. +func (n *DestinationAckerNode) Wait(ctx context.Context) { // happens only once to signal that the destination is stopping n.stopOnce.Do(func() { n.init() @@ -227,7 +228,7 @@ func (n *AckerNode) Wait(ctx context.Context) { } // SetLogger sets the logger. -func (n *AckerNode) SetLogger(logger log.CtxLogger) { +func (n *DestinationAckerNode) SetLogger(logger log.CtxLogger) { n.logger = logger } diff --git a/pkg/pipeline/stream/acker_test.go b/pkg/pipeline/stream/destination_acker_test.go similarity index 97% rename from pkg/pipeline/stream/acker_test.go rename to pkg/pipeline/stream/destination_acker_test.go index cc33e04c9..0a2a447af 100644 --- a/pkg/pipeline/stream/acker_test.go +++ b/pkg/pipeline/stream/destination_acker_test.go @@ -32,7 +32,7 @@ func TestAckerNode_Run_StopAfterWait(t *testing.T) { ctrl := gomock.NewController(t) dest := mock.NewDestination(ctrl) - node := &AckerNode{ + node := &DestinationAckerNode{ Name: "acker-node", Destination: dest, } @@ -69,7 +69,7 @@ func TestAckerNode_Run_StopAfterExpectAck(t *testing.T) { ctrl := gomock.NewController(t) dest := mock.NewDestination(ctrl) - node := &AckerNode{ + node := &DestinationAckerNode{ Name: "acker-node", Destination: dest, } diff --git a/pkg/pipeline/stream/stream_test.go b/pkg/pipeline/stream/stream_test.go index 8e0a1d909..57b1b2cb7 100644 --- a/pkg/pipeline/stream/stream_test.go +++ b/pkg/pipeline/stream/stream_test.go @@ -52,7 +52,7 @@ func Example_simpleStream() { Destination: printerDestination(ctrl, logger, "printer"), ConnectorTimer: noop.Timer{}, } - node3 := &stream.AckerNode{ + node3 := &stream.DestinationAckerNode{ Name: "printer-acker", Destination: node2.Destination, } @@ -144,12 +144,12 @@ func Example_complexStream() { Destination: printerDestination(ctrl, logger, "printer2"), ConnectorTimer: noop.Timer{}, } - node8 := &stream.AckerNode{ + node8 := &stream.DestinationAckerNode{ Name: "printer1-acker", Destination: node6.Destination, } node6.AckerNode = node8 - node9 := &stream.AckerNode{ + node9 := &stream.DestinationAckerNode{ Name: "printer2-acker", Destination: node7.Destination, } From c66c2739a1803373b8a179a88d84c2b027ac55e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Wed, 22 Jun 2022 19:29:25 +0200 Subject: [PATCH 11/46] remove message status change middleware to ensure all message handlers are called --- pkg/pipeline/lifecycle.go | 8 +-- pkg/pipeline/stream/fanout.go | 27 +++------ pkg/pipeline/stream/message.go | 71 +++++++---------------- pkg/pipeline/stream/message_test.go | 83 +++++++++++++-------------- pkg/pipeline/stream/metrics.go | 4 +- pkg/pipeline/stream/processor_test.go | 8 +-- pkg/pipeline/stream/source.go | 12 ++-- 7 files changed, 83 insertions(+), 130 deletions(-) diff --git a/pkg/pipeline/lifecycle.go b/pkg/pipeline/lifecycle.go index 954166516..9b867ab51 100644 --- a/pkg/pipeline/lifecycle.go +++ b/pkg/pipeline/lifecycle.go @@ -288,7 +288,7 @@ func (s *Service) buildMetricsNode( } } -func (s *Service) buildAckerNode( +func (s *Service) buildDestinationAckerNode( dest connector.Destination, ) *stream.DestinationAckerNode { return &stream.DestinationAckerNode{ @@ -316,7 +316,7 @@ func (s *Service) buildDestinationNodes( continue // skip any connector that's not a destination } - ackerNode := s.buildAckerNode(instance.(connector.Destination)) + ackerNode := s.buildDestinationAckerNode(instance.(connector.Destination)) destinationNode := stream.DestinationNode{ Name: instance.ID(), Destination: instance.(connector.Destination), @@ -358,7 +358,7 @@ func (s *Service) runPipeline(ctx context.Context, pl *Instance) error { // If any of the nodes stops, the nodesTomb will be put into a dying state // and ctx will be cancelled. // This way, the other nodes will be notified that they need to stop too. - //nolint: staticcheck // nil used to use the default (parent provided via WithContext) + // nolint: staticcheck // nil used to use the default (parent provided via WithContext) ctx := nodesTomb.Context(nil) s.logger.Trace(ctx).Str(log.NodeIDField, node.ID()).Msg("running node") defer func() { @@ -406,7 +406,7 @@ func (s *Service) runPipeline(ctx context.Context, pl *Instance) error { // before declaring the pipeline as stopped. pl.t = &tomb.Tomb{} pl.t.Go(func() error { - //nolint: staticcheck // nil used to use the default (parent provided via WithContext) + // nolint: staticcheck // nil used to use the default (parent provided via WithContext) ctx := pl.t.Context(nil) err := nodesTomb.Wait() diff --git a/pkg/pipeline/stream/fanout.go b/pkg/pipeline/stream/fanout.go index 6e71aed26..c6c2aa9ec 100644 --- a/pkg/pipeline/stream/fanout.go +++ b/pkg/pipeline/stream/fanout.go @@ -146,38 +146,27 @@ func (n *FanoutNode) Run(ctx context.Context) error { // wrapAckHandler modifies the ack handler, so it's called with the original // message received by FanoutNode instead of the new message created by // FanoutNode. -func (n *FanoutNode) wrapAckHandler(origMsg *Message, f AckHandler) AckMiddleware { - return func(newMsg *Message, next AckHandler) error { - err := f(origMsg) - if err != nil { - return err - } - // next handler is called again with new message - return next(newMsg) +func (n *FanoutNode) wrapAckHandler(origMsg *Message, f AckHandler) AckHandler { + return func(_ *Message) error { + return f(origMsg) } } // wrapNackHandler modifies the nack handler, so it's called with the original // message received by FanoutNode instead of the new message created by // FanoutNode. -func (n *FanoutNode) wrapNackHandler(origMsg *Message, f NackHandler) NackMiddleware { - return func(newMsg *Message, reason error, next NackHandler) error { - err := f(origMsg, reason) - if err != nil { - return err - } - // next handler is called again with new message - return next(newMsg, err) +func (n *FanoutNode) wrapNackHandler(origMsg *Message, f NackHandler) NackHandler { + return func(_ *Message, reason error) error { + return f(origMsg, reason) } } // wrapDropHandler modifies the drop handler, so it's called with the original // message received by FanoutNode instead of the new message created by // FanoutNode. -func (n *FanoutNode) wrapDropHandler(origMsg *Message, f DropHandler) DropMiddleware { - return func(newMsg *Message, reason error, next DropHandler) { +func (n *FanoutNode) wrapDropHandler(origMsg *Message, f DropHandler) DropHandler { + return func(_ *Message, reason error) { f(origMsg, reason) - next(newMsg, reason) } } diff --git a/pkg/pipeline/stream/message.go b/pkg/pipeline/stream/message.go index f8d833297..5a2c761e3 100644 --- a/pkg/pipeline/stream/message.go +++ b/pkg/pipeline/stream/message.go @@ -22,6 +22,7 @@ import ( "sync" "github.com/conduitio/conduit/pkg/foundation/cerrors" + "github.com/conduitio/conduit/pkg/foundation/multierror" "github.com/conduitio/conduit/pkg/record" ) @@ -84,45 +85,21 @@ type ( // a nack or drop. StatusChangeHandler func(*Message, StatusChange) error - // StatusChangeMiddleware can be registered on a message and will be executed in - // case of a status change (see StatusChangeHandler). Middlewares are called in - // the reverse order of how they were registered. - // The middleware has two options when processing a message status change: - // - If it successfully processed the status change it should call the next - // handler and return its error. The handler may inspect the error and act - // accordingly, but it must return that error (or another error that - // contains it). It must not return an error if the next handler was called - // and it returned nil. - // - If it failed to process the status change successfully it must not call - // the next handler but instead return an error right away. - // Applying these rules means each middleware can be sure that all middlewares - // before it processed the status change successfully. - StatusChangeMiddleware func(*Message, StatusChange, StatusChangeHandler) error - // AckHandler is a variation of the StatusChangeHandler that is only called // when a message is acked. For more info see StatusChangeHandler. AckHandler func(*Message) error - // AckMiddleware is a variation of the StatusChangeMiddleware that is only - // called when a message is acked. For more info see StatusChangeMiddleware. - AckMiddleware func(*Message, AckHandler) error // NackHandler is a variation of the StatusChangeHandler that is only called // when a message is nacked. For more info see StatusChangeHandler. NackHandler func(*Message, error) error - // NackMiddleware is a variation of the StatusChangeMiddleware that is only - // called when a message is nacked. For more info see StatusChangeMiddleware. - NackMiddleware func(*Message, error, NackHandler) error // DropHandler is a variation of the StatusChangeHandler that is only called // when a message is dropped. For more info see StatusChangeHandler. DropHandler func(*Message, error) - // DropMiddleware is a variation of the StatusChangeMiddleware that is only - // called when a message is dropped. For more info see StatusChangeMiddleware. - DropMiddleware func(*Message, error, DropHandler) ) -// StatusChange is passed to StatusChangeMiddleware and StatusChangeHandler when -// the status of a message changes. +// StatusChange is passed to StatusChangeHandler when the status of a message +// changes. type StatusChange struct { Old MessageStatus New MessageStatus @@ -150,9 +127,9 @@ func (m *Message) ID() string { // RegisterStatusHandler is used to register a function that will be called on // any status change of the message. This function can only be called if the -// message status is open, otherwise it panics. Middlewares are called in the +// message status is open, otherwise it panics. Handlers are called in the // reverse order of how they were registered. -func (m *Message) RegisterStatusHandler(mw StatusChangeMiddleware) { +func (m *Message) RegisterStatusHandler(mw StatusChangeHandler) { m.init() m.handlerGuard.Lock() defer m.handlerGuard.Unlock() @@ -163,35 +140,34 @@ func (m *Message) RegisterStatusHandler(mw StatusChangeMiddleware) { next := m.handler m.handler = func(msg *Message, change StatusChange) error { - return mw(msg, change, next) + // all handlers are called and errors collected + err1 := mw(msg, change) + err2 := next(msg, change) + return multierror.Append(err1, err2) } } // RegisterAckHandler is used to register a function that will be called when // the message is acked. This function can only be called if the message status // is open, otherwise it panics. -func (m *Message) RegisterAckHandler(mw AckMiddleware) { - m.RegisterStatusHandler(func(msg *Message, change StatusChange, next StatusChangeHandler) error { +func (m *Message) RegisterAckHandler(mw AckHandler) { + m.RegisterStatusHandler(func(msg *Message, change StatusChange) error { if change.New != MessageStatusAcked { - return next(msg, change) + return nil // skip } - return mw(msg, func(msg *Message) error { - return next(msg, change) - }) + return mw(msg) }) } // RegisterNackHandler is used to register a function that will be called when // the message is nacked. This function can only be called if the message status // is open, otherwise it panics. -func (m *Message) RegisterNackHandler(mw NackMiddleware) { - m.RegisterStatusHandler(func(msg *Message, change StatusChange, next StatusChangeHandler) error { +func (m *Message) RegisterNackHandler(mw NackHandler) { + m.RegisterStatusHandler(func(msg *Message, change StatusChange) error { if change.New != MessageStatusNacked { - return next(msg, change) + return nil // skip } - return mw(msg, change.Reason, func(msg *Message, reason error) error { - return next(msg, change) - }) + return mw(msg, change.Reason) }) m.hasNackHandler = true } @@ -199,17 +175,12 @@ func (m *Message) RegisterNackHandler(mw NackMiddleware) { // RegisterDropHandler is used to register a function that will be called when // the message is dropped. This function can only be called if the message // status is open, otherwise it panics. -func (m *Message) RegisterDropHandler(mw DropMiddleware) { - m.RegisterStatusHandler(func(msg *Message, change StatusChange, next StatusChangeHandler) error { +func (m *Message) RegisterDropHandler(mw DropHandler) { + m.RegisterStatusHandler(func(msg *Message, change StatusChange) error { if change.New != MessageStatusDropped { - return next(msg, change) + return nil } - mw(msg, change.Reason, func(msg *Message, reason error) { - err := next(msg, change) - if err != nil { - panic(cerrors.Errorf("BUG: drop handlers should never return an error (message %s): %w", msg.ID(), err)) - } - }) + mw(msg, change.Reason) return nil }) } diff --git a/pkg/pipeline/stream/message_test.go b/pkg/pipeline/stream/message_test.go index 322009c9e..12c35b2f5 100644 --- a/pkg/pipeline/stream/message_test.go +++ b/pkg/pipeline/stream/message_test.go @@ -47,7 +47,7 @@ func TestMessage_Ack_WithHandler(t *testing.T) { ackedMessageHandlerCallCount int ) - msg.RegisterAckHandler(func(*Message, AckHandler) error { + msg.RegisterAckHandler(func(*Message) error { ackedMessageHandlerCallCount++ return nil }) @@ -90,35 +90,34 @@ func TestMessage_Ack_WithFailingHandler(t *testing.T) { ) { - // first handler should never be called - msg.RegisterAckHandler(func(*Message, AckHandler) error { - t.Fatalf("did not expect first handler to be called") + // first handler should still be called + msg.RegisterAckHandler(func(*Message) error { + ackedMessageHandlerCallCount++ return nil }) // second handler fails - msg.RegisterAckHandler(func(*Message, AckHandler) error { + msg.RegisterAckHandler(func(*Message) error { return wantErr }) // third handler should work as expected - msg.RegisterAckHandler(func(msg *Message, next AckHandler) error { + msg.RegisterAckHandler(func(msg *Message) error { ackedMessageHandlerCallCount++ - return next(msg) + return nil }) // fourth handler should be called twice, once for ack, once for drop - msg.RegisterStatusHandler(func(msg *Message, change StatusChange, next StatusChangeHandler) error { + msg.RegisterStatusHandler(func(msg *Message, change StatusChange) error { statusMessageHandlerCallCount++ - return next(msg, change) + return nil }) // drop handler should be called after the ack fails - msg.RegisterDropHandler(func(msg *Message, reason error, next DropHandler) { - if ackedMessageHandlerCallCount != 1 { - t.Fatal("expected acked message handler to already be called") + msg.RegisterDropHandler(func(msg *Message, reason error) { + if ackedMessageHandlerCallCount != 2 { + t.Fatal("expected acked message handlers to already be called") } droppedMessageHandlerCallCount++ - next(msg, reason) }) // nack handler should not be called - msg.RegisterNackHandler(func(*Message, error, NackHandler) error { + msg.RegisterNackHandler(func(*Message, error) error { t.Fatalf("did not expect nack handler to be called") return nil }) @@ -131,8 +130,8 @@ func TestMessage_Ack_WithFailingHandler(t *testing.T) { t.Fatalf("ack expected error %v, got: %v", wantErr, err) } assertMessageIsDropped(t, &msg) - if ackedMessageHandlerCallCount != 1 { - t.Fatalf("expected acked message handler to be called once, got %d calls", ackedMessageHandlerCallCount) + if ackedMessageHandlerCallCount != 2 { + t.Fatalf("expected acked message handler to be called twice, got %d calls", ackedMessageHandlerCallCount) } if droppedMessageHandlerCallCount != 1 { t.Fatalf("expected dropped message handler to be called once, got %d calls", droppedMessageHandlerCallCount) @@ -186,12 +185,12 @@ func TestMessage_Nack_WithHandler(t *testing.T) { nackedMessageHandlerCallCount int ) - msg.RegisterNackHandler(func(msg *Message, err error, next NackHandler) error { + msg.RegisterNackHandler(func(msg *Message, err error) error { nackedMessageHandlerCallCount++ if err != wantErr { t.Fatalf("nacked message handler, expected err %v, got %v", wantErr, err) } - return next(msg, err) + return nil }) err := msg.Nack(wantErr) @@ -225,35 +224,34 @@ func TestMessage_Nack_WithFailingHandler(t *testing.T) { ) { - // first handler should never be called - msg.RegisterNackHandler(func(*Message, error, NackHandler) error { - t.Fatalf("did not expect first handler to be called") + // first handler should still be called + msg.RegisterNackHandler(func(*Message, error) error { + nackedMessageHandlerCallCount++ return nil }) // second handler fails - msg.RegisterNackHandler(func(*Message, error, NackHandler) error { + msg.RegisterNackHandler(func(*Message, error) error { return wantErr }) // third handler should work as expected - msg.RegisterNackHandler(func(msg *Message, reason error, next NackHandler) error { + msg.RegisterNackHandler(func(msg *Message, reason error) error { nackedMessageHandlerCallCount++ - return next(msg, reason) + return nil }) // fourth handler should be called twice, once for ack, once for drop - msg.RegisterStatusHandler(func(msg *Message, change StatusChange, next StatusChangeHandler) error { + msg.RegisterStatusHandler(func(msg *Message, change StatusChange) error { statusMessageHandlerCallCount++ - return next(msg, change) + return nil }) // drop handler should be called after the nack fails - msg.RegisterDropHandler(func(msg *Message, reason error, next DropHandler) { - if nackedMessageHandlerCallCount != 1 { - t.Fatal("expected nacked message handler to already be called") + msg.RegisterDropHandler(func(msg *Message, reason error) { + if nackedMessageHandlerCallCount != 2 { + t.Fatal("expected nacked message handlers to already be called") } droppedMessageHandlerCallCount++ - next(msg, reason) }) // ack handler should not be called - msg.RegisterAckHandler(func(*Message, AckHandler) error { + msg.RegisterAckHandler(func(*Message) error { t.Fatalf("did not expect ack handler to be called") return nil }) @@ -266,8 +264,8 @@ func TestMessage_Nack_WithFailingHandler(t *testing.T) { t.Fatalf("nack expected error %v, got: %v", wantErr, err) } assertMessageIsDropped(t, &msg) - if nackedMessageHandlerCallCount != 1 { - t.Fatalf("expected nacked message handler to be called once, got %d calls", nackedMessageHandlerCallCount) + if nackedMessageHandlerCallCount != 2 { + t.Fatalf("expected nacked message handler to be called twice, got %d calls", nackedMessageHandlerCallCount) } if droppedMessageHandlerCallCount != 1 { t.Fatalf("expected dropped message handler to be called once, got %d calls", droppedMessageHandlerCallCount) @@ -309,14 +307,13 @@ func TestMessage_Drop_WithHandler(t *testing.T) { ) { - msg.RegisterDropHandler(func(msg *Message, reason error, next DropHandler) { + msg.RegisterDropHandler(func(msg *Message, reason error) { droppedMessageHandlerCallCount++ - next(msg, reason) }) // second handler should be called once for drop - msg.RegisterStatusHandler(func(msg *Message, change StatusChange, next StatusChangeHandler) error { + msg.RegisterStatusHandler(func(msg *Message, change StatusChange) error { statusMessageHandlerCallCount++ - return next(msg, change) + return nil }) } @@ -337,7 +334,7 @@ func TestMessage_Drop_WithFailingHandler(t *testing.T) { var msg Message // handler return error for drop - msg.RegisterStatusHandler(func(msg *Message, change StatusChange, next StatusChangeHandler) error { + msg.RegisterStatusHandler(func(msg *Message, change StatusChange) error { return cerrors.New("oops") }) @@ -391,7 +388,7 @@ func TestMessage_StatusChangeTwice(t *testing.T) { t.Run("nacked message", func(t *testing.T) { var msg Message // need to register a nack handler for message to be nacked - msg.RegisterNackHandler(func(*Message, error, NackHandler) error { return nil }) + msg.RegisterNackHandler(func(*Message, error) error { return nil }) err := msg.Nack(nil) if err != nil { t.Fatalf("ack did not expect error, got %v", err) @@ -424,7 +421,7 @@ func TestMessage_RegisterHandlerFail(t *testing.T) { t.Fatalf("expected msg.RegisterAckHandler to panic") } }() - msg.RegisterAckHandler(func(*Message, AckHandler) error { return nil }) + msg.RegisterAckHandler(func(*Message) error { return nil }) } assertRegisterNackHandlerPanics := func(msg *Message) { defer func() { @@ -432,7 +429,7 @@ func TestMessage_RegisterHandlerFail(t *testing.T) { t.Fatalf("expected msg.RegisterNackHandler to panic") } }() - msg.RegisterNackHandler(func(*Message, error, NackHandler) error { return nil }) + msg.RegisterNackHandler(func(*Message, error) error { return nil }) } assertRegisterDropHandlerPanics := func(msg *Message) { defer func() { @@ -440,7 +437,7 @@ func TestMessage_RegisterHandlerFail(t *testing.T) { t.Fatalf("expected msg.RegisterDropHandler to panic") } }() - msg.RegisterDropHandler(func(*Message, error, DropHandler) {}) + msg.RegisterDropHandler(func(*Message, error) {}) } // registering a handler after the message is acked should panic @@ -459,7 +456,7 @@ func TestMessage_RegisterHandlerFail(t *testing.T) { t.Run("nacked message", func(t *testing.T) { var msg Message // need to register a nack handler for message to be nacked - msg.RegisterNackHandler(func(*Message, error, NackHandler) error { return nil }) + msg.RegisterNackHandler(func(*Message, error) error { return nil }) err := msg.Nack(nil) if err != nil { t.Fatalf("ack did not expect error, got %v", err) diff --git a/pkg/pipeline/stream/metrics.go b/pkg/pipeline/stream/metrics.go index 406c5801d..b6dd3bb13 100644 --- a/pkg/pipeline/stream/metrics.go +++ b/pkg/pipeline/stream/metrics.go @@ -46,7 +46,7 @@ func (n *MetricsNode) Run(ctx context.Context) error { return err } - msg.RegisterAckHandler(func(msg *Message, next AckHandler) error { + msg.RegisterAckHandler(func(msg *Message) error { // TODO for now we call method Bytes() on key and payload to get the // bytes representation. In case of a structured payload or key it // is marshaled into JSON, which might not be the correct way to @@ -60,7 +60,7 @@ func (n *MetricsNode) Run(ctx context.Context) error { bytes += len(msg.Record.Payload.Bytes()) } n.BytesHistogram.Observe(float64(bytes)) - return next(msg) + return nil }) err = n.base.Send(ctx, n.logger, msg) diff --git a/pkg/pipeline/stream/processor_test.go b/pkg/pipeline/stream/processor_test.go index 1bb324425..43183f554 100644 --- a/pkg/pipeline/stream/processor_test.go +++ b/pkg/pipeline/stream/processor_test.go @@ -143,9 +143,9 @@ func TestProcessorNode_ErrorWithNackHandler(t *testing.T) { out := n.Pub() msg := &Message{Ctx: ctx} - msg.RegisterNackHandler(func(msg *Message, err error, next NackHandler) error { + msg.RegisterNackHandler(func(msg *Message, err error) error { assert.True(t, cerrors.Is(err, wantErr), "expected underlying error to be the transform error") - return next(msg, err) // the error should be regarded as handled + return nil // the error should be regarded as handled }) go func() { // publisher @@ -186,11 +186,11 @@ func TestProcessorNode_Skip(t *testing.T) { // register a dummy AckHandler and NackHandler for tests. counter := 0 - msg.RegisterAckHandler(func(msg *Message, next AckHandler) error { + msg.RegisterAckHandler(func(msg *Message) error { counter++ return nil }) - msg.RegisterNackHandler(func(msg *Message, err error, next NackHandler) error { + msg.RegisterNackHandler(func(msg *Message, err error) error { // Our NackHandler shouldn't ever be hit if we're correctly skipping // so fail the test if we get here at all. t.Fail() diff --git a/pkg/pipeline/stream/source.go b/pkg/pipeline/stream/source.go index 86bd8a788..16d8339fc 100644 --- a/pkg/pipeline/stream/source.go +++ b/pkg/pipeline/stream/source.go @@ -104,24 +104,20 @@ func (n *SourceNode) Run(ctx context.Context) (err error) { // register another open message wgOpenMessages.Add(1) msg.RegisterStatusHandler( - func(msg *Message, change StatusChange, next StatusChangeHandler) error { + func(msg *Message, change StatusChange) error { // this is the last handler to be executed, once this handler is // reached we know either the message was successfully acked, nacked // or dropped defer n.PipelineTimer.Update(time.Since(msg.Record.ReadAt)) defer wgOpenMessages.Done() - return next(msg, change) + return nil }, ) msg.RegisterAckHandler( - func(msg *Message, next AckHandler) error { + func(msg *Message) error { n.logger.Trace(msg.Ctx).Msg("forwarding ack to source connector") - err := n.Source.Ack(msg.Ctx, msg.Record.Position) - if err != nil { - return err - } - return next(msg) + return n.Source.Ack(msg.Ctx, msg.Record.Position) }, ) From 68cd7c7e18066500063db2e1fa247057e0077c3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Wed, 22 Jun 2022 20:26:08 +0200 Subject: [PATCH 12/46] implement SourceAckerNode --- pkg/pipeline/lifecycle.go | 15 +++- pkg/pipeline/lifecycle_test.go | 1 + pkg/pipeline/stream/source.go | 7 -- pkg/pipeline/stream/source_acker.go | 129 ++++++++++++++++++++++++++++ pkg/pipeline/stream/stream_test.go | 80 ++++++++++------- 5 files changed, 190 insertions(+), 42 deletions(-) create mode 100644 pkg/pipeline/stream/source_acker.go diff --git a/pkg/pipeline/lifecycle.go b/pkg/pipeline/lifecycle.go index 9b867ab51..3c2e6f17f 100644 --- a/pkg/pipeline/lifecycle.go +++ b/pkg/pipeline/lifecycle.go @@ -233,6 +233,15 @@ func (s *Service) buildProcessorNodes( return nodes, nil } +func (s *Service) buildSourceAckerNode( + src connector.Source, +) *stream.SourceAckerNode { + return &stream.SourceAckerNode{ + Name: src.ID() + "-acker", + Source: src, + } +} + func (s *Service) buildSourceNodes( ctx context.Context, connFetcher ConnectorFetcher, @@ -259,15 +268,17 @@ func (s *Service) buildSourceNodes( pl.Config.Name, ), } + ackerNode := s.buildSourceAckerNode(instance.(connector.Source)) + ackerNode.Sub(sourceNode.Pub()) metricsNode := s.buildMetricsNode(pl, instance) - metricsNode.Sub(sourceNode.Pub()) + metricsNode.Sub(ackerNode.Pub()) procNodes, err := s.buildProcessorNodes(ctx, procFetcher, pl, instance.Config().ProcessorIDs, metricsNode, next) if err != nil { return nil, cerrors.Errorf("could not build processor nodes for connector %s: %w", instance.ID(), err) } - nodes = append(nodes, &sourceNode, metricsNode) + nodes = append(nodes, &sourceNode, ackerNode, metricsNode) nodes = append(nodes, procNodes...) } diff --git a/pkg/pipeline/lifecycle_test.go b/pkg/pipeline/lifecycle_test.go index c15b783fa..2424c3aa4 100644 --- a/pkg/pipeline/lifecycle_test.go +++ b/pkg/pipeline/lifecycle_test.go @@ -111,6 +111,7 @@ func TestServiceLifecycle_PipelineError(t *testing.T) { // wait for pipeline to finish err = pl.Wait() assert.Error(t, err) + t.Log(err) assert.Equal(t, StatusDegraded, pl.Status) // pipeline errors contain only string messages, so we can only compare the errors by the messages diff --git a/pkg/pipeline/stream/source.go b/pkg/pipeline/stream/source.go index 16d8339fc..9c827c5fc 100644 --- a/pkg/pipeline/stream/source.go +++ b/pkg/pipeline/stream/source.go @@ -114,13 +114,6 @@ func (n *SourceNode) Run(ctx context.Context) (err error) { }, ) - msg.RegisterAckHandler( - func(msg *Message) error { - n.logger.Trace(msg.Ctx).Msg("forwarding ack to source connector") - return n.Source.Ack(msg.Ctx, msg.Record.Position) - }, - ) - err = n.base.Send(ctx, n.logger, msg) if err != nil { msg.Drop() diff --git a/pkg/pipeline/stream/source_acker.go b/pkg/pipeline/stream/source_acker.go new file mode 100644 index 000000000..9dc6c759f --- /dev/null +++ b/pkg/pipeline/stream/source_acker.go @@ -0,0 +1,129 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stream + +import ( + "context" + + "github.com/conduitio/conduit/pkg/connector" + "github.com/conduitio/conduit/pkg/foundation/cerrors" + "github.com/conduitio/conduit/pkg/foundation/log" + "github.com/conduitio/conduit/pkg/foundation/semaphore" +) + +// SourceAckerNode is responsible for handling acknowledgments for messages of +// a specific source and forwarding them to the source in the correct order. +type SourceAckerNode struct { + Name string + Source connector.Source + + base pubSubNodeBase + logger log.CtxLogger + + // sem ensures acks are sent to the source in the correct order and only one + // at a time + sem semaphore.Simple +} + +func (n *SourceAckerNode) ID() string { + return n.Name +} + +func (n *SourceAckerNode) Run(ctx context.Context) error { + trigger, cleanup, err := n.base.Trigger(ctx, n.logger) + if err != nil { + return err + } + + defer cleanup() + for { + msg, err := trigger() + if err != nil || msg == nil { + return err + } + + // enqueue message in semaphore + ticket := n.sem.Enqueue() + n.registerAckHandler(msg, ticket) + n.registerNackHandler(msg, ticket) + + err = n.base.Send(ctx, n.logger, msg) + if err != nil { + msg.Drop() + return err + } + } +} + +func (n *SourceAckerNode) registerAckHandler(msg *Message, ticket semaphore.Ticket) { + msg.RegisterAckHandler( + func(msg *Message) (err error) { + defer func() { + tmpErr := n.sem.Release(ticket) + if err != nil { + // we are already returning an error, log this one instead + n.logger.Err(msg.Ctx, tmpErr).Msg("error releasing semaphore ticket for ack") + return + } + err = tmpErr + }() + n.logger.Trace(msg.Ctx).Msg("acquiring semaphore for ack") + err = n.sem.Acquire(ticket) + if err != nil { + return cerrors.Errorf("could not acquire semaphore for ack: %w", err) + } + n.logger.Trace(msg.Ctx).Msg("forwarding ack to source connector") + return n.Source.Ack(msg.Ctx, msg.Record.Position) + }, + ) +} + +func (n *SourceAckerNode) registerNackHandler(msg *Message, ticket semaphore.Ticket) { + msg.RegisterNackHandler( + func(msg *Message, reason error) (err error) { + defer func() { + tmpErr := n.sem.Release(ticket) + if err != nil { + // we are already returning an error, log this one instead + n.logger.Err(msg.Ctx, tmpErr).Msg("error releasing semaphore ticket for nack") + return + } + err = tmpErr + }() + n.logger.Trace(msg.Ctx).Msg("acquiring semaphore for nack") + err = n.sem.Acquire(ticket) + if err != nil { + return cerrors.Errorf("could not acquire semaphore for nack: %w", err) + } + n.logger.Trace(msg.Ctx).Msg("forwarding nack to DLQ handler") + // TODO implement DLQ and call it here, right now any nacked message + // will just stop the pipeline because we don't support DLQs + // https://github.com/ConduitIO/conduit/issues/306 + return cerrors.New("no DLQ handler configured") + }, + ) +} + +func (n *SourceAckerNode) Sub(in <-chan *Message) { + n.base.Sub(in) +} + +func (n *SourceAckerNode) Pub() <-chan *Message { + return n.base.Pub() +} + +func (n *SourceAckerNode) SetLogger(logger log.CtxLogger) { + n.logger = logger +} diff --git a/pkg/pipeline/stream/stream_test.go b/pkg/pipeline/stream/stream_test.go index 57b1b2cb7..18366f8b4 100644 --- a/pkg/pipeline/stream/stream_test.go +++ b/pkg/pipeline/stream/stream_test.go @@ -47,27 +47,33 @@ func Example_simpleStream() { Source: generatorSource(ctrl, logger, "generator", 10, time.Millisecond*10), PipelineTimer: noop.Timer{}, } - node2 := &stream.DestinationNode{ + node2 := &stream.SourceAckerNode{ + Name: "generator-acker", + Source: node1.Source, + } + node3 := &stream.DestinationNode{ Name: "printer", Destination: printerDestination(ctrl, logger, "printer"), ConnectorTimer: noop.Timer{}, } - node3 := &stream.DestinationAckerNode{ + node4 := &stream.DestinationAckerNode{ Name: "printer-acker", - Destination: node2.Destination, + Destination: node3.Destination, } - node2.AckerNode = node3 + node3.AckerNode = node4 stream.SetLogger(node1, logger) stream.SetLogger(node2, logger) stream.SetLogger(node3, logger) + stream.SetLogger(node4, logger) // put everything together - out := node1.Pub() - node2.Sub(out) + node2.Sub(node1.Pub()) + node3.Sub(node2.Pub()) var wg sync.WaitGroup - wg.Add(3) + wg.Add(4) + go runNode(ctx, &wg, node4) go runNode(ctx, &wg, node3) go runNode(ctx, &wg, node2) go runNode(ctx, &wg, node1) @@ -104,6 +110,7 @@ func Example_simpleStream() { // DBG received ack message_id=p/generator-10 node_id=generator // INF stopping source connector component=SourceNode node_id=generator // DBG received error on error channel error="error reading from source: stream not open" component=SourceNode node_id=generator + // DBG incoming messages channel closed component=SourceAckerNode node_id=generator-acker // DBG incoming messages channel closed component=DestinationNode node_id=printer // INF finished successfully } @@ -122,57 +129,62 @@ func Example_complexStream() { Source: generatorSource(ctrl, logger, "generator1", 10, time.Millisecond*10), PipelineTimer: noop.Timer{}, } - node2 := &stream.SourceNode{ + node2 := &stream.SourceAckerNode{ + Name: "generator1-acker", + Source: node1.Source, + } + node3 := &stream.SourceNode{ Name: "generator2", Source: generatorSource(ctrl, logger, "generator2", 10, time.Millisecond*10), PipelineTimer: noop.Timer{}, } - node3 := &stream.FaninNode{Name: "fanin"} - node4 := &stream.ProcessorNode{ + node4 := &stream.SourceAckerNode{ + Name: "generator2-acker", + Source: node3.Source, + } + node5 := &stream.FaninNode{Name: "fanin"} + node6 := &stream.ProcessorNode{ Name: "counter", Processor: counterProcessor(ctrl, &count), ProcessorTimer: noop.Timer{}, } - node5 := &stream.FanoutNode{Name: "fanout"} - node6 := &stream.DestinationNode{ + node7 := &stream.FanoutNode{Name: "fanout"} + node8 := &stream.DestinationNode{ Name: "printer1", Destination: printerDestination(ctrl, logger, "printer1"), ConnectorTimer: noop.Timer{}, } - node7 := &stream.DestinationNode{ + node9 := &stream.DestinationNode{ Name: "printer2", Destination: printerDestination(ctrl, logger, "printer2"), ConnectorTimer: noop.Timer{}, } - node8 := &stream.DestinationAckerNode{ + node10 := &stream.DestinationAckerNode{ Name: "printer1-acker", - Destination: node6.Destination, + Destination: node8.Destination, } - node6.AckerNode = node8 - node9 := &stream.DestinationAckerNode{ + node8.AckerNode = node10 + node11 := &stream.DestinationAckerNode{ Name: "printer2-acker", - Destination: node7.Destination, + Destination: node9.Destination, } - node7.AckerNode = node9 + node9.AckerNode = node11 // put everything together - out := node1.Pub() - node3.Sub(out) - out = node2.Pub() - node3.Sub(out) + node2.Sub(node1.Pub()) + node4.Sub(node3.Pub()) + + node5.Sub(node2.Pub()) + node5.Sub(node4.Pub()) - out = node3.Pub() - node4.Sub(out) - out = node4.Pub() - node5.Sub(out) + node6.Sub(node5.Pub()) + node7.Sub(node6.Pub()) - out = node5.Pub() - node6.Sub(out) - out = node5.Pub() - node7.Sub(out) + node8.Sub(node7.Pub()) + node9.Sub(node7.Pub()) // run nodes - nodes := []stream.Node{node1, node2, node3, node4, node5, node6, node7, node8, node9} + nodes := []stream.Node{node1, node2, node3, node4, node5, node6, node7, node8, node9, node10, node11} var wg sync.WaitGroup wg.Add(len(nodes)) @@ -186,7 +198,7 @@ func Example_complexStream() { 250*time.Millisecond, func() { node1.Stop(nil) - node2.Stop(nil) + node3.Stop(nil) }, ) // give the nodes some time to process the records, plus a bit of time to stop @@ -260,6 +272,8 @@ func Example_complexStream() { // DBG received ack message_id=p/generator1-10 node_id=generator1 // INF stopping source connector component=SourceNode node_id=generator1 // INF stopping source connector component=SourceNode node_id=generator2 + // DBG incoming messages channel closed component=SourceAckerNode node_id=generator1-acker + // DBG incoming messages channel closed component=SourceAckerNode node_id=generator2-acker // DBG received error on error channel error="error reading from source: stream not open" component=SourceNode node_id=generator1 // DBG received error on error channel error="error reading from source: stream not open" component=SourceNode node_id=generator2 // DBG incoming messages channel closed component=ProcessorNode node_id=counter From 6bdc89377dcd29f6dd48da78ef8f2d33ba9774f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Thu, 23 Jun 2022 16:44:06 +0200 Subject: [PATCH 13/46] add todo note about possible deadlock --- pkg/pipeline/stream/destination_acker.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pkg/pipeline/stream/destination_acker.go b/pkg/pipeline/stream/destination_acker.go index f2fa04861..2e437d89c 100644 --- a/pkg/pipeline/stream/destination_acker.go +++ b/pkg/pipeline/stream/destination_acker.go @@ -113,6 +113,11 @@ func (n *DestinationAckerNode) Run(ctx context.Context) (err error) { continue } + // TODO make sure acks are called in the right order or this will block + // forever. Right now we rely on connectors sending acks back in the + // correct order and this should generally be true, but we can't be + // completely sure and a badly written connector shouldn't provoke a + // deadlock. err = n.handleAck(msg, err) if err != nil { return err From 8b6dc73443a07d8c47cc8ecbbeb64b641d2845ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Thu, 23 Jun 2022 17:25:49 +0200 Subject: [PATCH 14/46] source acker node test --- pkg/pipeline/stream/destination_acker_test.go | 6 - pkg/pipeline/stream/source_acker_test.go | 155 ++++++++++++++++++ 2 files changed, 155 insertions(+), 6 deletions(-) create mode 100644 pkg/pipeline/stream/source_acker_test.go diff --git a/pkg/pipeline/stream/destination_acker_test.go b/pkg/pipeline/stream/destination_acker_test.go index 0a2a447af..66dfe759e 100644 --- a/pkg/pipeline/stream/destination_acker_test.go +++ b/pkg/pipeline/stream/destination_acker_test.go @@ -44,9 +44,6 @@ func TestAckerNode_Run_StopAfterWait(t *testing.T) { is.NoErr(err) }() - // give Go a chance to run the node - time.Sleep(time.Millisecond) - // note that there should be no calls to the destination at all if we didn't // receive any ExpectedAck call @@ -81,9 +78,6 @@ func TestAckerNode_Run_StopAfterExpectAck(t *testing.T) { is.NoErr(err) }() - // give Go a chance to run the node - time.Sleep(time.Millisecond) - // up to this point there should have been no calls to the destination // only after the call to ExpectAck should the node try to fetch any acks msg := &Message{ diff --git a/pkg/pipeline/stream/source_acker_test.go b/pkg/pipeline/stream/source_acker_test.go new file mode 100644 index 000000000..746e991ca --- /dev/null +++ b/pkg/pipeline/stream/source_acker_test.go @@ -0,0 +1,155 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stream + +import ( + "context" + "math/rand" + "strconv" + "sync" + "testing" + "time" + + "github.com/conduitio/conduit/pkg/connector/mock" + "github.com/conduitio/conduit/pkg/record" + "github.com/golang/mock/gomock" + "github.com/matryer/is" +) + +func TestSourceAckerNode_ForwardAck(t *testing.T) { + is := is.New(t) + ctx := context.Background() + ctrl := gomock.NewController(t) + src := mock.NewSource(ctrl) + + node := &SourceAckerNode{ + Name: "acker-node", + Source: src, + } + in := make(chan *Message) + out := node.Pub() + node.Sub(in) + + go func() { + err := node.Run(ctx) + is.NoErr(err) + }() + + want := &Message{Ctx: ctx, Record: record.Record{Position: []byte("foo")}} + // expect to receive an ack in the source after the message is acked + src.EXPECT().Ack(want.Ctx, want.Record.Position).Return(nil) + + in <- want + got := <-out + is.Equal(got, want) + + // ack should be propagated to the source, the mock will do the assertion + err := got.Ack() + is.NoErr(err) + + // gracefully stop node and give the test 1 second to finish + close(in) + + waitCtx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + select { + case <-waitCtx.Done(): + is.Fail() // expected node to stop running + case <-out: + // all good + } +} + +func TestSourceAckerNode_AckOrder(t *testing.T) { + is := is.New(t) + ctx := context.Background() + ctrl := gomock.NewController(t) + src := mock.NewSource(ctrl) + + const count = 1000 + const maxSleep = 1 * time.Millisecond + + node := &SourceAckerNode{ + Name: "acker-node", + Source: src, + } + in := make(chan *Message) + out := node.Pub() + node.Sub(in) + + go func() { + err := node.Run(ctx) + is.NoErr(err) + }() + + // first send messages through the node in the correct order + messages := make([]*Message, count) + for i := 0; i < count; i++ { + m := &Message{ + Ctx: ctx, + Record: record.Record{ + Position: []byte(strconv.Itoa(i)), // position is monotonically increasing + }, + } + in <- m + <-out + messages[i] = m + } + + // expect to receive an acks in the same order as the order of the messages + expectedPosition := 0 + expectedCalls := make([]*gomock.Call, count) + for i := 0; i < count; i++ { + expectedCalls[i] = src.EXPECT(). + Ack(ctx, messages[i].Record.Position). + Do(func(context.Context, record.Position) { expectedPosition++ }). + Return(nil) + } + gomock.InOrder(expectedCalls...) // enforce order + + // ack messages concurrently in random order + var wg sync.WaitGroup + wg.Add(count) + for i := 0; i < count; i++ { + go func(msg *Message) { + defer wg.Done() + // sleep for a random amount of time and ack the message + //nolint:gosec // math/rand is good enough for a test + time.Sleep(time.Duration(rand.Int63n(int64(maxSleep/time.Nanosecond))) * time.Nanosecond) + err := msg.Ack() + is.NoErr(err) + }(messages[i]) + } + + // gracefully stop node and give the test 1 second to finish + close(in) + + wgDone := make(chan struct{}) + go func() { + defer close(wgDone) + wg.Wait() + }() + + waitCtx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + select { + case <-waitCtx.Done(): + is.Fail() // expected to receive all acks in time + case <-wgDone: + // all good + } +} From d6c9641cb7d58fe5a410f66f9956081f4e389399 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Thu, 23 Jun 2022 20:03:01 +0200 Subject: [PATCH 15/46] remove message status dropped --- pkg/pipeline/stream/destination.go | 3 +- pkg/pipeline/stream/destination_acker.go | 47 +++-- pkg/pipeline/stream/doc.go | 17 +- pkg/pipeline/stream/fanin.go | 3 +- pkg/pipeline/stream/fanout.go | 34 +--- pkg/pipeline/stream/message.go | 134 ++++---------- pkg/pipeline/stream/message_test.go | 183 ++------------------ pkg/pipeline/stream/messagestatus_string.go | 5 +- pkg/pipeline/stream/metrics.go | 3 +- pkg/pipeline/stream/processor.go | 19 +- pkg/pipeline/stream/source.go | 3 +- pkg/pipeline/stream/source_acker.go | 4 +- 12 files changed, 91 insertions(+), 364 deletions(-) diff --git a/pkg/pipeline/stream/destination.go b/pkg/pipeline/stream/destination.go index 0c75f2179..845ad455b 100644 --- a/pkg/pipeline/stream/destination.go +++ b/pkg/pipeline/stream/destination.go @@ -94,7 +94,8 @@ func (n *DestinationNode) Run(ctx context.Context) (err error) { writeTime := time.Now() err = n.Destination.Write(msg.Ctx, msg.Record) if err != nil { - n.AckerNode.ForgetAndDrop(msg) + n.AckerNode.Forget(msg) + _ = msg.Nack(err) // TODO think this through if it makes sense to return the error return cerrors.Errorf("error writing to destination: %w", err) } n.ConnectorTimer.Update(time.Since(writeTime)) diff --git a/pkg/pipeline/stream/destination_acker.go b/pkg/pipeline/stream/destination_acker.go index 2e437d89c..a7e65001d 100644 --- a/pkg/pipeline/stream/destination_acker.go +++ b/pkg/pipeline/stream/destination_acker.go @@ -23,6 +23,7 @@ import ( "github.com/conduitio/conduit/pkg/connector" "github.com/conduitio/conduit/pkg/foundation/cerrors" "github.com/conduitio/conduit/pkg/foundation/log" + "github.com/conduitio/conduit/pkg/foundation/multierror" "github.com/conduitio/conduit/pkg/plugin" "github.com/conduitio/conduit/pkg/record" ) @@ -70,13 +71,13 @@ func (n *DestinationAckerNode) Run(ctx context.Context) (err error) { n.init() defer func() { - dropAllErr := n.teardown() + teardownErr := n.teardown(err) if err != nil { // we are already returning an error, just log this one - n.logger.Err(ctx, dropAllErr).Msg("acker node stopped without processing all messages") + n.logger.Err(ctx, teardownErr).Msg("acker node stopped without processing all messages") } else { - // return dropAllErr instead - err = dropAllErr + // return teardownErr instead + err = teardownErr } }() @@ -125,38 +126,39 @@ func (n *DestinationAckerNode) Run(ctx context.Context) (err error) { } } -// teardown will drop all messages still in the cache and return an error in +// teardown will nack all messages still in the cache and return an error in // case there were still unprocessed messages in the cache. -func (n *DestinationAckerNode) teardown() error { - var dropped int +func (n *DestinationAckerNode) teardown(reason error) error { + var nacked int + var err error n.cache.Range(func(pos record.Position, msg *Message) bool { - msg.Drop() - dropped++ + err = multierror.Append(err, msg.Nack(reason)) + nacked++ return true }) - if dropped > 0 { - return cerrors.Errorf("dropped %d messages when stopping acker node", dropped) + if err != nil { + return cerrors.Errorf("nacked %d messages when stopping destination acker node, some nacks failed: %w", nacked, err) + } + if nacked > 0 { + return cerrors.Errorf("nacked %d messages when stopping destination acker node", nacked) } return nil } // handleAck either acks or nacks the message, depending on the supplied error. -// If the nacking or acking fails, the message is dropped and the error is -// returned. +// If the nacking or acking fails the error is returned. func (n *DestinationAckerNode) handleAck(msg *Message, err error) error { switch { case err != nil: n.logger.Trace(msg.Ctx).Err(err).Msg("nacking message") err = msg.Nack(err) if err != nil { - msg.Drop() return cerrors.Errorf("error while nacking message: %w", err) } default: n.logger.Trace(msg.Ctx).Msg("acking message") err = msg.Ack() if err != nil { - msg.Drop() return cerrors.Errorf("error while acking message: %w", err) } } @@ -187,17 +189,10 @@ func (n *DestinationAckerNode) ExpectAck(msg *Message) error { return nil } -// ForgetAndDrop signals the handler that an ack for this message won't be -// received, and it should remove it from its cache. In case an ack for this -// message wasn't yet received it drops the message, otherwise it does nothing. -func (n *DestinationAckerNode) ForgetAndDrop(msg *Message) { - _, ok := n.cache.LoadAndDelete(msg.Record.Position) - if !ok { - // message wasn't found in the cache, looks like the message was already - // acked / nacked - return - } - msg.Drop() +// Forget signals the handler that an ack for this message won't be received, +// and it should remove it from its cache. +func (n *DestinationAckerNode) Forget(msg *Message) { + n.cache.LoadAndDelete(msg.Record.Position) } // Wait can be used to wait for the count of outstanding acks to drop to 0 or diff --git a/pkg/pipeline/stream/doc.go b/pkg/pipeline/stream/doc.go index eba22e8a9..39d9db9f7 100644 --- a/pkg/pipeline/stream/doc.go +++ b/pkg/pipeline/stream/doc.go @@ -28,18 +28,15 @@ A message can have of these statuses: it's passed around between the nodes. Acked Once a node successfully processes the message (e.g. it is sent to the destination or is filtered out by a processor) it is acked. - Nacked If some node fails to process the message it can nack the message - and once it's successfully nacked (e.g. sent to a dead letter queue) - it becomes nacked. - Dropped If a node experiences a non-recoverable error or has to stop running - without sending the message to the next node (e.g. force stop) it - can drop the message, then the message status changes to dropped. + Nacked If some node fails to process the message it nacks the message. In + that case a handler can pick it up to send it to a dead letter + queue. -In other words, once a node receives a message it has 4 options for how to +In other words, once a node receives a message it has 3 options for how to handle it: it can either pass it to the next node (message stays open), ack the -message and keep running, nack the message and keep running or drop the message -and stop running. This means that no message will be left in an open status when -the pipeline stops. +message and keep running if ack is successful, nack the message and keep running +if nack is successful. This means that no message will be left in an open status +when the pipeline stops. Nodes can register functions on the message which will be called when the status of a message changes. For more information see StatusChangeHandler. diff --git a/pkg/pipeline/stream/fanin.go b/pkg/pipeline/stream/fanin.go index 1b6b69144..230be0f29 100644 --- a/pkg/pipeline/stream/fanin.go +++ b/pkg/pipeline/stream/fanin.go @@ -76,8 +76,7 @@ func (n *FaninNode) Run(ctx context.Context) error { select { case <-ctx.Done(): - msg.Drop() - return ctx.Err() + return msg.Nack(ctx.Err()) case n.out <- msg: } } diff --git a/pkg/pipeline/stream/fanout.go b/pkg/pipeline/stream/fanout.go index c6c2aa9ec..147bb4c2c 100644 --- a/pkg/pipeline/stream/fanout.go +++ b/pkg/pipeline/stream/fanout.go @@ -92,8 +92,6 @@ func (n *FanoutNode) Run(ctx context.Context) error { return msg.Ack() case <-msg.Nacked(): return cerrors.New("message was nacked by another node") - case <-msg.Dropped(): - return ErrMessageDropped } }), ) @@ -104,31 +102,10 @@ func (n *FanoutNode) Run(ctx context.Context) error { return msg.Nack(reason) }), ) - newMsg.RegisterDropHandler( - // wrap drop handler to make sure msg is not overwritten - // by the time drop handler is called - n.wrapDropHandler(msg, func(msg *Message, reason error) { - defer func() { - if err := recover(); err != nil { - if cerrors.Is(err.(error), ErrUnexpectedMessageStatus) { - // the unexpected message status is expected (I know, right?) - // this rare case might happen if one downstream node first - // nacks the message and afterwards another node tries to drop - // the message - // this is a valid use case, the panic is trying to make us - // notice all other invalid use cases - return - } - panic(err) // re-panic - } - }() - msg.Drop() - }), - ) select { case <-ctx.Done(): - msg.Drop() + _ = msg.Nack(ctx.Err()) // TODO handle this, don't approve PR unless this is handled return case n.out[i] <- newMsg: } @@ -161,15 +138,6 @@ func (n *FanoutNode) wrapNackHandler(origMsg *Message, f NackHandler) NackHandle } } -// wrapDropHandler modifies the drop handler, so it's called with the original -// message received by FanoutNode instead of the new message created by -// FanoutNode. -func (n *FanoutNode) wrapDropHandler(origMsg *Message, f DropHandler) DropHandler { - return func(_ *Message, reason error) { - f(origMsg, reason) - } -} - func (n *FanoutNode) Sub(in <-chan *Message) { if n.in != nil { panic("can't connect FanoutNode to more than one in") diff --git a/pkg/pipeline/stream/message.go b/pkg/pipeline/stream/message.go index 5a2c761e3..f8f917e5f 100644 --- a/pkg/pipeline/stream/message.go +++ b/pkg/pipeline/stream/message.go @@ -26,18 +26,16 @@ import ( "github.com/conduitio/conduit/pkg/record" ) -// MessageStatus represents the state of the message (acked, nacked, dropped or open). +// MessageStatus represents the state of the message (acked, nacked or open). type MessageStatus int const ( MessageStatusAcked MessageStatus = iota MessageStatusNacked MessageStatusOpen - MessageStatusDropped ) var ( - ErrMessageDropped = cerrors.New("message is dropped") ErrUnexpectedMessageStatus = cerrors.New("unexpected message status") ) @@ -45,44 +43,40 @@ var ( type Message struct { // Ctx is the context in which the record was fetched. It should be used for // any function calls when processing the message. If the context is done - // the message should be dropped as soon as possible and not processed + // the message should be nacked as soon as possible and not processed // further. Ctx context.Context // Record represents a single record attached to the message. Record record.Record - // acked, nacked and dropped are channels used to capture acks, nacks and - // drops. When a message is acked, nacked or dropped the corresponding - // channel is closed. - acked chan struct{} - nacked chan struct{} - dropped chan struct{} + // acked and nacked and are channels used to capture acks and nacks. When a + // message is acked or nacked the corresponding channel is closed. + acked chan struct{} + nacked chan struct{} - // handler is executed when Ack, Nack or Drop is called. + // handler is executed when Ack or Nack is called. handler StatusChangeHandler // hasNackHandler is true if at least one nack handler was registered. hasNackHandler bool - // ackNackReturnValue is cached the first time Ack, Nack or Drop is executed. + // ackNackReturnValue is cached the first time Ack or Nack is executed. ackNackReturnValue error // initOnce is guarding the initialization logic of a message. initOnce sync.Once - // ackNackDropOnce is guarding the acking/nacking/dropping logic of a message. - ackNackDropOnce sync.Once - // handlerGuard guards fields ackHandlers and nackHandlers. - handlerGuard sync.Mutex + // ackNackOnce is guarding the acking/nacking logic of a message. + ackNackOnce sync.Once } type ( - // StatusChangeHandler is executed when a message status changes. The handlers - // are triggered by a call to either of these functions: Message.Nack, - // Message.Ack, Message.Drop. These functions will block until the handlers + // StatusChangeHandler is executed when a message status changes. The + // handlers are triggered by a call to either of these functions: + // Message.Nack, Message.Ack. These functions will block until the handlers // finish handling the message and will return the error returned by the // handlers. - // The function receives the message and the status change describing the old - // and new message status as well as the reason for the status change in case of - // a nack or drop. + // The function receives the message and the status change describing the + // old and new message status as well as the reason for the status change in + // case of a nack. StatusChangeHandler func(*Message, StatusChange) error // AckHandler is a variation of the StatusChangeHandler that is only called @@ -92,10 +86,6 @@ type ( // NackHandler is a variation of the StatusChangeHandler that is only called // when a message is nacked. For more info see StatusChangeHandler. NackHandler func(*Message, error) error - - // DropHandler is a variation of the StatusChangeHandler that is only called - // when a message is dropped. For more info see StatusChangeHandler. - DropHandler func(*Message, error) ) // StatusChange is passed to StatusChangeHandler when the status of a message @@ -104,7 +94,7 @@ type StatusChange struct { Old MessageStatus New MessageStatus // Reason contains the error that triggered the status change in case of a - // nack or drop. + // nack. Reason error } @@ -113,7 +103,6 @@ func (m *Message) init() { m.initOnce.Do(func() { m.acked = make(chan struct{}) m.nacked = make(chan struct{}) - m.dropped = make(chan struct{}) // initialize empty status handler m.handler = func(msg *Message, change StatusChange) error { return nil } }) @@ -131,8 +120,6 @@ func (m *Message) ID() string { // reverse order of how they were registered. func (m *Message) RegisterStatusHandler(mw StatusChangeHandler) { m.init() - m.handlerGuard.Lock() - defer m.handlerGuard.Unlock() if m.Status() != MessageStatusOpen { panic(cerrors.Errorf("BUG: tried to register handler on message %s, it has already been handled", m.ID())) @@ -172,23 +159,7 @@ func (m *Message) RegisterNackHandler(mw NackHandler) { m.hasNackHandler = true } -// RegisterDropHandler is used to register a function that will be called when -// the message is dropped. This function can only be called if the message -// status is open, otherwise it panics. -func (m *Message) RegisterDropHandler(mw DropHandler) { - m.RegisterStatusHandler(func(msg *Message, change StatusChange) error { - if change.New != MessageStatusDropped { - return nil - } - mw(msg, change.Reason) - return nil - }) -} - func (m *Message) notifyStatusHandlers(status MessageStatus, reason error) error { - m.handlerGuard.Lock() - defer m.handlerGuard.Unlock() - return m.handler(m, StatusChange{ Old: m.Status(), New: status, @@ -197,26 +168,19 @@ func (m *Message) notifyStatusHandlers(status MessageStatus, reason error) error } // Ack marks the message as acked, calls the corresponding status change -// handlers and closes the channel returned by Acked. If an ack handler returns -// an error, the message is dropped instead, which means that registered status -// change handlers are again notified about the drop and the channel returned by -// Dropped is closed instead. +// handlers and closes the channel returned by Acked. Errors from ack handlers +// get collected and returned as a single error. If Ack returns an error, the +// caller node should stop processing new messages and return the error. // Calling Ack after the message has already been nacked will panic, while -// subsequent calls to Ack on an acked or dropped message are a noop and return -// the same value. +// subsequent calls to Ack on an acked message are a noop and return the same +// value. func (m *Message) Ack() error { m.init() - m.ackNackDropOnce.Do(func() { + m.ackNackOnce.Do(func() { m.ackNackReturnValue = m.notifyStatusHandlers(MessageStatusAcked, nil) - if m.ackNackReturnValue != nil { - // unsuccessful ack, message is dropped - _ = m.notifyStatusHandlers(MessageStatusDropped, m.ackNackReturnValue) - close(m.dropped) - return - } close(m.acked) }) - if s := m.Status(); s != MessageStatusAcked && s != MessageStatusDropped { + if s := m.Status(); s != MessageStatusAcked { panic(cerrors.Errorf("BUG: message %s ack failed, status is %s: %w", m.ID(), s, ErrUnexpectedMessageStatus)) } return m.ackNackReturnValue @@ -224,55 +188,29 @@ func (m *Message) Ack() error { // Nack marks the message as nacked, calls the registered status change handlers // and closes the channel returned by Nacked. If no nack handlers were -// registered or a nack handler returns an error, the message is dropped -// instead, which means that registered status change handlers are again -// notified about the drop and the channel returned by Dropped is closed -// instead. +// registered Nack will return an error. Errors from nack handlers get collected +// and returned as a single error. If Nack returns an error, the caller node +// should stop processing new messages and return the error. // Calling Nack after the message has already been acked will panic, while -// subsequent calls to Nack on a nacked or dropped message are a noop and return -// the same value. +// subsequent calls to Nack on a nacked message are a noop and return the same +// value. func (m *Message) Nack(reason error) error { m.init() - m.ackNackDropOnce.Do(func() { + m.ackNackOnce.Do(func() { if !m.hasNackHandler { // we enforce at least one nack handler, otherwise nacks will go unnoticed m.ackNackReturnValue = cerrors.Errorf("no nack handler on message %s: %w", m.ID(), reason) } else { m.ackNackReturnValue = m.notifyStatusHandlers(MessageStatusNacked, reason) } - if m.ackNackReturnValue != nil { - // unsuccessful nack, message is dropped - _ = m.notifyStatusHandlers(MessageStatusDropped, m.ackNackReturnValue) - close(m.dropped) - return - } close(m.nacked) }) - if s := m.Status(); s != MessageStatusNacked && s != MessageStatusDropped { + if s := m.Status(); s != MessageStatusNacked { panic(cerrors.Errorf("BUG: message %s nack failed, status is %s: %w", m.ID(), s, ErrUnexpectedMessageStatus)) } return m.ackNackReturnValue } -// Drop marks the message as dropped, calls the registered status change -// handlers and closes the channel returned by Dropped. -// Calling Drop after the message has already been acked or nacked will panic, -// while subsequent calls to Drop on a dropped message are a noop. -func (m *Message) Drop() { - m.init() - m.ackNackDropOnce.Do(func() { - m.ackNackReturnValue = ErrMessageDropped - err := m.notifyStatusHandlers(MessageStatusDropped, m.ackNackReturnValue) - if err != nil { - panic(cerrors.Errorf("BUG: drop handlers should never return an error (message %s): %w", m.ID(), err)) - } - close(m.dropped) - }) - if s := m.Status(); s != MessageStatusDropped { - panic(cerrors.Errorf("BUG: message %s drop failed, status is %s: %w", m.ID(), s, ErrUnexpectedMessageStatus)) - } -} - // Acked returns a channel that's closed when the message has been acked. // Successive calls to Acked return the same value. This function can be used to // wait for a message to be acked without notifying the acker. @@ -289,14 +227,6 @@ func (m *Message) Nacked() <-chan struct{} { return m.nacked } -// Dropped returns a channel that's closed when the message has been dropped. -// Successive calls to Dropped return the same value. This function can be used -// to wait for a message to be dropped without notifying the dropper. -func (m *Message) Dropped() <-chan struct{} { - m.init() - return m.dropped -} - // Clone returns a cloned message with the same content but separate ack and // nack handling. func (m *Message) Clone() *Message { @@ -313,8 +243,6 @@ func (m *Message) Status() MessageStatus { return MessageStatusAcked case <-m.nacked: return MessageStatusNacked - case <-m.dropped: - return MessageStatusDropped default: return MessageStatusOpen } diff --git a/pkg/pipeline/stream/message_test.go b/pkg/pipeline/stream/message_test.go index 12c35b2f5..4f8445614 100644 --- a/pkg/pipeline/stream/message_test.go +++ b/pkg/pipeline/stream/message_test.go @@ -84,9 +84,8 @@ func TestMessage_Ack_WithFailingHandler(t *testing.T) { msg Message wantErr = cerrors.New("oops") - ackedMessageHandlerCallCount int - droppedMessageHandlerCallCount int - statusMessageHandlerCallCount int + ackedMessageHandlerCallCount int + statusMessageHandlerCallCount int ) { @@ -104,18 +103,11 @@ func TestMessage_Ack_WithFailingHandler(t *testing.T) { ackedMessageHandlerCallCount++ return nil }) - // fourth handler should be called twice, once for ack, once for drop + // fourth handler should be called once msg.RegisterStatusHandler(func(msg *Message, change StatusChange) error { statusMessageHandlerCallCount++ return nil }) - // drop handler should be called after the ack fails - msg.RegisterDropHandler(func(msg *Message, reason error) { - if ackedMessageHandlerCallCount != 2 { - t.Fatal("expected acked message handlers to already be called") - } - droppedMessageHandlerCallCount++ - }) // nack handler should not be called msg.RegisterNackHandler(func(*Message, error) error { t.Fatalf("did not expect nack handler to be called") @@ -129,26 +121,14 @@ func TestMessage_Ack_WithFailingHandler(t *testing.T) { if err != wantErr { t.Fatalf("ack expected error %v, got: %v", wantErr, err) } - assertMessageIsDropped(t, &msg) + assertMessageIsAcked(t, &msg) if ackedMessageHandlerCallCount != 2 { t.Fatalf("expected acked message handler to be called twice, got %d calls", ackedMessageHandlerCallCount) } - if droppedMessageHandlerCallCount != 1 { - t.Fatalf("expected dropped message handler to be called once, got %d calls", droppedMessageHandlerCallCount) - } - if statusMessageHandlerCallCount != 2 { - t.Fatalf("expected status message handler to be called twice, got %d calls", statusMessageHandlerCallCount) + if statusMessageHandlerCallCount != 1 { + t.Fatalf("expected status message handler to be called once, got %d calls", statusMessageHandlerCallCount) } } - - // nacking the message should return the same error - err := msg.Nack(cerrors.New("reason")) - if err != wantErr { - t.Fatalf("nack expected error %v, got %v", wantErr, err) - } - - // dropping the message shouldn't do anything - msg.Drop() } func TestMessage_Nack_WithoutHandler(t *testing.T) { @@ -161,20 +141,14 @@ func TestMessage_Nack_WithoutHandler(t *testing.T) { if err1 == nil { t.Fatal("nack expected error, got nil") } - assertMessageIsDropped(t, &msg) + assertMessageIsNacked(t, &msg) // nacking again should return the same error err2 := msg.Nack(cerrors.New("reason")) if err1 != err2 { t.Fatalf("nack expected error %v, got %v", err1, err2) } - assertMessageIsDropped(t, &msg) - - // acking the message should return the same error - err3 := msg.Ack() - if err1 != err3 { - t.Fatalf("ack expected error %v, got %v", err1, err3) - } + assertMessageIsNacked(t, &msg) } func TestMessage_Nack_WithHandler(t *testing.T) { @@ -218,9 +192,8 @@ func TestMessage_Nack_WithFailingHandler(t *testing.T) { msg Message wantErr = cerrors.New("oops") - nackedMessageHandlerCallCount int - droppedMessageHandlerCallCount int - statusMessageHandlerCallCount int + nackedMessageHandlerCallCount int + statusMessageHandlerCallCount int ) { @@ -238,18 +211,11 @@ func TestMessage_Nack_WithFailingHandler(t *testing.T) { nackedMessageHandlerCallCount++ return nil }) - // fourth handler should be called twice, once for ack, once for drop + // fourth handler should be called once msg.RegisterStatusHandler(func(msg *Message, change StatusChange) error { statusMessageHandlerCallCount++ return nil }) - // drop handler should be called after the nack fails - msg.RegisterDropHandler(func(msg *Message, reason error) { - if nackedMessageHandlerCallCount != 2 { - t.Fatal("expected nacked message handlers to already be called") - } - droppedMessageHandlerCallCount++ - }) // ack handler should not be called msg.RegisterAckHandler(func(*Message) error { t.Fatalf("did not expect ack handler to be called") @@ -263,90 +229,16 @@ func TestMessage_Nack_WithFailingHandler(t *testing.T) { if err != wantErr { t.Fatalf("nack expected error %v, got: %v", wantErr, err) } - assertMessageIsDropped(t, &msg) + assertMessageIsNacked(t, &msg) if nackedMessageHandlerCallCount != 2 { t.Fatalf("expected nacked message handler to be called twice, got %d calls", nackedMessageHandlerCallCount) } - if droppedMessageHandlerCallCount != 1 { - t.Fatalf("expected dropped message handler to be called once, got %d calls", droppedMessageHandlerCallCount) - } - if statusMessageHandlerCallCount != 2 { - t.Fatalf("expected status message handler to be called twice, got %d calls", statusMessageHandlerCallCount) - } - } - - // acking the message should return the same error - err := msg.Ack() - if err != wantErr { - t.Fatalf("ack expected error %v, got %v", wantErr, err) - } - - // dropping the message shouldn't do anything - msg.Drop() -} - -func TestMessage_Drop_WithoutHandler(t *testing.T) { - var msg Message - - assertMessageIsOpen(t, &msg) - - msg.Drop() - assertMessageIsDropped(t, &msg) - - // doing the same thing again shouldn't do anything - msg.Drop() - assertMessageIsDropped(t, &msg) -} - -func TestMessage_Drop_WithHandler(t *testing.T) { - var ( - msg Message - - droppedMessageHandlerCallCount int - statusMessageHandlerCallCount int - ) - - { - msg.RegisterDropHandler(func(msg *Message, reason error) { - droppedMessageHandlerCallCount++ - }) - // second handler should be called once for drop - msg.RegisterStatusHandler(func(msg *Message, change StatusChange) error { - statusMessageHandlerCallCount++ - return nil - }) - } - - // doing the same thing twice should have the same result - for i := 0; i < 2; i++ { - msg.Drop() - assertMessageIsDropped(t, &msg) - if droppedMessageHandlerCallCount != 1 { - t.Fatalf("expected dropped message handler to be called once, got %d calls", droppedMessageHandlerCallCount) - } if statusMessageHandlerCallCount != 1 { t.Fatalf("expected status message handler to be called once, got %d calls", statusMessageHandlerCallCount) } } } -func TestMessage_Drop_WithFailingHandler(t *testing.T) { - var msg Message - - // handler return error for drop - msg.RegisterStatusHandler(func(msg *Message, change StatusChange) error { - return cerrors.New("oops") - }) - - defer func() { - if recover() == nil { - t.Fatalf("expected msg.Drop to panic") - } - }() - - msg.Drop() -} - func TestMessage_StatusChangeTwice(t *testing.T) { assertAckPanics := func(msg *Message) { defer func() { @@ -364,16 +256,8 @@ func TestMessage_StatusChangeTwice(t *testing.T) { }() _ = msg.Nack(nil) } - assertDropPanics := func(msg *Message) { - defer func() { - if recover() == nil { - t.Fatalf("expected msg.Drop to panic") - } - }() - msg.Drop() - } - // nack or drop after the message is acked should panic + // nack after the message is acked should panic t.Run("acked message", func(t *testing.T) { var msg Message err := msg.Ack() @@ -381,7 +265,6 @@ func TestMessage_StatusChangeTwice(t *testing.T) { t.Fatalf("ack did not expect error, got %v", err) } assertNackPanics(&msg) - assertDropPanics(&msg) }) // registering a handler after the message is nacked should panic @@ -394,23 +277,6 @@ func TestMessage_StatusChangeTwice(t *testing.T) { t.Fatalf("ack did not expect error, got %v", err) } assertAckPanics(&msg) - assertDropPanics(&msg) - }) - - // registering a handler after the message is dropped should panic - t.Run("dropped message", func(t *testing.T) { - var msg Message - msg.Drop() - - err := msg.Ack() - if err != ErrMessageDropped { - t.Fatalf("expected %v, got %v", ErrMessageDropped, err) - } - - err = msg.Nack(nil) - if err != ErrMessageDropped { - t.Fatalf("expected %v, got %v", ErrMessageDropped, err) - } }) } @@ -431,14 +297,6 @@ func TestMessage_RegisterHandlerFail(t *testing.T) { }() msg.RegisterNackHandler(func(*Message, error) error { return nil }) } - assertRegisterDropHandlerPanics := func(msg *Message) { - defer func() { - if recover() == nil { - t.Fatalf("expected msg.RegisterDropHandler to panic") - } - }() - msg.RegisterDropHandler(func(*Message, error) {}) - } // registering a handler after the message is acked should panic t.Run("acked message", func(t *testing.T) { @@ -449,7 +307,6 @@ func TestMessage_RegisterHandlerFail(t *testing.T) { } assertRegisterAckHandlerPanics(&msg) assertRegisterNackHandlerPanics(&msg) - assertRegisterDropHandlerPanics(&msg) }) // registering a handler after the message is nacked should panic @@ -463,16 +320,6 @@ func TestMessage_RegisterHandlerFail(t *testing.T) { } assertRegisterAckHandlerPanics(&msg) assertRegisterNackHandlerPanics(&msg) - assertRegisterDropHandlerPanics(&msg) - }) - - // registering a handler after the message is dropped should panic - t.Run("dropped message", func(t *testing.T) { - var msg Message - msg.Drop() - assertRegisterAckHandlerPanics(&msg) - assertRegisterNackHandlerPanics(&msg) - assertRegisterDropHandlerPanics(&msg) }) } @@ -487,7 +334,3 @@ func assertMessageIsNacked(t *testing.T, msg *Message) { func assertMessageIsOpen(t *testing.T, msg *Message) { assert.Equal(t, MessageStatusOpen, msg.Status()) } - -func assertMessageIsDropped(t *testing.T, msg *Message) { - assert.Equal(t, MessageStatusDropped, msg.Status()) -} diff --git a/pkg/pipeline/stream/messagestatus_string.go b/pkg/pipeline/stream/messagestatus_string.go index 8b339bc5e..d748cd0ba 100644 --- a/pkg/pipeline/stream/messagestatus_string.go +++ b/pkg/pipeline/stream/messagestatus_string.go @@ -11,12 +11,11 @@ func _() { _ = x[MessageStatusAcked-0] _ = x[MessageStatusNacked-1] _ = x[MessageStatusOpen-2] - _ = x[MessageStatusDropped-3] } -const _MessageStatus_name = "AckedNackedOpenDropped" +const _MessageStatus_name = "AckedNackedOpen" -var _MessageStatus_index = [...]uint8{0, 5, 11, 15, 22} +var _MessageStatus_index = [...]uint8{0, 5, 11, 15} func (i MessageStatus) String() string { if i < 0 || i >= MessageStatus(len(_MessageStatus_index)-1) { diff --git a/pkg/pipeline/stream/metrics.go b/pkg/pipeline/stream/metrics.go index b6dd3bb13..0fce43dd7 100644 --- a/pkg/pipeline/stream/metrics.go +++ b/pkg/pipeline/stream/metrics.go @@ -65,8 +65,7 @@ func (n *MetricsNode) Run(ctx context.Context) error { err = n.base.Send(ctx, n.logger, msg) if err != nil { - msg.Drop() - return err + return msg.Nack(err) } } } diff --git a/pkg/pipeline/stream/processor.go b/pkg/pipeline/stream/processor.go index ddbfddb5b..acbd5416c 100644 --- a/pkg/pipeline/stream/processor.go +++ b/pkg/pipeline/stream/processor.go @@ -55,28 +55,27 @@ func (n *ProcessorNode) Run(ctx context.Context) error { n.ProcessorTimer.Update(time.Since(executeTime)) if err != nil { // Check for Skipped records - if err == processor.ErrSkipRecord { + switch err { + case processor.ErrSkipRecord: // NB: Ack skipped messages since they've been correctly handled err := msg.Ack() if err != nil { return cerrors.Errorf("failed to ack skipped message: %w", err) } - continue - } - err = msg.Nack(err) - if err != nil { - msg.Drop() - return cerrors.Errorf("error applying transform: %w", err) + default: + err = msg.Nack(err) + if err != nil { + return cerrors.Errorf("error executing processor: %w", err) + } } - // nack was handled successfully, we recovered + // error was handled successfully, we recovered continue } msg.Record = rec err = n.base.Send(ctx, n.logger, msg) if err != nil { - msg.Drop() - return err + return msg.Nack(err) } } } diff --git a/pkg/pipeline/stream/source.go b/pkg/pipeline/stream/source.go index 9c827c5fc..cb5efbdb2 100644 --- a/pkg/pipeline/stream/source.go +++ b/pkg/pipeline/stream/source.go @@ -116,8 +116,7 @@ func (n *SourceNode) Run(ctx context.Context) (err error) { err = n.base.Send(ctx, n.logger, msg) if err != nil { - msg.Drop() - return err + return msg.Nack(err) } } } diff --git a/pkg/pipeline/stream/source_acker.go b/pkg/pipeline/stream/source_acker.go index 9dc6c759f..485af8458 100644 --- a/pkg/pipeline/stream/source_acker.go +++ b/pkg/pipeline/stream/source_acker.go @@ -56,13 +56,13 @@ func (n *SourceAckerNode) Run(ctx context.Context) error { // enqueue message in semaphore ticket := n.sem.Enqueue() + // TODO make sure that if an ack/nack fails we stop forwarding acks n.registerAckHandler(msg, ticket) n.registerNackHandler(msg, ticket) err = n.base.Send(ctx, n.logger, msg) if err != nil { - msg.Drop() - return err + return msg.Nack(err) } } } From f3273f3d29c67076dc60fcbbfaccad1148ba488f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Fri, 24 Jun 2022 15:25:52 +0200 Subject: [PATCH 16/46] document behavior, fanout node return nacked message error --- pkg/pipeline/stream/destination.go | 8 +++++++- pkg/pipeline/stream/destination_acker.go | 5 ++--- pkg/pipeline/stream/fanout.go | 23 ++++++++++++++++++++++- pkg/pipeline/stream/source_acker.go | 1 - 4 files changed, 31 insertions(+), 6 deletions(-) diff --git a/pkg/pipeline/stream/destination.go b/pkg/pipeline/stream/destination.go index 845ad455b..23f775b3a 100644 --- a/pkg/pipeline/stream/destination.go +++ b/pkg/pipeline/stream/destination.go @@ -94,8 +94,14 @@ func (n *DestinationNode) Run(ctx context.Context) (err error) { writeTime := time.Now() err = n.Destination.Write(msg.Ctx, msg.Record) if err != nil { + // An error in Write is a fatal error, we probably won't be able to + // process any further messages because there is a problem in the + // communication with the plugin. We need to let the acker node know + // that it shouldn't wait to receive an ack for the message, we need + // to nack the message to not leave it open and then return the + // error to stop the pipeline. n.AckerNode.Forget(msg) - _ = msg.Nack(err) // TODO think this through if it makes sense to return the error + _ = msg.Nack(err) return cerrors.Errorf("error writing to destination: %w", err) } n.ConnectorTimer.Update(time.Since(writeTime)) diff --git a/pkg/pipeline/stream/destination_acker.go b/pkg/pipeline/stream/destination_acker.go index a7e65001d..317707ecb 100644 --- a/pkg/pipeline/stream/destination_acker.go +++ b/pkg/pipeline/stream/destination_acker.go @@ -116,9 +116,8 @@ func (n *DestinationAckerNode) Run(ctx context.Context) (err error) { // TODO make sure acks are called in the right order or this will block // forever. Right now we rely on connectors sending acks back in the - // correct order and this should generally be true, but we can't be - // completely sure and a badly written connector shouldn't provoke a - // deadlock. + // correct order and this should generally be true, but a badly written + // connector could provoke a deadlock, we could prevent that. err = n.handleAck(msg, err) if err != nil { return err diff --git a/pkg/pipeline/stream/fanout.go b/pkg/pipeline/stream/fanout.go index 147bb4c2c..e67779e90 100644 --- a/pkg/pipeline/stream/fanout.go +++ b/pkg/pipeline/stream/fanout.go @@ -105,7 +105,9 @@ func (n *FanoutNode) Run(ctx context.Context) error { select { case <-ctx.Done(): - _ = msg.Nack(ctx.Err()) // TODO handle this, don't approve PR unless this is handled + // we can ignore the error, it will show up in the + // original msg + _ = newMsg.Nack(ctx.Err()) return case n.out[i] <- newMsg: } @@ -116,6 +118,25 @@ func (n *FanoutNode) Run(ctx context.Context) error { // also there is no need to listen to ctx.Done because that's what // the go routines are doing already wg.Wait() + + // check if the context is still alive + if ctx.Err() != nil { + // context was closed - if the message was nacked there's a high + // chance it was nacked in this node by one of the goroutines + if msg.Status() == MessageStatusNacked { + // check if the message nack returned an error (Nack is + // idempotent and will return the same error as in the first + // call), return it if it returns an error + if err := msg.Nack(nil); err != nil { + return err + } + } + // the message is not nacked, it must have been sent to all + // downstream nodes just before the context got cancelled, we + // don't care about the message anymore, so we just return the + // context error + return ctx.Err() + } } } } diff --git a/pkg/pipeline/stream/source_acker.go b/pkg/pipeline/stream/source_acker.go index 485af8458..36005a13d 100644 --- a/pkg/pipeline/stream/source_acker.go +++ b/pkg/pipeline/stream/source_acker.go @@ -56,7 +56,6 @@ func (n *SourceAckerNode) Run(ctx context.Context) error { // enqueue message in semaphore ticket := n.sem.Enqueue() - // TODO make sure that if an ack/nack fails we stop forwarding acks n.registerAckHandler(msg, ticket) n.registerNackHandler(msg, ticket) From 94b4f437c20701932a4d6f89779326bbafaca418 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Tue, 28 Jun 2022 20:46:15 +0200 Subject: [PATCH 17/46] don't forward acks after a failed ack/nack --- pkg/pipeline/stream/source_acker.go | 49 ++++-- pkg/pipeline/stream/source_acker_test.go | 186 +++++++++++++++++++---- 2 files changed, 196 insertions(+), 39 deletions(-) diff --git a/pkg/pipeline/stream/source_acker.go b/pkg/pipeline/stream/source_acker.go index 9dc6c759f..c716f306e 100644 --- a/pkg/pipeline/stream/source_acker.go +++ b/pkg/pipeline/stream/source_acker.go @@ -35,6 +35,10 @@ type SourceAckerNode struct { // sem ensures acks are sent to the source in the correct order and only one // at a time sem semaphore.Simple + // fail is set to true once the first ack/nack fails and we can't guarantee + // that acks will be delivered in the correct order to the source anymore, + // at that point we completely stop processing acks/nacks + fail bool } func (n *SourceAckerNode) ID() string { @@ -71,19 +75,30 @@ func (n *SourceAckerNode) registerAckHandler(msg *Message, ticket semaphore.Tick msg.RegisterAckHandler( func(msg *Message) (err error) { defer func() { - tmpErr := n.sem.Release(ticket) if err != nil { - // we are already returning an error, log this one instead - n.logger.Err(msg.Ctx, tmpErr).Msg("error releasing semaphore ticket for ack") - return + n.fail = true + } + tmpErr := n.sem.Release(ticket) + if tmpErr != nil { + if err != nil { + // we are already returning an error, log this one instead + n.logger.Err(msg.Ctx, tmpErr).Msg("error releasing semaphore ticket for ack") + } else { + err = tmpErr + } } - err = tmpErr }() n.logger.Trace(msg.Ctx).Msg("acquiring semaphore for ack") err = n.sem.Acquire(ticket) if err != nil { return cerrors.Errorf("could not acquire semaphore for ack: %w", err) } + + if n.fail { + n.logger.Trace(msg.Ctx).Msg("blocking forwarding of ack to source connector, because another message failed to be acked/nacked") + return cerrors.Errorf("another message failed to be acked/nacked") + } + n.logger.Trace(msg.Ctx).Msg("forwarding ack to source connector") return n.Source.Ack(msg.Ctx, msg.Record.Position) }, @@ -94,22 +109,34 @@ func (n *SourceAckerNode) registerNackHandler(msg *Message, ticket semaphore.Tic msg.RegisterNackHandler( func(msg *Message, reason error) (err error) { defer func() { - tmpErr := n.sem.Release(ticket) if err != nil { - // we are already returning an error, log this one instead - n.logger.Err(msg.Ctx, tmpErr).Msg("error releasing semaphore ticket for nack") - return + n.fail = true + } + tmpErr := n.sem.Release(ticket) + if tmpErr != nil { + if err != nil { + // we are already returning an error, log this one instead + n.logger.Err(msg.Ctx, tmpErr).Msg("error releasing semaphore ticket for nack") + } else { + err = tmpErr + } } - err = tmpErr }() n.logger.Trace(msg.Ctx).Msg("acquiring semaphore for nack") err = n.sem.Acquire(ticket) if err != nil { return cerrors.Errorf("could not acquire semaphore for nack: %w", err) } + + if n.fail { + n.logger.Trace(msg.Ctx).Msg("blocking forwarding of nack to DLQ handler, because another message failed to be acked/nacked") + return cerrors.Errorf("another message failed to be acked/nacked") + } + n.logger.Trace(msg.Ctx).Msg("forwarding nack to DLQ handler") // TODO implement DLQ and call it here, right now any nacked message - // will just stop the pipeline because we don't support DLQs + // will just stop the pipeline because we don't support DLQs, + // don't forget to forward ack to source if the DLQ call succeeds // https://github.com/ConduitIO/conduit/issues/306 return cerrors.New("no DLQ handler configured") }, diff --git a/pkg/pipeline/stream/source_acker_test.go b/pkg/pipeline/stream/source_acker_test.go index 746e991ca..12307cbd6 100644 --- a/pkg/pipeline/stream/source_acker_test.go +++ b/pkg/pipeline/stream/source_acker_test.go @@ -16,12 +16,14 @@ package stream import ( "context" + "errors" "math/rand" "strconv" "sync" "testing" "time" + "github.com/conduitio/conduit/pkg/connector" "github.com/conduitio/conduit/pkg/connector/mock" "github.com/conduitio/conduit/pkg/record" "github.com/golang/mock/gomock" @@ -33,19 +35,9 @@ func TestSourceAckerNode_ForwardAck(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) src := mock.NewSource(ctrl) + helper := sourceAckerNodeTestHelper{} - node := &SourceAckerNode{ - Name: "acker-node", - Source: src, - } - in := make(chan *Message) - out := node.Pub() - node.Sub(in) - - go func() { - err := node.Run(ctx) - is.NoErr(err) - }() + _, in, out := helper.newSourceAckerNode(ctx, is, src) want := &Message{Ctx: ctx, Record: record.Record{Position: []byte("foo")}} // expect to receive an ack in the source after the message is acked @@ -78,10 +70,126 @@ func TestSourceAckerNode_AckOrder(t *testing.T) { ctx := context.Background() ctrl := gomock.NewController(t) src := mock.NewSource(ctrl) + helper := sourceAckerNodeTestHelper{} + + _, in, out := helper.newSourceAckerNode(ctx, is, src) + // send 1000 messages through the node + messages := helper.sendMessages(ctx, 1000, in, out) + // expect all messages to be acked + expectedCalls := helper.expectAcks(ctx, messages, src) + gomock.InOrder(expectedCalls...) // enforce order of acks - const count = 1000 - const maxSleep = 1 * time.Millisecond + // ack messages concurrently in random order, expect no errors + var wg sync.WaitGroup + helper.ackMessagesConcurrently( + &wg, + messages, + func(msg *Message, err error) { + is.NoErr(err) + }, + ) + + // gracefully stop node and give the test 1 second to finish + close(in) + + err := helper.wait(ctx, &wg, time.Second) + is.NoErr(err) // expected to receive acks in time +} + +func TestSourceAckerNode_FailedAck(t *testing.T) { + is := is.New(t) + ctx := context.Background() + ctrl := gomock.NewController(t) + src := mock.NewSource(ctrl) + helper := sourceAckerNodeTestHelper{} + + _, in, out := helper.newSourceAckerNode(ctx, is, src) + // send 1000 messages through the node + messages := helper.sendMessages(ctx, 1000, in, out) + // expect first 500 to be acked successfully + expectedCalls := helper.expectAcks(ctx, messages[:500], src) + gomock.InOrder(expectedCalls...) // enforce order of acks + // the 500th message should be acked unsuccessfully + wantErr := errors.New("test error") + src.EXPECT(). + Ack(ctx, messages[500].Record.Position). + Return(wantErr). + After(expectedCalls[len(expectedCalls)-1]) // should happen after last acked call + + // ack messages concurrently in random order, expect errors for second half + var wg sync.WaitGroup + helper.ackMessagesConcurrently(&wg, messages[:500], + func(msg *Message, err error) { + is.NoErr(err) // expected messages from the first half to be acked successfully + }, + ) + helper.ackMessagesConcurrently(&wg, messages[500:501], + func(msg *Message, err error) { + is.Equal(err, wantErr) // expected the middle message ack to fail with specific error + }, + ) + helper.ackMessagesConcurrently(&wg, messages[501:], + func(msg *Message, err error) { + is.True(err != nil) // expected messages from the second half to be acked unsuccessfully + is.True(err != wantErr) + }, + ) + + // gracefully stop node and give the test 1 second to finish + close(in) + + err := helper.wait(ctx, &wg, time.Second) + is.NoErr(err) // expected to receive acks in time +} +func TestSourceAckerNode_FailedNack(t *testing.T) { + is := is.New(t) + ctx := context.Background() + ctrl := gomock.NewController(t) + src := mock.NewSource(ctrl) + helper := sourceAckerNodeTestHelper{} + + _, in, out := helper.newSourceAckerNode(ctx, is, src) + // send 1000 messages through the node + messages := helper.sendMessages(ctx, 1000, in, out) + // expect first 500 to be acked successfully + expectedCalls := helper.expectAcks(ctx, messages[:500], src) + gomock.InOrder(expectedCalls...) // enforce order of acks + // the 500th message will be nacked unsuccessfully, no more acks should be received after that + + // ack messages concurrently in random order + var wg sync.WaitGroup + helper.ackMessagesConcurrently(&wg, messages[:500], + func(msg *Message, err error) { + is.NoErr(err) // expected messages from the first half to be acked successfully + }, + ) + helper.ackMessagesConcurrently(&wg, messages[501:], + func(msg *Message, err error) { + is.True(err != nil) // expected messages from the second half to be acked unsuccessfully + }, + ) + + wantErr := errors.New("test error") + err := messages[500].Nack(wantErr) + is.True(err != nil) // expected the 500th message nack to fail with specific error + + // gracefully stop node and give the test 1 second to finish + close(in) + + err = helper.wait(ctx, &wg, time.Second) + is.NoErr(err) // expected to receive acks in time +} + +// sourceAckerNodeTestHelper groups together helper functions for tests related +// to SourceAckerNode. +type sourceAckerNodeTestHelper struct{} + +func (sourceAckerNodeTestHelper) newSourceAckerNode( + ctx context.Context, + is *is.I, + src connector.Source, +) (*SourceAckerNode, chan<- *Message, <-chan *Message) { node := &SourceAckerNode{ Name: "acker-node", Source: src, @@ -95,7 +203,15 @@ func TestSourceAckerNode_AckOrder(t *testing.T) { is.NoErr(err) }() - // first send messages through the node in the correct order + return node, in, out +} + +func (sourceAckerNodeTestHelper) sendMessages( + ctx context.Context, + count int, + in chan<- *Message, + out <-chan *Message, +) []*Message { messages := make([]*Message, count) for i := 0; i < count; i++ { m := &Message{ @@ -108,20 +224,35 @@ func TestSourceAckerNode_AckOrder(t *testing.T) { <-out messages[i] = m } + return messages +} + +func (sourceAckerNodeTestHelper) expectAcks( + ctx context.Context, + messages []*Message, + src *mock.Source, +) []*gomock.Call { + count := len(messages) - // expect to receive an acks in the same order as the order of the messages - expectedPosition := 0 + // expect to receive acks successfully expectedCalls := make([]*gomock.Call, count) for i := 0; i < count; i++ { expectedCalls[i] = src.EXPECT(). Ack(ctx, messages[i].Record.Position). - Do(func(context.Context, record.Position) { expectedPosition++ }). Return(nil) } - gomock.InOrder(expectedCalls...) // enforce order - // ack messages concurrently in random order - var wg sync.WaitGroup + return expectedCalls +} + +func (sourceAckerNodeTestHelper) ackMessagesConcurrently( + wg *sync.WaitGroup, + messages []*Message, + assertAckErr func(*Message, error), +) { + const maxSleep = time.Millisecond + count := len(messages) + wg.Add(count) for i := 0; i < count; i++ { go func(msg *Message) { @@ -130,26 +261,25 @@ func TestSourceAckerNode_AckOrder(t *testing.T) { //nolint:gosec // math/rand is good enough for a test time.Sleep(time.Duration(rand.Int63n(int64(maxSleep/time.Nanosecond))) * time.Nanosecond) err := msg.Ack() - is.NoErr(err) + assertAckErr(msg, err) }(messages[i]) } +} - // gracefully stop node and give the test 1 second to finish - close(in) - +func (sourceAckerNodeTestHelper) wait(ctx context.Context, wg *sync.WaitGroup, timeout time.Duration) error { wgDone := make(chan struct{}) go func() { defer close(wgDone) wg.Wait() }() - waitCtx, cancel := context.WithTimeout(ctx, time.Second) + waitCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() select { case <-waitCtx.Done(): - is.Fail() // expected to receive all acks in time + return waitCtx.Err() case <-wgDone: - // all good + return nil } } From 3e283dd58df698410107e220d09ed981100d168d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Tue, 28 Jun 2022 20:49:54 +0200 Subject: [PATCH 18/46] use cerrors --- pkg/pipeline/stream/source_acker_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/pipeline/stream/source_acker_test.go b/pkg/pipeline/stream/source_acker_test.go index 12307cbd6..ad3d53385 100644 --- a/pkg/pipeline/stream/source_acker_test.go +++ b/pkg/pipeline/stream/source_acker_test.go @@ -16,7 +16,6 @@ package stream import ( "context" - "errors" "math/rand" "strconv" "sync" @@ -25,6 +24,7 @@ import ( "github.com/conduitio/conduit/pkg/connector" "github.com/conduitio/conduit/pkg/connector/mock" + "github.com/conduitio/conduit/pkg/foundation/cerrors" "github.com/conduitio/conduit/pkg/record" "github.com/golang/mock/gomock" "github.com/matryer/is" @@ -110,7 +110,7 @@ func TestSourceAckerNode_FailedAck(t *testing.T) { expectedCalls := helper.expectAcks(ctx, messages[:500], src) gomock.InOrder(expectedCalls...) // enforce order of acks // the 500th message should be acked unsuccessfully - wantErr := errors.New("test error") + wantErr := cerrors.New("test error") src.EXPECT(). Ack(ctx, messages[500].Record.Position). Return(wantErr). @@ -170,7 +170,7 @@ func TestSourceAckerNode_FailedNack(t *testing.T) { }, ) - wantErr := errors.New("test error") + wantErr := cerrors.New("test error") err := messages[500].Nack(wantErr) is.True(err != nil) // expected the 500th message nack to fail with specific error From ce14149064fbaf6391f2ccc283a7ff0cf44baa5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Wed, 29 Jun 2022 13:22:38 +0200 Subject: [PATCH 19/46] update plugin interface --- go.mod | 2 +- go.sum | 2 + pkg/plugin/acceptance_testing.go | 338 ++++++++++--------------------- pkg/plugin/plugin.go | 28 ++- 4 files changed, 130 insertions(+), 240 deletions(-) diff --git a/go.mod b/go.mod index ca0c9d453..d31ba4401 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/conduitio/conduit-connector-generator v0.1.0 github.com/conduitio/conduit-connector-kafka v0.1.1 github.com/conduitio/conduit-connector-postgres v0.1.0 - github.com/conduitio/conduit-connector-protocol v0.2.0 + github.com/conduitio/conduit-connector-protocol v0.2.1-0.20220608133528-f466a956bd4d github.com/conduitio/conduit-connector-s3 v0.1.1 github.com/conduitio/conduit-connector-sdk v0.2.0 github.com/dgraph-io/badger/v3 v3.2103.2 diff --git a/go.sum b/go.sum index 2e59c153b..92d9402d0 100644 --- a/go.sum +++ b/go.sum @@ -157,6 +157,8 @@ github.com/conduitio/conduit-connector-postgres v0.1.0 h1:Dj2S1NrwnJaUOgQqb9MjGS github.com/conduitio/conduit-connector-postgres v0.1.0/go.mod h1:ug4N+2pGKDbG5UN++w7xRqb0A5Ua2J5Ld5wUzLbU1Q0= github.com/conduitio/conduit-connector-protocol v0.2.0 h1:gwYXVKEMgTtU67ephQ5WwTGIDbT/eTLA9Mdr9Bnbqxc= github.com/conduitio/conduit-connector-protocol v0.2.0/go.mod h1:udCU2AkLcYQoLjAO06tHVL2iFJPw+DamK+wllnj50hk= +github.com/conduitio/conduit-connector-protocol v0.2.1-0.20220608133528-f466a956bd4d h1:f3R0yPiH45hDZwNcYMSzKJP6LOGQPELCqW9OkZmd2lA= +github.com/conduitio/conduit-connector-protocol v0.2.1-0.20220608133528-f466a956bd4d/go.mod h1:1nmTaD+l3mvq3PnMmPPx8UxHPM53Xk8zGT3URu2Xx2M= github.com/conduitio/conduit-connector-s3 v0.1.1 h1:10uIakNmF65IN5TNJB1qPWC6vbdGgrHEMg8r+dxDrc8= github.com/conduitio/conduit-connector-s3 v0.1.1/go.mod h1:xpfBzOGjZkkglTmF1444qEjXuEx+do1PTYZNroPFcSE= github.com/conduitio/conduit-connector-sdk v0.2.0 h1:yReJT3SOAGqJIlk59WC5FPgpv0Gg+NG4NFj6FJ89XnM= diff --git a/pkg/plugin/acceptance_testing.go b/pkg/plugin/acceptance_testing.go index 943dd483b..562f241ae 100644 --- a/pkg/plugin/acceptance_testing.go +++ b/pkg/plugin/acceptance_testing.go @@ -90,6 +90,7 @@ type testDispenserFunc func(*testing.T) (Dispenser, *mock.SpecifierPlugin, *mock // --------------- func testSpecifier_Specify_Success(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) dispenser, mockSpecifier, _, _ := tdf(t) want := Specification{ @@ -127,14 +128,10 @@ func testSpecifier_Specify_Success(t *testing.T, tdf testDispenserFunc) { }, nil) specifier, err := dispenser.DispenseSpecifier() - if err != nil { - t.Fatalf("error dispensing specifier: %+v", err) - } + is.NoErr(err) got, err := specifier.Specify() - if err != nil { - t.Fatalf("error dispensing specifier: %+v", err) - } + is.NoErr(err) if diff := cmp.Diff(got, want); diff != "" { t.Errorf("expected specification: %s", diff) @@ -142,6 +139,7 @@ func testSpecifier_Specify_Success(t *testing.T, tdf testDispenserFunc) { } func testSpecifier_Specify_Fail(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) dispenser, mockSpecifier, _, _ := tdf(t) want := cerrors.New("specify error") @@ -150,14 +148,10 @@ func testSpecifier_Specify_Fail(t *testing.T, tdf testDispenserFunc) { Return(cpluginv1.SpecifierSpecifyResponse{}, want) specifier, err := dispenser.DispenseSpecifier() - if err != nil { - t.Fatalf("error dispensing specifier: %+v", err) - } + is.NoErr(err) _, got := specifier.Specify() - if got.Error() != want.Error() { - t.Fatalf("want: %+v, got: %+v", want, got) - } + is.Equal(got.Error(), want.Error()) } // ------------ @@ -165,6 +159,7 @@ func testSpecifier_Specify_Fail(t *testing.T, tdf testDispenserFunc) { // ------------ func testSource_Configure_Success(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, mockSource, _ := tdf(t) @@ -178,17 +173,14 @@ func testSource_Configure_Success(t *testing.T, tdf testDispenserFunc) { Return(cpluginv1.SourceConfigureResponse{}, want) source, err := dispenser.DispenseSource() - if err != nil { - t.Fatalf("error dispensing source: %+v", err) - } + is.NoErr(err) got := source.Configure(ctx, cfg) - if got.Error() != want.Error() { - t.Fatalf("want: %+v, got: %+v", want, got) - } + is.Equal(got.Error(), want.Error()) } func testSource_Configure_Fail(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, mockSource, _ := tdf(t) @@ -197,17 +189,14 @@ func testSource_Configure_Fail(t *testing.T, tdf testDispenserFunc) { Return(cpluginv1.SourceConfigureResponse{}, nil) source, err := dispenser.DispenseSource() - if err != nil { - t.Fatalf("error dispensing source: %+v", err) - } + is.NoErr(err) got := source.Configure(ctx, map[string]string{}) - if got != nil { - t.Fatalf("want: nil, got: %+v", got) - } + is.Equal(got, nil) } func testSource_Start_WithPosition(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, mockSource, _ := tdf(t) @@ -227,14 +216,10 @@ func testSource_Start_WithPosition(t *testing.T, tdf testDispenserFunc) { }) source, err := dispenser.DispenseSource() - if err != nil { - t.Fatalf("error dispensing source: %+v", err) - } + is.NoErr(err) err = source.Start(ctx, pos) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) select { case <-closeCh: @@ -244,6 +229,7 @@ func testSource_Start_WithPosition(t *testing.T, tdf testDispenserFunc) { } func testSource_Start_EmptyPosition(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, mockSource, _ := tdf(t) @@ -261,14 +247,10 @@ func testSource_Start_EmptyPosition(t *testing.T, tdf testDispenserFunc) { }) source, err := dispenser.DispenseSource() - if err != nil { - t.Fatalf("error dispensing source: %+v", err) - } + is.NoErr(err) err = source.Start(ctx, nil) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) select { case <-closeCh: @@ -278,6 +260,7 @@ func testSource_Start_EmptyPosition(t *testing.T, tdf testDispenserFunc) { } func testSource_Read_Success(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, mockSource, _ := tdf(t) @@ -323,21 +306,15 @@ func testSource_Read_Success(t *testing.T, tdf testDispenserFunc) { }) source, err := dispenser.DispenseSource() - if err != nil { - t.Fatalf("error dispensing source: %+v", err) - } + is.NoErr(err) err = source.Start(ctx, nil) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) var got []record.Record for i := 0; i < len(want); i++ { rec, err := source.Read(ctx) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) // read at is recorded when we receive the record, adjust in the expectation want[i].ReadAt = rec.ReadAt got = append(got, rec) @@ -349,21 +326,19 @@ func testSource_Read_Success(t *testing.T, tdf testDispenserFunc) { } func testSource_Read_WithoutStart(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, _, _ := tdf(t) source, err := dispenser.DispenseSource() - if err != nil { - t.Fatalf("error dispensing source: %+v", err) - } + is.NoErr(err) _, err = source.Read(ctx) - if !cerrors.Is(err, ErrStreamNotOpen) { - t.Fatalf("unexpected error: %+v", err) - } + is.True(cerrors.Is(err, ErrStreamNotOpen)) } func testSource_Read_AfterStop(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, mockSource, _ := tdf(t) @@ -381,28 +356,23 @@ func testSource_Read_AfterStop(t *testing.T, tdf testDispenserFunc) { Stop(gomock.Any(), cpluginv1.SourceStopRequest{}). DoAndReturn(func(context.Context, cpluginv1.SourceStopRequest) (cpluginv1.SourceStopResponse, error) { close(stopRunCh) - return cpluginv1.SourceStopResponse{}, nil + return cpluginv1.SourceStopResponse{ + LastPosition: []byte("foo"), + }, nil }) source, err := dispenser.DispenseSource() - if err != nil { - t.Fatalf("error dispensing source: %+v", err) - } + is.NoErr(err) err = source.Start(ctx, nil) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) - err = source.Stop(ctx) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + gotLastPosition, err := source.Stop(ctx) + is.NoErr(err) + is.Equal(gotLastPosition, record.Position("foo")) _, err = source.Read(ctx) - if !cerrors.Is(err, ErrStreamNotOpen) { - t.Fatalf("unexpected error: %+v", err) - } + is.True(cerrors.Is(err, ErrStreamNotOpen)) select { case <-stopRunCh: @@ -428,15 +398,11 @@ func testSource_Read_CancelContext(t *testing.T, tdf testDispenserFunc) { }) source, err := dispenser.DispenseSource() - if err != nil { - t.Fatalf("error dispensing source: %+v", err) - } + is.NoErr(err) startCtx, startCancel := context.WithCancel(ctx) err = source.Start(startCtx, nil) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) // calling read when source didn't produce records should block until start // ctx is cancelled @@ -448,17 +414,14 @@ func testSource_Read_CancelContext(t *testing.T, tdf testDispenserFunc) { is.True(err != nil) // TODO see if we can change this error into context.Canceled, right now we // follow the default gRPC behavior - if cerrors.Is(err, context.Canceled) { - t.Fatalf("unexpected error: %+v", err) - } - if cerrors.Is(err, ErrStreamNotOpen) { - t.Fatalf("unexpected error: %+v", err) - } + is.True(!cerrors.Is(err, context.Canceled)) + is.True(!cerrors.Is(err, ErrStreamNotOpen)) close(stopRunCh) // stop run channel } func testSource_Ack_Success(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, mockSource, _ := tdf(t) @@ -475,9 +438,7 @@ func testSource_Ack_Success(t *testing.T, tdf testDispenserFunc) { DoAndReturn(func(_ context.Context, stream cpluginv1.SourceRunStream) error { defer close(closeCh) got, err := stream.Recv() - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) if diff := cmp.Diff(got.AckPosition, want); diff != "" { t.Errorf("expected ack: %s", diff) } @@ -485,19 +446,13 @@ func testSource_Ack_Success(t *testing.T, tdf testDispenserFunc) { }) source, err := dispenser.DispenseSource() - if err != nil { - t.Fatalf("error dispensing source: %+v", err) - } + is.NoErr(err) err = source.Start(ctx, nil) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) err = source.Ack(ctx, record.Position("test-position")) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) select { case <-closeCh: @@ -510,27 +465,23 @@ func testSource_Ack_Success(t *testing.T, tdf testDispenserFunc) { // acking after the stream is closed should result in an error err = source.Ack(ctx, record.Position("test-position")) - if !cerrors.Is(err, ErrStreamNotOpen) { - t.Fatalf("unexpected error: %+v", err) - } + is.True(cerrors.Is(err, ErrStreamNotOpen)) } func testSource_Ack_WithoutStart(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, _, _ := tdf(t) source, err := dispenser.DispenseSource() - if err != nil { - t.Fatalf("error dispensing source: %+v", err) - } + is.NoErr(err) err = source.Ack(ctx, []byte("test-position")) - if !cerrors.Is(err, ErrStreamNotOpen) { - t.Fatalf("unexpected error: %+v", err) - } + is.True(cerrors.Is(err, ErrStreamNotOpen)) } func testSource_Run_Fail(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, mockSource, _ := tdf(t) @@ -551,19 +502,13 @@ func testSource_Run_Fail(t *testing.T, tdf testDispenserFunc) { }) source, err := dispenser.DispenseSource() - if err != nil { - t.Fatalf("error dispensing source: %+v", err) - } + is.NoErr(err) err = source.Start(ctx, nil) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) err = source.Ack(ctx, record.Position("test-position")) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) select { case <-closeCh: @@ -580,9 +525,7 @@ func testSource_Run_Fail(t *testing.T, tdf testDispenserFunc) { unwrapped = cerrors.Unwrap(unwrapped) } - if got.Error() != want.Error() { - t.Fatalf("want: %+v, got: %+v", want, got) - } + is.Equal(got.Error(), want.Error()) // Error is returned through the Ack function, that's the outgoing stream. err = source.Ack(ctx, record.Position("test-position")) @@ -593,12 +536,11 @@ func testSource_Run_Fail(t *testing.T, tdf testDispenserFunc) { unwrapped = cerrors.Unwrap(unwrapped) } - if got.Error() != want.Error() { - t.Fatalf("want: %+v, got: %+v", want, got) - } + is.Equal(got.Error(), want.Error()) } func testSource_Teardown_Success(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, mockSource, _ := tdf(t) @@ -620,19 +562,13 @@ func testSource_Teardown_Success(t *testing.T, tdf testDispenserFunc) { Return(cpluginv1.SourceTeardownResponse{}, want) source, err := dispenser.DispenseSource() - if err != nil { - t.Fatalf("error dispensing source: %+v", err) - } + is.NoErr(err) err = source.Start(ctx, nil) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + is.NoErr(err) got := source.Teardown(ctx) - if got.Error() != want.Error() { - t.Fatalf("want: %+v, got: %+v", want, got) - } + is.Equal(got.Error(), want.Error()) close(stopRunCh) select { @@ -647,6 +583,7 @@ func testSource_Teardown_Success(t *testing.T, tdf testDispenserFunc) { // ----------------- func testDestination_Configure_Success(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, _, mockDestination := tdf(t) @@ -660,17 +597,14 @@ func testDestination_Configure_Success(t *testing.T, tdf testDispenserFunc) { Return(cpluginv1.DestinationConfigureResponse{}, want) destination, err := dispenser.DispenseDestination() - if err != nil { - t.Fatalf("error dispensing destination: %+v", err) - } + is.NoErr(err) got := destination.Configure(ctx, cfg) - if got.Error() != want.Error() { - t.Fatalf("want: %+v, got: %+v", want, got) - } + is.Equal(got.Error(), want.Error()) } func testDestination_Configure_Fail(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, _, mockDestination := tdf(t) @@ -679,17 +613,14 @@ func testDestination_Configure_Fail(t *testing.T, tdf testDispenserFunc) { Return(cpluginv1.DestinationConfigureResponse{}, nil) destination, err := dispenser.DispenseDestination() - if err != nil { - t.Fatalf("error dispensing destination: %+v", err) - } + is.NoErr(err) err = destination.Configure(ctx, map[string]string{}) - if err != nil { - t.Fatalf("want: nil, got: %+v", err) - } + is.NoErr(err) } func testDestination_Start_Success(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, _, mockDestination := tdf(t) @@ -707,14 +638,10 @@ func testDestination_Start_Success(t *testing.T, tdf testDispenserFunc) { }) destination, err := dispenser.DispenseDestination() - if err != nil { - t.Fatalf("error dispensing destination: %+v", err) - } + is.NoErr(err) err = destination.Start(ctx) - if err != nil { - t.Fatalf("want: nil, got: %+v", err) - } + is.NoErr(err) select { case <-closeCh: @@ -724,6 +651,7 @@ func testDestination_Start_Success(t *testing.T, tdf testDispenserFunc) { } func testDestination_Start_Fail(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, _, mockDestination := tdf(t) @@ -734,17 +662,14 @@ func testDestination_Start_Fail(t *testing.T, tdf testDispenserFunc) { Return(cpluginv1.DestinationStartResponse{}, want) destination, err := dispenser.DispenseDestination() - if err != nil { - t.Fatalf("error dispensing destination: %+v", err) - } + is.NoErr(err) got := destination.Start(ctx) - if got.Error() != want.Error() { - t.Fatalf("want: %+v, got: %+v", want, got) - } + is.Equal(got.Error(), want.Error()) } func testDestination_Write_Success(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, _, mockDestination := tdf(t) @@ -767,9 +692,7 @@ func testDestination_Write_Success(t *testing.T, tdf testDispenserFunc) { DoAndReturn(func(_ context.Context, stream cpluginv1.DestinationRunStream) error { defer close(closeCh) got, err := stream.Recv() - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) if diff := cmp.Diff(got.Record, want); diff != "" { t.Errorf("expected ack: %s", diff) } @@ -777,14 +700,10 @@ func testDestination_Write_Success(t *testing.T, tdf testDispenserFunc) { }) destination, err := dispenser.DispenseDestination() - if err != nil { - t.Fatalf("error dispensing destination: %+v", err) - } + is.NoErr(err) err = destination.Start(ctx) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) err = destination.Write(ctx, record.Record{ Position: want.Position, @@ -793,9 +712,7 @@ func testDestination_Write_Success(t *testing.T, tdf testDispenserFunc) { Key: record.RawData{Raw: want.Key.(cpluginv1.RawData)}, Payload: record.StructuredData{"baz": "qux"}, }) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) select { case <-closeCh: @@ -807,27 +724,23 @@ func testDestination_Write_Success(t *testing.T, tdf testDispenserFunc) { time.Sleep(time.Millisecond * 50) err = destination.Write(ctx, record.Record{}) - if !cerrors.Is(err, ErrStreamNotOpen) { - t.Fatalf("unexpected error: %+v", err) - } + is.True(cerrors.Is(err, ErrStreamNotOpen)) } func testDestination_Write_WithoutStart(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, _, _ := tdf(t) destination, err := dispenser.DispenseDestination() - if err != nil { - t.Fatalf("error dispensing destination: %+v", err) - } + is.NoErr(err) err = destination.Write(ctx, record.Record{}) - if !cerrors.Is(err, ErrStreamNotOpen) { - t.Fatalf("unexpected error: %+v", err) - } + is.True(cerrors.Is(err, ErrStreamNotOpen)) } func testDestination_Ack_Success(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, _, mockDestination := tdf(t) @@ -846,29 +759,21 @@ func testDestination_Ack_Success(t *testing.T, tdf testDispenserFunc) { err := stream.Send(cpluginv1.DestinationRunResponse{ AckPosition: p, }) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) } return nil }) destination, err := dispenser.DispenseDestination() - if err != nil { - t.Fatalf("error dispensing destination: %+v", err) - } + is.NoErr(err) err = destination.Start(ctx) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) var got []record.Position for i := 0; i < len(want); i++ { pos, err := destination.Ack(ctx) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) got = append(got, pos) } @@ -878,6 +783,7 @@ func testDestination_Ack_Success(t *testing.T, tdf testDispenserFunc) { } func testDestination_Ack_WithError(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, _, mockDestination := tdf(t) @@ -894,44 +800,33 @@ func testDestination_Ack_WithError(t *testing.T, tdf testDispenserFunc) { AckPosition: wantPos, Error: wantErr.Error(), }) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) return nil }) destination, err := dispenser.DispenseDestination() - if err != nil { - t.Fatalf("error dispensing destination: %+v", err) - } + is.NoErr(err) err = destination.Start(ctx) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) gotPos, gotErr := destination.Ack(ctx) if diff := cmp.Diff(gotPos, wantPos); diff != "" { t.Errorf("expected position: %s", diff) } - if gotErr.Error() != wantErr.Error() { - t.Fatalf("want: %+v, got: %+v", wantErr, gotErr) - } + is.Equal(gotErr.Error(), wantErr.Error()) } func testDestination_Ack_WithoutStart(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, _, _ := tdf(t) destination, err := dispenser.DispenseDestination() - if err != nil { - t.Fatalf("error dispensing destination: %+v", err) - } + is.NoErr(err) _, err = destination.Ack(ctx) - if !cerrors.Is(err, ErrStreamNotOpen) { - t.Fatalf("unexpected error: %+v", err) - } + is.True(cerrors.Is(err, ErrStreamNotOpen)) } func testDestination_Run_Fail(t *testing.T, tdf testDispenserFunc) { @@ -956,19 +851,13 @@ func testDestination_Run_Fail(t *testing.T, tdf testDispenserFunc) { }) destination, err := dispenser.DispenseDestination() - if err != nil { - t.Fatalf("error dispensing destination: %+v", err) - } + is.NoErr(err) err = destination.Start(ctx) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) err = destination.Write(ctx, record.Record{}) - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } + is.NoErr(err) select { case <-closeCh: @@ -992,6 +881,7 @@ func testDestination_Run_Fail(t *testing.T, tdf testDispenserFunc) { } func testDestination_Teardown_Success(t *testing.T, tdf testDispenserFunc) { + is := is.New(t) ctx := context.Background() dispenser, _, _, mockDestination := tdf(t) @@ -1016,23 +906,15 @@ func testDestination_Teardown_Success(t *testing.T, tdf testDispenserFunc) { Return(cpluginv1.DestinationTeardownResponse{}, want) destination, err := dispenser.DispenseDestination() - if err != nil { - t.Fatalf("error dispensing destination: %+v", err) - } + is.NoErr(err) err = destination.Start(ctx) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - err = destination.Stop(ctx) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + is.NoErr(err) + err = destination.Stop(ctx, nil) + is.NoErr(err) got := destination.Teardown(ctx) - if got.Error() != want.Error() { - t.Fatalf("want: %+v, got: %+v", want, got) - } + is.Equal(got.Error(), want.Error()) close(stopRunCh) select { @@ -1053,7 +935,9 @@ func testDestination_Stop_CloseSend(t *testing.T, tdf testDispenserFunc) { Start(gomock.Any(), cpluginv1.DestinationStartRequest{}). Return(cpluginv1.DestinationStartResponse{}, nil) mockDestination.EXPECT(). - Stop(gomock.Any(), cpluginv1.DestinationStopRequest{}). + Stop(gomock.Any(), cpluginv1.DestinationStopRequest{ + LastPosition: []byte("foo"), + }). Return(cpluginv1.DestinationStopResponse{}, nil) mockDestination.EXPECT(). Run(gomock.Any(), gomock.Any()). @@ -1074,18 +958,12 @@ func testDestination_Stop_CloseSend(t *testing.T, tdf testDispenserFunc) { Return(cpluginv1.DestinationTeardownResponse{}, nil) destination, err := dispenser.DispenseDestination() - if err != nil { - t.Fatalf("error dispensing destination: %+v", err) - } + is.NoErr(err) err = destination.Start(ctx) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - err = destination.Stop(ctx) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + is.NoErr(err) + err = destination.Stop(ctx, record.Position("foo")) + is.NoErr(err) select { case <-closeCh: diff --git a/pkg/plugin/plugin.go b/pkg/plugin/plugin.go index ad9a3f7e6..5cf40a2c1 100644 --- a/pkg/plugin/plugin.go +++ b/pkg/plugin/plugin.go @@ -53,11 +53,14 @@ type SourcePlugin interface { // Stop should be called to invoke a graceful shutdown of the stream. It // will signal the plugin to stop retrieving new records and flush any - // records that might be cached. The stream will still remain open so - // Conduit can fetch the remaining records and send back any outstanding - // acks. After the stream is closed the Read method will return the - // appropriate error signaling the stream is closed. - Stop(context.Context) error + // records that might be cached. The response will contain the position of + // the last record in the stream. Conduit should keep reading records until + // it encounters the record with the last position. After it received all + // records and sent back acks for all successfully processed records it + // should call Teardown to close the stream. After the stream is closed the + // Read method will return the appropriate error signaling the stream is + // closed. + Stop(context.Context) (record.Position, error) // Teardown is the last call that must be issued before discarding the // plugin. It signals to the plugin it can release any open resources and @@ -91,13 +94,20 @@ type DestinationPlugin interface { // successfully processed the function returns the position and an error. Ack(context.Context) (record.Position, error) + // Stop signals to the plugin that the record with the specified position is + // the last one and no more records will be written to the stream after it. + // Once the plugin receives the last record it should flush any records that + // might be cached and not yet written to the 3rd party resource. + // Stop should be called to invoke a graceful shutdown of the stream. It - // will signal the plugin that no more records will be written to the stream - // and that it should flush any records that might be cached. The stream - // will still remain open so Conduit can fetch the remaining acks. After the + // will signal the plugin that after receiving the record with the last + // position no more records will be written to the stream and that the + // plugin should flush any records that might be cached. The stream will + // still remain open so Conduit can fetch the remaining acks. After all acks + // are received Conduit should call Teardown to close the stream. After the // stream is closed the Ack method will return the appropriate error // signaling the stream is closed. - Stop(context.Context) error + Stop(context.Context, record.Position) error // Teardown is the last call that must be issued before discarding the // plugin. It signals to the plugin it can release any open resources and From 3bdec6c66ab2687dc7a830d8170bd1894e0962df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Wed, 29 Jun 2022 13:22:52 +0200 Subject: [PATCH 20/46] update standalone plugin implementation --- pkg/plugin/standalone/v1/destination.go | 4 ++-- .../standalone/v1/internal/fromproto/source.go | 4 ++++ .../standalone/v1/internal/toproto/destination.go | 6 ++++-- pkg/plugin/standalone/v1/source.go | 13 ++++++++----- 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/pkg/plugin/standalone/v1/destination.go b/pkg/plugin/standalone/v1/destination.go index 3e515c6b6..09f4e2eed 100644 --- a/pkg/plugin/standalone/v1/destination.go +++ b/pkg/plugin/standalone/v1/destination.go @@ -123,7 +123,7 @@ func (s *destinationPluginClient) Ack(ctx context.Context) (record.Position, err return position, nil } -func (s *destinationPluginClient) Stop(ctx context.Context) error { +func (s *destinationPluginClient) Stop(ctx context.Context, lastPosition record.Position) error { var errOut error if s.stream == nil { return plugin.ErrStreamNotOpen @@ -134,7 +134,7 @@ func (s *destinationPluginClient) Stop(ctx context.Context) error { errOut = multierror.Append(errOut, unwrapGRPCError(err)) } - protoReq := toproto.DestinationStopRequest() + protoReq := toproto.DestinationStopRequest(lastPosition) protoResp, err := s.grpcClient.Stop(ctx, protoReq) if err != nil { errOut = multierror.Append(errOut, unwrapGRPCError(err)) diff --git a/pkg/plugin/standalone/v1/internal/fromproto/source.go b/pkg/plugin/standalone/v1/internal/fromproto/source.go index 0d60f42b4..858ab9c77 100644 --- a/pkg/plugin/standalone/v1/internal/fromproto/source.go +++ b/pkg/plugin/standalone/v1/internal/fromproto/source.go @@ -26,3 +26,7 @@ func SourceRunResponse(in *connectorv1.Source_Run_Response) (record.Record, erro } return out, nil } + +func SourceStopResponse(in *connectorv1.Source_Stop_Response) (record.Position, error) { + return in.LastPosition, nil +} diff --git a/pkg/plugin/standalone/v1/internal/toproto/destination.go b/pkg/plugin/standalone/v1/internal/toproto/destination.go index 0bee813aa..5e93e51e9 100644 --- a/pkg/plugin/standalone/v1/internal/toproto/destination.go +++ b/pkg/plugin/standalone/v1/internal/toproto/destination.go @@ -41,8 +41,10 @@ func DestinationRunRequest(in record.Record) (*connectorv1.Destination_Run_Reque return &out, nil } -func DestinationStopRequest() *connectorv1.Destination_Stop_Request { - return &connectorv1.Destination_Stop_Request{} +func DestinationStopRequest(in record.Position) *connectorv1.Destination_Stop_Request { + return &connectorv1.Destination_Stop_Request{ + LastPosition: in, + } } func DestinationTeardownRequest() *connectorv1.Destination_Teardown_Request { diff --git a/pkg/plugin/standalone/v1/source.go b/pkg/plugin/standalone/v1/source.go index f40e405bf..79d0bdbbf 100644 --- a/pkg/plugin/standalone/v1/source.go +++ b/pkg/plugin/standalone/v1/source.go @@ -140,18 +140,21 @@ func (s *sourcePluginClient) ackErrorCause(err error) error { return recvErr } -func (s *sourcePluginClient) Stop(ctx context.Context) error { +func (s *sourcePluginClient) Stop(ctx context.Context) (record.Position, error) { if s.stream == nil { - return plugin.ErrStreamNotOpen + return nil, plugin.ErrStreamNotOpen } protoReq := toproto.SourceStopRequest() protoResp, err := s.grpcClient.Stop(ctx, protoReq) if err != nil { - return unwrapGRPCError(err) + return nil, unwrapGRPCError(err) } - _ = protoResp // response is empty - return nil + goResp, err := fromproto.SourceStopResponse(protoResp) + if err != nil { + return nil, err + } + return goResp, nil } func (s *sourcePluginClient) Teardown(ctx context.Context) error { From 6ab51fc78c9977f381f7d35318532699a4f115aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Wed, 29 Jun 2022 13:27:49 +0200 Subject: [PATCH 21/46] update builtin plugin implementation --- pkg/plugin/builtin/v1/destination.go | 6 +++--- pkg/plugin/builtin/v1/internal/fromplugin/source.go | 4 ++++ .../builtin/v1/internal/toplugin/destination.go | 6 ++++-- pkg/plugin/builtin/v1/source.go | 13 ++++++++----- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/pkg/plugin/builtin/v1/destination.go b/pkg/plugin/builtin/v1/destination.go index ac8a6bfda..055e256f0 100644 --- a/pkg/plugin/builtin/v1/destination.go +++ b/pkg/plugin/builtin/v1/destination.go @@ -140,15 +140,15 @@ func (s *destinationPluginAdapter) Ack(ctx context.Context) (record.Position, er return position, nil } -func (s *destinationPluginAdapter) Stop(ctx context.Context) error { +func (s *destinationPluginAdapter) Stop(ctx context.Context, lastPosition record.Position) error { if s.stream == nil { return plugin.ErrStreamNotOpen } s.stream.stopSend() - s.logger.Trace(ctx).Msg("calling Stop") - resp, err := runSandbox(s.impl.Stop, s.withLogger(ctx), toplugin.DestinationStopRequest()) + s.logger.Trace(ctx).Bytes(log.RecordPositionField, lastPosition).Msg("calling Stop") + resp, err := runSandbox(s.impl.Stop, s.withLogger(ctx), toplugin.DestinationStopRequest(lastPosition)) if err != nil { return err } diff --git a/pkg/plugin/builtin/v1/internal/fromplugin/source.go b/pkg/plugin/builtin/v1/internal/fromplugin/source.go index 8fbd97c80..e8daeadf6 100644 --- a/pkg/plugin/builtin/v1/internal/fromplugin/source.go +++ b/pkg/plugin/builtin/v1/internal/fromplugin/source.go @@ -26,3 +26,7 @@ func SourceRunResponse(in cpluginv1.SourceRunResponse) (record.Record, error) { } return out, nil } + +func SourceStopResponse(in cpluginv1.SourceStopResponse) (record.Position, error) { + return in.LastPosition, nil +} diff --git a/pkg/plugin/builtin/v1/internal/toplugin/destination.go b/pkg/plugin/builtin/v1/internal/toplugin/destination.go index 0acda9bb1..3df9e13dc 100644 --- a/pkg/plugin/builtin/v1/internal/toplugin/destination.go +++ b/pkg/plugin/builtin/v1/internal/toplugin/destination.go @@ -44,8 +44,10 @@ func DestinationRunRequest(in record.Record) (cpluginv1.DestinationRunRequest, e return out, nil } -func DestinationStopRequest() cpluginv1.DestinationStopRequest { - return cpluginv1.DestinationStopRequest{} +func DestinationStopRequest(in record.Position) cpluginv1.DestinationStopRequest { + return cpluginv1.DestinationStopRequest{ + LastPosition: in, + } } func DestinationTeardownRequest() cpluginv1.DestinationTeardownRequest { diff --git a/pkg/plugin/builtin/v1/source.go b/pkg/plugin/builtin/v1/source.go index bec5b1f42..b3e8caa38 100644 --- a/pkg/plugin/builtin/v1/source.go +++ b/pkg/plugin/builtin/v1/source.go @@ -145,19 +145,22 @@ func (s *sourcePluginAdapter) Ack(ctx context.Context, p record.Position) error return nil } -func (s *sourcePluginAdapter) Stop(ctx context.Context) error { +func (s *sourcePluginAdapter) Stop(ctx context.Context) (record.Position, error) { if s.stream == nil { - return plugin.ErrStreamNotOpen + return nil, plugin.ErrStreamNotOpen } s.logger.Trace(ctx).Msg("calling Stop") resp, err := runSandbox(s.impl.Stop, s.withLogger(ctx), toplugin.SourceStopRequest()) if err != nil { - return err + return nil, err + } + out, err := fromplugin.SourceStopResponse(resp) + if err != nil { + return nil, err } - _ = resp // empty response - return nil + return out, nil } func (s *sourcePluginAdapter) Teardown(ctx context.Context) error { From 08115feb72de0a256d291560fcd5d210b5567bcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Wed, 29 Jun 2022 15:22:50 +0200 Subject: [PATCH 22/46] update connector --- pkg/connector/connector.go | 12 ++++++++---- pkg/connector/destination.go | 31 +++++++++++++++++++++++-------- pkg/connector/mock/connector.go | 22 +++++++++++++++++++--- pkg/connector/source.go | 21 +++++++++++---------- 4 files changed, 61 insertions(+), 25 deletions(-) diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index 3855a9c95..f875e205e 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -87,10 +87,10 @@ type Source interface { // processed and can be acknowledged. Ack(context.Context, record.Position) error - // Stop signals to the source to stop producing records. Note that after - // this call Read can still produce records that have been cached by the - // connector. - Stop(context.Context) error + // Stop signals to the source to stop producing records. After this call + // Read will produce records until the record with the last position has + // been read (Conduit might have already received that record). + Stop(context.Context) (record.Position, error) } // Destination is a connector that can write records to a destination. @@ -111,6 +111,10 @@ type Destination interface { // processed and returns the position of that record. If the record wasn't // successfully processed the function returns the position and an error. Ack(context.Context) (record.Position, error) + + // Stop signals to the destination that no more records will be produced + // after record with the last position. + Stop(context.Context) (record.Position, error) } // Config collects common data stored for a connector. diff --git a/pkg/connector/destination.go b/pkg/connector/destination.go index e2ad6de2f..22bd36db7 100644 --- a/pkg/connector/destination.go +++ b/pkg/connector/destination.go @@ -21,7 +21,6 @@ import ( "github.com/conduitio/conduit/pkg/foundation/cerrors" "github.com/conduitio/conduit/pkg/foundation/log" - "github.com/conduitio/conduit/pkg/foundation/multierror" "github.com/conduitio/conduit/pkg/plugin" "github.com/conduitio/conduit/pkg/record" ) @@ -156,6 +155,25 @@ func (s *destination) Open(ctx context.Context) error { return nil } +func (s *destination) Stop(ctx context.Context, lastPosition record.Position) error { + cleanup, err := s.preparePluginCall() + defer cleanup() + if err != nil { + return err + } + + s.logger.Debug(ctx). + Bytes(log.RecordPositionField, lastPosition). + Msg("sending stop signal to destination connector plugin") + err = s.plugin.Stop(ctx, lastPosition) + if err != nil { + return cerrors.Errorf("could not stop destination plugin: %w", err) + } + + s.logger.Debug(ctx).Msg("destination connector plugin successfully responded to stop signal") + return nil +} + func (s *destination) Teardown(ctx context.Context) error { // lock destination as we are about to mutate the plugin field s.m.Lock() @@ -164,23 +182,20 @@ func (s *destination) Teardown(ctx context.Context) error { return plugin.ErrPluginNotRunning } - s.logger.Debug(ctx).Msg("stopping destination connector plugin") - err := s.plugin.Stop(ctx) - - // wait for any calls to the plugin to stop running first (e.g. Ack or Write) + // wait for any calls to the plugin to stop running first (e.g. Stop, Ack or Write) s.wg.Wait() s.logger.Debug(ctx).Msg("tearing down destination connector plugin") - err = multierror.Append(err, s.plugin.Teardown(ctx)) + err := s.plugin.Teardown(ctx) s.plugin = nil s.persister.ConnectorStopped() if err != nil { - return cerrors.Errorf("could not tear down plugin: %w", err) + return cerrors.Errorf("could not tear down destination connector plugin: %w", err) } - s.logger.Info(ctx).Msg("connector plugin successfully torn down") + s.logger.Info(ctx).Msg("destination connector plugin successfully torn down") return nil } diff --git a/pkg/connector/mock/connector.go b/pkg/connector/mock/connector.go index 100b9efda..6f7520ab8 100644 --- a/pkg/connector/mock/connector.go +++ b/pkg/connector/mock/connector.go @@ -201,11 +201,12 @@ func (mr *SourceMockRecorder) State() *gomock.Call { } // Stop mocks base method. -func (m *Source) Stop(arg0 context.Context) error { +func (m *Source) Stop(arg0 context.Context) (record.Position, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Stop", arg0) - ret0, _ := ret[0].(error) - return ret0 + ret0, _ := ret[0].(record.Position) + ret1, _ := ret[1].(error) + return ret0, ret1 } // Stop indicates an expected call of Stop. @@ -442,6 +443,21 @@ func (mr *DestinationMockRecorder) State() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "State", reflect.TypeOf((*Destination)(nil).State)) } +// Stop mocks base method. +func (m *Destination) Stop(arg0 context.Context) (record.Position, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stop", arg0) + ret0, _ := ret[0].(record.Position) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Stop indicates an expected call of Stop. +func (mr *DestinationMockRecorder) Stop(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*Destination)(nil).Stop), arg0) +} + // Teardown mocks base method. func (m *Destination) Teardown(arg0 context.Context) error { m.ctrl.T.Helper() diff --git a/pkg/connector/source.go b/pkg/connector/source.go index 3843947d5..36fe9a978 100644 --- a/pkg/connector/source.go +++ b/pkg/connector/source.go @@ -156,21 +156,23 @@ func (s *source) Open(ctx context.Context) error { return nil } -func (s *source) Stop(ctx context.Context) error { +func (s *source) Stop(ctx context.Context) (record.Position, error) { cleanup, err := s.preparePluginCall() defer cleanup() if err != nil { - return err + return nil, err } - s.logger.Debug(ctx).Msg("stopping source connector plugin") - err = s.plugin.Stop(ctx) + s.logger.Debug(ctx).Msg("sending stop signal to source connector plugin") + lastPosition, err := s.plugin.Stop(ctx) if err != nil { - return cerrors.Errorf("could not stop plugin: %w", err) + return nil, cerrors.Errorf("could not stop source plugin: %w", err) } - s.logger.Info(ctx).Msg("connector plugin successfully stopped") - return nil + s.logger.Info(ctx). + Bytes(log.RecordPositionField, lastPosition). + Msg("source connector plugin successfully responded to stop signal") + return lastPosition, nil } func (s *source) Teardown(ctx context.Context) error { @@ -185,17 +187,16 @@ func (s *source) Teardown(ctx context.Context) error { s.wg.Wait() s.logger.Debug(ctx).Msg("tearing down source connector plugin") - err := s.plugin.Teardown(ctx) s.plugin = nil s.persister.ConnectorStopped() if err != nil { - return cerrors.Errorf("could not tear down plugin: %w", err) + return cerrors.Errorf("could not tear down source connector plugin: %w", err) } - s.logger.Info(ctx).Msg("connector plugin successfully torn down") + s.logger.Info(ctx).Msg("source connector plugin successfully torn down") return nil } From 65e087ad5f1aa94393bd6d7b4d51f0b3d68d59e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Wed, 29 Jun 2022 15:42:28 +0200 Subject: [PATCH 23/46] update nodes --- pkg/connector/connector.go | 2 +- pkg/connector/destination.go | 2 ++ pkg/connector/mock/connector.go | 13 ++++++------- pkg/connector/source.go | 2 +- pkg/pipeline/lifecycle.go | 10 +++++++--- pkg/pipeline/stream/node.go | 10 +++++----- pkg/pipeline/stream/source.go | 7 ++++--- pkg/pipeline/stream/stream_test.go | 14 +++++++------- 8 files changed, 33 insertions(+), 27 deletions(-) diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index f875e205e..891f3c66b 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -114,7 +114,7 @@ type Destination interface { // Stop signals to the destination that no more records will be produced // after record with the last position. - Stop(context.Context) (record.Position, error) + Stop(context.Context, record.Position) error } // Config collects common data stored for a connector. diff --git a/pkg/connector/destination.go b/pkg/connector/destination.go index 22bd36db7..695805ed2 100644 --- a/pkg/connector/destination.go +++ b/pkg/connector/destination.go @@ -59,6 +59,8 @@ type destination struct { wg sync.WaitGroup } +var _ Destination = (*destination)(nil) + func (s *destination) ID() string { return s.XID } diff --git a/pkg/connector/mock/connector.go b/pkg/connector/mock/connector.go index 6f7520ab8..ebd732a23 100644 --- a/pkg/connector/mock/connector.go +++ b/pkg/connector/mock/connector.go @@ -444,18 +444,17 @@ func (mr *DestinationMockRecorder) State() *gomock.Call { } // Stop mocks base method. -func (m *Destination) Stop(arg0 context.Context) (record.Position, error) { +func (m *Destination) Stop(arg0 context.Context, arg1 record.Position) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Stop", arg0) - ret0, _ := ret[0].(record.Position) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret := m.ctrl.Call(m, "Stop", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 } // Stop indicates an expected call of Stop. -func (mr *DestinationMockRecorder) Stop(arg0 interface{}) *gomock.Call { +func (mr *DestinationMockRecorder) Stop(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*Destination)(nil).Stop), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*Destination)(nil).Stop), arg0, arg1) } // Teardown mocks base method. diff --git a/pkg/connector/source.go b/pkg/connector/source.go index 36fe9a978..fe7570a68 100644 --- a/pkg/connector/source.go +++ b/pkg/connector/source.go @@ -58,7 +58,7 @@ type source struct { wg sync.WaitGroup } -// not running -> running -> stopping -> not running +var _ Source = (*source)(nil) func (s *source) ID() string { return s.XID diff --git a/pkg/pipeline/lifecycle.go b/pkg/pipeline/lifecycle.go index 3c2e6f17f..cac735f29 100644 --- a/pkg/pipeline/lifecycle.go +++ b/pkg/pipeline/lifecycle.go @@ -86,15 +86,19 @@ func (s *Service) stopWithReason(ctx context.Context, pl *Instance, reason error } s.logger.Debug(ctx).Str(log.PipelineIDField, pl.ID).Msg("stopping pipeline") + var err error for _, n := range pl.n { if node, ok := n.(stream.StoppableNode); ok { // stop all pub nodes s.logger.Trace(ctx).Str(log.NodeIDField, n.ID()).Msg("stopping node") - node.Stop(reason) + stopErr := node.Stop(ctx, reason) + if stopErr != nil { + s.logger.Err(ctx, stopErr).Str(log.NodeIDField, n.ID()).Msg("stop failed") + err = multierror.Append(err, stopErr) + } } } - - return nil + return err } // StopAll will ask all the pipelines to stop gracefully diff --git a/pkg/pipeline/stream/node.go b/pkg/pipeline/stream/node.go index 36f725a80..6cb8d1d48 100644 --- a/pkg/pipeline/stream/node.go +++ b/pkg/pipeline/stream/node.go @@ -98,11 +98,11 @@ type StoppableNode interface { // Stop signals a running StopNode that it should gracefully shutdown. It // should stop producing new messages, wait to receive acks/nacks for any - // in-flight messages, close the outgoing channel and return nil from - // Node.Run. Stop should return right away, not waiting for the node to - // actually stop running. If the node is not running the function does not - // do anything. The reason supplied to Stop will be returned by Node.Run. - Stop(reason error) + // in-flight messages, close the outgoing channel and return from Node.Run. + // Stop should return right away, not waiting for the node to actually stop + // running. If the node is not running the function does not do anything. + // The reason supplied to Stop will be returned by Node.Run. + Stop(ctx context.Context, reason error) error } // LoggingNode is a node which expects a logger. diff --git a/pkg/pipeline/stream/source.go b/pkg/pipeline/stream/source.go index cb5efbdb2..a182a0dad 100644 --- a/pkg/pipeline/stream/source.go +++ b/pkg/pipeline/stream/source.go @@ -121,11 +121,12 @@ func (n *SourceNode) Run(ctx context.Context) (err error) { } } -func (n *SourceNode) Stop(reason error) { - ctx := context.TODO() // TODO get context as parameter +func (n *SourceNode) Stop(ctx context.Context, reason error) error { n.logger.Err(ctx, reason).Msg("stopping source connector") n.stopReason = reason - _ = n.Source.Stop(ctx) // TODO return error + lastPosition, err := n.Source.Stop(ctx) + _ = lastPosition // TODO store lastPosition + return err } func (n *SourceNode) Pub() <-chan *Message { diff --git a/pkg/pipeline/stream/stream_test.go b/pkg/pipeline/stream/stream_test.go index 1acc9c814..70426ec79 100644 --- a/pkg/pipeline/stream/stream_test.go +++ b/pkg/pipeline/stream/stream_test.go @@ -79,7 +79,7 @@ func Example_simpleStream() { go runNode(ctx, &wg, node1) // stop node after 150ms, which should be enough to process the 10 messages - time.AfterFunc(150*time.Millisecond, func() { node1.Stop(nil) }) + time.AfterFunc(150*time.Millisecond, func() { _ = node1.Stop(ctx, nil) }) // give the node some time to process the records, plus a bit of time to stop if waitTimeout(&wg, 1000*time.Millisecond) { killAll() @@ -197,8 +197,8 @@ func Example_complexStream() { time.AfterFunc( 250*time.Millisecond, func() { - node1.Stop(nil) - node3.Stop(nil) + _ = node1.Stop(ctx, nil) + _ = node3.Stop(ctx, nil) }, ) // give the nodes some time to process the records, plus a bit of time to stop @@ -310,13 +310,13 @@ func generatorSource(ctrl *gomock.Controller, logger log.CtxLogger, nodeID strin source.EXPECT().Read(gomock.Any()).DoAndReturn(func(ctx context.Context) (record.Record, error) { time.Sleep(delay) - position++ - if position > recordCount { + if position == recordCount { // block until Stop is called <-stop return record.Record{}, plugin.ErrStreamNotOpen } + position++ return record.Record{ // SourceID would normally be the source node ID, but since we need // to add the node ID to the position to create unique positions we @@ -325,9 +325,9 @@ func generatorSource(ctrl *gomock.Controller, logger log.CtxLogger, nodeID strin Position: record.Position(nodeID + "-" + strconv.Itoa(position)), }, nil }).MinTimes(recordCount + 1) - source.EXPECT().Stop(gomock.Any()).DoAndReturn(func(context.Context) error { + source.EXPECT().Stop(gomock.Any()).DoAndReturn(func(context.Context) (record.Position, error) { close(stop) - return nil + return record.Position(nodeID + "-" + strconv.Itoa(position)), nil }) source.EXPECT().Errors().Return(make(chan error)) From ffd3637c49983397d5120fa6dfce41af4db94028 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Wed, 29 Jun 2022 16:57:05 +0200 Subject: [PATCH 24/46] change plugin semantics, close stream on teardown --- pkg/plugin/acceptance_testing.go | 30 +++------- pkg/plugin/builtin/v1/destination.go | 75 +++++++++++++------------ pkg/plugin/builtin/v1/source.go | 26 +++++++-- pkg/plugin/standalone/v1/destination.go | 22 ++++---- pkg/plugin/standalone/v1/source.go | 16 ------ 5 files changed, 78 insertions(+), 91 deletions(-) diff --git a/pkg/plugin/acceptance_testing.go b/pkg/plugin/acceptance_testing.go index 562f241ae..a995dc079 100644 --- a/pkg/plugin/acceptance_testing.go +++ b/pkg/plugin/acceptance_testing.go @@ -74,7 +74,7 @@ func AcceptanceTestV1(t *testing.T, tdf testDispenserFunc) { run(t, tdf, testDestination_Ack_WithoutStart) run(t, tdf, testDestination_Run_Fail) run(t, tdf, testDestination_Teardown_Success) - run(t, tdf, testDestination_Stop_CloseSend) + run(t, tdf, testDestination_Teardown_CloseSend) } func run(t *testing.T, tdf testDispenserFunc, test func(*testing.T, testDispenserFunc)) { @@ -527,16 +527,9 @@ func testSource_Run_Fail(t *testing.T, tdf testDispenserFunc) { is.Equal(got.Error(), want.Error()) - // Error is returned through the Ack function, that's the outgoing stream. + // Ack returns just a generic error err = source.Ack(ctx, record.Position("test-position")) - // Unwrap inner-most error - got = nil - for unwrapped := err; unwrapped != nil; { - got = unwrapped - unwrapped = cerrors.Unwrap(unwrapped) - } - - is.Equal(got.Error(), want.Error()) + is.True(cerrors.Is(err, ErrStreamNotOpen)) } func testSource_Teardown_Success(t *testing.T, tdf testDispenserFunc) { @@ -924,7 +917,7 @@ func testDestination_Teardown_Success(t *testing.T, tdf testDispenserFunc) { } } -func testDestination_Stop_CloseSend(t *testing.T, tdf testDispenserFunc) { +func testDestination_Teardown_CloseSend(t *testing.T, tdf testDispenserFunc) { is := is.New(t) ctx := context.Background() @@ -945,12 +938,6 @@ func testDestination_Stop_CloseSend(t *testing.T, tdf testDispenserFunc) { _, recvErr := stream.Recv() is.Equal(recvErr, io.EOF) close(closeCh) - - // we should still be able to send acks back even if incoming stream - // is closed - sendErr := stream.Send(cpluginv1.DestinationRunResponse{}) - is.NoErr(sendErr) - return recvErr }) mockDestination.EXPECT(). @@ -965,6 +952,9 @@ func testDestination_Stop_CloseSend(t *testing.T, tdf testDispenserFunc) { err = destination.Stop(ctx, record.Position("foo")) is.NoErr(err) + err = destination.Teardown(ctx) + is.NoErr(err) + select { case <-closeCh: // all good, outgoing stream was closed @@ -972,10 +962,4 @@ func testDestination_Stop_CloseSend(t *testing.T, tdf testDispenserFunc) { is.Fail() // expected outgoing stream to be closed } - // fetching an ack should still work - _, err = destination.Ack(ctx) - is.NoErr(err) - - err = destination.Teardown(ctx) - is.NoErr(err) } diff --git a/pkg/plugin/builtin/v1/destination.go b/pkg/plugin/builtin/v1/destination.go index 055e256f0..c01fd3bc8 100644 --- a/pkg/plugin/builtin/v1/destination.go +++ b/pkg/plugin/builtin/v1/destination.go @@ -92,9 +92,11 @@ func (s *destinationPluginAdapter) Start(ctx context.Context) error { s.logger.Trace(ctx).Msg("calling Run") err := runSandboxNoResp(s.impl.Run, s.withLogger(ctx), cpluginv1.DestinationRunStream(s.stream)) if err != nil { - s.stream.stopAll(cerrors.Errorf("error in run: %w", err)) + if s.stream.stop(cerrors.Errorf("error in run: %w", err)) { + s.logger.Err(ctx, err).Msg("stream already stopped") + } } else { - s.stream.stopAll(plugin.ErrStreamNotOpen) + s.stream.stop(plugin.ErrStreamNotOpen) } s.logger.Trace(ctx).Msg("Run stopped") }() @@ -145,8 +147,6 @@ func (s *destinationPluginAdapter) Stop(ctx context.Context, lastPosition record return plugin.ErrStreamNotOpen } - s.stream.stopSend() - s.logger.Trace(ctx).Bytes(log.RecordPositionField, lastPosition).Msg("calling Stop") resp, err := runSandbox(s.impl.Stop, s.withLogger(ctx), toplugin.DestinationStopRequest(lastPosition)) if err != nil { @@ -158,6 +158,11 @@ func (s *destinationPluginAdapter) Stop(ctx context.Context, lastPosition record } func (s *destinationPluginAdapter) Teardown(ctx context.Context) error { + if s.stream != nil { + // stop stream if it's open + _ = s.stream.stop(nil) + } + s.logger.Trace(ctx).Msg("calling Teardown") resp, err := runSandbox(s.impl.Teardown, s.withLogger(ctx), toplugin.DestinationTeardownRequest()) if err != nil { @@ -169,25 +174,19 @@ func (s *destinationPluginAdapter) Teardown(ctx context.Context) error { } func newDestinationRunStream(ctx context.Context) *destinationRunStream { - sendCtx, sendCancel := context.WithCancel(ctx) - recvCtx, recvCancel := context.WithCancel(ctx) return &destinationRunStream{ - sendCtx: sendCtx, - recvCtx: recvCtx, - closeSend: sendCancel, - closeRecv: recvCancel, - reqChan: make(chan cpluginv1.DestinationRunRequest), - respChan: make(chan cpluginv1.DestinationRunResponse), + ctx: ctx, + stopChan: make(chan struct{}), + reqChan: make(chan cpluginv1.DestinationRunRequest), + respChan: make(chan cpluginv1.DestinationRunResponse), } } type destinationRunStream struct { - sendCtx context.Context - recvCtx context.Context - closeSend context.CancelFunc - closeRecv context.CancelFunc - reqChan chan cpluginv1.DestinationRunRequest - respChan chan cpluginv1.DestinationRunResponse + ctx context.Context + stopChan chan struct{} + reqChan chan cpluginv1.DestinationRunRequest + respChan chan cpluginv1.DestinationRunResponse reason error m sync.RWMutex @@ -195,7 +194,9 @@ type destinationRunStream struct { func (s *destinationRunStream) Send(resp cpluginv1.DestinationRunResponse) error { select { - case <-s.sendCtx.Done(): + case <-s.ctx.Done(): + return s.ctx.Err() + case <-s.stopChan: return io.EOF case s.respChan <- resp: return nil @@ -204,7 +205,9 @@ func (s *destinationRunStream) Send(resp cpluginv1.DestinationRunResponse) error func (s *destinationRunStream) Recv() (cpluginv1.DestinationRunRequest, error) { select { - case <-s.recvCtx.Done(): + case <-s.ctx.Done(): + return cpluginv1.DestinationRunRequest{}, s.ctx.Err() + case <-s.stopChan: return cpluginv1.DestinationRunRequest{}, io.EOF case req := <-s.reqChan: return req, nil @@ -212,10 +215,10 @@ func (s *destinationRunStream) Recv() (cpluginv1.DestinationRunRequest, error) { } func (s *destinationRunStream) recvInternal() (cpluginv1.DestinationRunResponse, error) { - // note that contexts are named from the perspective of the server, while - // this function is used from the perspective of the client select { - case <-s.sendCtx.Done(): + case <-s.ctx.Done(): + return cpluginv1.DestinationRunResponse{}, cerrors.New(s.ctx.Err().Error()) + case <-s.stopChan: return cpluginv1.DestinationRunResponse{}, s.reason case resp := <-s.respChan: return resp, nil @@ -223,24 +226,24 @@ func (s *destinationRunStream) recvInternal() (cpluginv1.DestinationRunResponse, } func (s *destinationRunStream) sendInternal(req cpluginv1.DestinationRunRequest) error { - // note that contexts are named from the perspective of the server, while - // this function is used from the perspective of the client select { - case <-s.recvCtx.Done(): + case <-s.ctx.Done(): + return cerrors.New(s.ctx.Err().Error()) + case <-s.stopChan: return plugin.ErrStreamNotOpen case s.reqChan <- req: return nil } } -func (s *destinationRunStream) stopAll(reason error) { - s.reason = reason - s.closeSend() - s.closeRecv() -} - -func (s *destinationRunStream) stopSend() { - // we want to stop the stream towards the server, so we close the receiving - // context from the servers perspective - s.closeRecv() +func (s *destinationRunStream) stop(reason error) bool { + select { + case <-s.stopChan: + // channel already closed + return false + default: + s.reason = reason + close(s.stopChan) + return true + } } diff --git a/pkg/plugin/builtin/v1/source.go b/pkg/plugin/builtin/v1/source.go index b3e8caa38..feca3a9f8 100644 --- a/pkg/plugin/builtin/v1/source.go +++ b/pkg/plugin/builtin/v1/source.go @@ -97,7 +97,9 @@ func (s *sourcePluginAdapter) Start(ctx context.Context, p record.Position) erro s.logger.Trace(ctx).Msg("calling Run") err := runSandboxNoResp(s.impl.Run, s.withLogger(ctx), cpluginv1.SourceRunStream(s.stream)) if err != nil { - s.stream.stop(cerrors.Errorf("error in run: %w", err)) + if s.stream.stop(cerrors.Errorf("error in run: %w", err)) { + s.logger.Err(ctx, err).Msg("stream already stopped") + } } else { s.stream.stop(plugin.ErrStreamNotOpen) } @@ -164,6 +166,11 @@ func (s *sourcePluginAdapter) Stop(ctx context.Context) (record.Position, error) } func (s *sourcePluginAdapter) Teardown(ctx context.Context) error { + // TODO stop stream before calling teardown + if s.stream != nil { + s.stream.stop(nil) + } + s.logger.Trace(ctx).Msg("calling Teardown") resp, err := runSandbox(s.impl.Teardown, s.withLogger(ctx), toplugin.SourceTeardownRequest()) if err != nil { @@ -229,15 +236,22 @@ func (s *sourceRunStream) recvInternal() (cpluginv1.SourceRunResponse, error) { func (s *sourceRunStream) sendInternal(req cpluginv1.SourceRunRequest) error { select { case <-s.ctx.Done(): - return cerrors.New(s.ctx.Err().Error()) // TODO should this be s.ctx.Err()? + return cerrors.New(s.ctx.Err().Error()) case <-s.stopChan: - return s.reason + return plugin.ErrStreamNotOpen case s.reqChan <- req: return nil } } -func (s *sourceRunStream) stop(reason error) { - s.reason = reason - close(s.stopChan) +func (s *sourceRunStream) stop(reason error) bool { + select { + case <-s.stopChan: + // channel already closed + return false + default: + s.reason = reason + close(s.stopChan) + return true + } } diff --git a/pkg/plugin/standalone/v1/destination.go b/pkg/plugin/standalone/v1/destination.go index 09f4e2eed..81d37cfd0 100644 --- a/pkg/plugin/standalone/v1/destination.go +++ b/pkg/plugin/standalone/v1/destination.go @@ -124,33 +124,35 @@ func (s *destinationPluginClient) Ack(ctx context.Context) (record.Position, err } func (s *destinationPluginClient) Stop(ctx context.Context, lastPosition record.Position) error { - var errOut error if s.stream == nil { return plugin.ErrStreamNotOpen } - err := s.stream.CloseSend() - if err != nil { - errOut = multierror.Append(errOut, unwrapGRPCError(err)) - } - protoReq := toproto.DestinationStopRequest(lastPosition) protoResp, err := s.grpcClient.Stop(ctx, protoReq) if err != nil { - errOut = multierror.Append(errOut, unwrapGRPCError(err)) + return unwrapGRPCError(err) } _ = protoResp // response is empty - return errOut + return nil } func (s *destinationPluginClient) Teardown(ctx context.Context) error { + var errOut error + if s.stream != nil { + err := s.stream.CloseSend() + if err != nil { + errOut = multierror.Append(errOut, unwrapGRPCError(err)) + } + } + protoReq := toproto.DestinationTeardownRequest() protoResp, err := s.grpcClient.Teardown(ctx, protoReq) if err != nil { - return unwrapGRPCError(err) + errOut = multierror.Append(errOut, unwrapGRPCError(err)) } _ = protoResp // response is empty - return nil + return errOut } diff --git a/pkg/plugin/standalone/v1/source.go b/pkg/plugin/standalone/v1/source.go index 79d0bdbbf..8df343cdb 100644 --- a/pkg/plugin/standalone/v1/source.go +++ b/pkg/plugin/standalone/v1/source.go @@ -115,7 +115,6 @@ func (s *sourcePluginClient) Ack(ctx context.Context, p record.Position) error { err = s.stream.Send(protoReq) if err != nil { - err = s.ackErrorCause(err) if err == io.EOF { // stream was gracefully closed return plugin.ErrStreamNotOpen @@ -125,21 +124,6 @@ func (s *sourcePluginClient) Ack(ctx context.Context, p record.Position) error { return nil } -func (s *sourcePluginClient) ackErrorCause(err error) error { - if err != io.EOF { - // this is an actual error, return it - return err - } - - // actual error can be discovered through Recv, let's do it - _, recvErr := s.stream.Recv() - if recvErr == nil { - // Recv did not return an error, we just read a record, that's a huge bug! - panic(cerrors.Errorf("tried to get error cause of Ack, read a record instead, this is a bug! original error: %w", err)) - } - return recvErr -} - func (s *sourcePluginClient) Stop(ctx context.Context) (record.Position, error) { if s.stream == nil { return nil, plugin.ErrStreamNotOpen From 03cbb0cf6a61dd0083f26811c44e60597ef63c2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Wed, 29 Jun 2022 17:06:14 +0200 Subject: [PATCH 25/46] refactor stream, reuse it in source and destination --- pkg/plugin/builtin/v1/destination.go | 84 +++------------------- pkg/plugin/builtin/v1/source.go | 84 +++------------------- pkg/plugin/builtin/v1/stream.go | 103 +++++++++++++++++++++++++++ 3 files changed, 119 insertions(+), 152 deletions(-) create mode 100644 pkg/plugin/builtin/v1/stream.go diff --git a/pkg/plugin/builtin/v1/destination.go b/pkg/plugin/builtin/v1/destination.go index c01fd3bc8..38e83cf4a 100644 --- a/pkg/plugin/builtin/v1/destination.go +++ b/pkg/plugin/builtin/v1/destination.go @@ -16,8 +16,6 @@ package builtinv1 import ( "context" - "io" - "sync" "github.com/conduitio/conduit-connector-protocol/cpluginv1" "github.com/conduitio/conduit/pkg/foundation/cerrors" @@ -41,7 +39,7 @@ type destinationPluginAdapter struct { // ctxLogger is attached to the context of each call to the plugin. ctxLogger zerolog.Logger - stream *destinationRunStream + stream *stream[cpluginv1.DestinationRunRequest, cpluginv1.DestinationRunResponse] } var _ plugin.DestinationPlugin = (*destinationPluginAdapter)(nil) @@ -92,11 +90,11 @@ func (s *destinationPluginAdapter) Start(ctx context.Context) error { s.logger.Trace(ctx).Msg("calling Run") err := runSandboxNoResp(s.impl.Run, s.withLogger(ctx), cpluginv1.DestinationRunStream(s.stream)) if err != nil { - if s.stream.stop(cerrors.Errorf("error in run: %w", err)) { + if s.stream.Stop(cerrors.Errorf("error in run: %w", err)) { s.logger.Err(ctx, err).Msg("stream already stopped") } } else { - s.stream.stop(plugin.ErrStreamNotOpen) + s.stream.Stop(plugin.ErrStreamNotOpen) } s.logger.Trace(ctx).Msg("Run stopped") }() @@ -115,7 +113,7 @@ func (s *destinationPluginAdapter) Write(ctx context.Context, r record.Record) ( } s.logger.Trace(ctx).Msg("sending record") - err = s.stream.sendInternal(req) + err = s.stream.SendInternal(req) if err != nil { return cerrors.Errorf("builtin plugin send failed: %w", err) } @@ -129,7 +127,7 @@ func (s *destinationPluginAdapter) Ack(ctx context.Context) (record.Position, er } s.logger.Trace(ctx).Msg("receiving ack") - resp, err := s.stream.recvInternal() + resp, err := s.stream.RecvInternal() if err != nil { return nil, err } @@ -160,7 +158,7 @@ func (s *destinationPluginAdapter) Stop(ctx context.Context, lastPosition record func (s *destinationPluginAdapter) Teardown(ctx context.Context) error { if s.stream != nil { // stop stream if it's open - _ = s.stream.stop(nil) + _ = s.stream.Stop(nil) } s.logger.Trace(ctx).Msg("calling Teardown") @@ -173,77 +171,11 @@ func (s *destinationPluginAdapter) Teardown(ctx context.Context) error { return nil } -func newDestinationRunStream(ctx context.Context) *destinationRunStream { - return &destinationRunStream{ +func newDestinationRunStream(ctx context.Context) *stream[cpluginv1.DestinationRunRequest, cpluginv1.DestinationRunResponse] { + return &stream[cpluginv1.DestinationRunRequest, cpluginv1.DestinationRunResponse]{ ctx: ctx, stopChan: make(chan struct{}), reqChan: make(chan cpluginv1.DestinationRunRequest), respChan: make(chan cpluginv1.DestinationRunResponse), } } - -type destinationRunStream struct { - ctx context.Context - stopChan chan struct{} - reqChan chan cpluginv1.DestinationRunRequest - respChan chan cpluginv1.DestinationRunResponse - - reason error - m sync.RWMutex -} - -func (s *destinationRunStream) Send(resp cpluginv1.DestinationRunResponse) error { - select { - case <-s.ctx.Done(): - return s.ctx.Err() - case <-s.stopChan: - return io.EOF - case s.respChan <- resp: - return nil - } -} - -func (s *destinationRunStream) Recv() (cpluginv1.DestinationRunRequest, error) { - select { - case <-s.ctx.Done(): - return cpluginv1.DestinationRunRequest{}, s.ctx.Err() - case <-s.stopChan: - return cpluginv1.DestinationRunRequest{}, io.EOF - case req := <-s.reqChan: - return req, nil - } -} - -func (s *destinationRunStream) recvInternal() (cpluginv1.DestinationRunResponse, error) { - select { - case <-s.ctx.Done(): - return cpluginv1.DestinationRunResponse{}, cerrors.New(s.ctx.Err().Error()) - case <-s.stopChan: - return cpluginv1.DestinationRunResponse{}, s.reason - case resp := <-s.respChan: - return resp, nil - } -} - -func (s *destinationRunStream) sendInternal(req cpluginv1.DestinationRunRequest) error { - select { - case <-s.ctx.Done(): - return cerrors.New(s.ctx.Err().Error()) - case <-s.stopChan: - return plugin.ErrStreamNotOpen - case s.reqChan <- req: - return nil - } -} - -func (s *destinationRunStream) stop(reason error) bool { - select { - case <-s.stopChan: - // channel already closed - return false - default: - s.reason = reason - close(s.stopChan) - return true - } -} diff --git a/pkg/plugin/builtin/v1/source.go b/pkg/plugin/builtin/v1/source.go index feca3a9f8..88eb3266d 100644 --- a/pkg/plugin/builtin/v1/source.go +++ b/pkg/plugin/builtin/v1/source.go @@ -16,8 +16,6 @@ package builtinv1 import ( "context" - "io" - "sync" "github.com/conduitio/conduit-connector-protocol/cpluginv1" "github.com/conduitio/conduit/pkg/foundation/cerrors" @@ -41,7 +39,7 @@ type sourcePluginAdapter struct { // ctxLogger is attached to the context of each call to the plugin. ctxLogger zerolog.Logger - stream *sourceRunStream + stream *stream[cpluginv1.SourceRunRequest, cpluginv1.SourceRunResponse] } var _ plugin.SourcePlugin = (*sourcePluginAdapter)(nil) @@ -97,11 +95,11 @@ func (s *sourcePluginAdapter) Start(ctx context.Context, p record.Position) erro s.logger.Trace(ctx).Msg("calling Run") err := runSandboxNoResp(s.impl.Run, s.withLogger(ctx), cpluginv1.SourceRunStream(s.stream)) if err != nil { - if s.stream.stop(cerrors.Errorf("error in run: %w", err)) { + if s.stream.Stop(cerrors.Errorf("error in run: %w", err)) { s.logger.Err(ctx, err).Msg("stream already stopped") } } else { - s.stream.stop(plugin.ErrStreamNotOpen) + s.stream.Stop(plugin.ErrStreamNotOpen) } s.logger.Trace(ctx).Msg("Run stopped") }() @@ -115,7 +113,7 @@ func (s *sourcePluginAdapter) Read(ctx context.Context) (record.Record, error) { } s.logger.Trace(ctx).Msg("receiving record") - resp, err := s.stream.recvInternal() + resp, err := s.stream.RecvInternal() if err != nil { return record.Record{}, cerrors.Errorf("builtin plugin receive failed: %w", err) } @@ -139,7 +137,7 @@ func (s *sourcePluginAdapter) Ack(ctx context.Context, p record.Position) error } s.logger.Trace(ctx).Msg("sending ack") - err = s.stream.sendInternal(req) + err = s.stream.SendInternal(req) if err != nil { return cerrors.Errorf("builtin plugin send failed: %w", err) } @@ -168,7 +166,7 @@ func (s *sourcePluginAdapter) Stop(ctx context.Context) (record.Position, error) func (s *sourcePluginAdapter) Teardown(ctx context.Context) error { // TODO stop stream before calling teardown if s.stream != nil { - s.stream.stop(nil) + s.stream.Stop(nil) } s.logger.Trace(ctx).Msg("calling Teardown") @@ -181,77 +179,11 @@ func (s *sourcePluginAdapter) Teardown(ctx context.Context) error { return nil } -func newSourceRunStream(ctx context.Context) *sourceRunStream { - return &sourceRunStream{ +func newSourceRunStream(ctx context.Context) *stream[cpluginv1.SourceRunRequest, cpluginv1.SourceRunResponse] { + return &stream[cpluginv1.SourceRunRequest, cpluginv1.SourceRunResponse]{ ctx: ctx, stopChan: make(chan struct{}), reqChan: make(chan cpluginv1.SourceRunRequest), respChan: make(chan cpluginv1.SourceRunResponse), } } - -type sourceRunStream struct { - ctx context.Context - stopChan chan struct{} - reqChan chan cpluginv1.SourceRunRequest - respChan chan cpluginv1.SourceRunResponse - - reason error - m sync.RWMutex -} - -func (s *sourceRunStream) Send(resp cpluginv1.SourceRunResponse) error { - select { - case <-s.ctx.Done(): - return s.ctx.Err() - case <-s.stopChan: - return io.EOF - case s.respChan <- resp: - return nil - } -} - -func (s *sourceRunStream) Recv() (cpluginv1.SourceRunRequest, error) { - select { - case <-s.ctx.Done(): - return cpluginv1.SourceRunRequest{}, s.ctx.Err() - case <-s.stopChan: - return cpluginv1.SourceRunRequest{}, io.EOF - case req := <-s.reqChan: - return req, nil - } -} - -func (s *sourceRunStream) recvInternal() (cpluginv1.SourceRunResponse, error) { - select { - case <-s.ctx.Done(): - return cpluginv1.SourceRunResponse{}, cerrors.New(s.ctx.Err().Error()) - case <-s.stopChan: - return cpluginv1.SourceRunResponse{}, s.reason - case resp := <-s.respChan: - return resp, nil - } -} - -func (s *sourceRunStream) sendInternal(req cpluginv1.SourceRunRequest) error { - select { - case <-s.ctx.Done(): - return cerrors.New(s.ctx.Err().Error()) - case <-s.stopChan: - return plugin.ErrStreamNotOpen - case s.reqChan <- req: - return nil - } -} - -func (s *sourceRunStream) stop(reason error) bool { - select { - case <-s.stopChan: - // channel already closed - return false - default: - s.reason = reason - close(s.stopChan) - return true - } -} diff --git a/pkg/plugin/builtin/v1/stream.go b/pkg/plugin/builtin/v1/stream.go new file mode 100644 index 000000000..ebb9b2bf7 --- /dev/null +++ b/pkg/plugin/builtin/v1/stream.go @@ -0,0 +1,103 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package builtinv1 + +import ( + "context" + "io" + "sync" + + "github.com/conduitio/conduit/pkg/foundation/cerrors" + "github.com/conduitio/conduit/pkg/plugin" +) + +// stream mimics the behavior of a gRPC stream using channels. +// REQ represents the type sent from the client to the server, RES is the type +// sent from the server to the client. +type stream[REQ any, RES any] struct { + ctx context.Context + stopChan chan struct{} + reqChan chan REQ + respChan chan RES + + reason error + m sync.RWMutex +} + +func (s *stream[REQ, RES]) Send(resp RES) error { + select { + case <-s.ctx.Done(): + return s.ctx.Err() + case <-s.stopChan: + return io.EOF + case s.respChan <- resp: + return nil + } +} + +func (s *stream[REQ, RES]) Recv() (REQ, error) { + select { + case <-s.ctx.Done(): + return s.emptyReq(), s.ctx.Err() + case <-s.stopChan: + return s.emptyReq(), io.EOF + case req := <-s.reqChan: + return req, nil + } +} + +func (s *stream[REQ, RES]) RecvInternal() (RES, error) { + select { + case <-s.ctx.Done(): + return s.emptyRes(), cerrors.New(s.ctx.Err().Error()) + case <-s.stopChan: + return s.emptyRes(), s.reason + case resp := <-s.respChan: + return resp, nil + } +} + +func (s *stream[REQ, RES]) SendInternal(req REQ) error { + select { + case <-s.ctx.Done(): + return cerrors.New(s.ctx.Err().Error()) + case <-s.stopChan: + return plugin.ErrStreamNotOpen + case s.reqChan <- req: + return nil + } +} + +func (s *stream[REQ, RES]) Stop(reason error) bool { + select { + case <-s.stopChan: + // channel already closed + return false + default: + s.reason = reason + close(s.stopChan) + return true + } +} + +func (s *stream[REQ, RES]) emptyReq() REQ { + var r REQ + return r +} + +func (s *stream[REQ, RES]) emptyRes() RES { + var r RES + return r +} From eb69500ac0d232911167d9c30fed822da526a665 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Wed, 29 Jun 2022 17:09:39 +0200 Subject: [PATCH 26/46] lock stream when stopping --- pkg/plugin/builtin/v1/destination.go | 10 +++++----- pkg/plugin/builtin/v1/source.go | 11 +++++------ pkg/plugin/builtin/v1/stream.go | 12 +++++++----- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/pkg/plugin/builtin/v1/destination.go b/pkg/plugin/builtin/v1/destination.go index 38e83cf4a..f4d1c9b84 100644 --- a/pkg/plugin/builtin/v1/destination.go +++ b/pkg/plugin/builtin/v1/destination.go @@ -90,11 +90,11 @@ func (s *destinationPluginAdapter) Start(ctx context.Context) error { s.logger.Trace(ctx).Msg("calling Run") err := runSandboxNoResp(s.impl.Run, s.withLogger(ctx), cpluginv1.DestinationRunStream(s.stream)) if err != nil { - if s.stream.Stop(cerrors.Errorf("error in run: %w", err)) { + if s.stream.stop(cerrors.Errorf("error in run: %w", err)) { s.logger.Err(ctx, err).Msg("stream already stopped") } } else { - s.stream.Stop(plugin.ErrStreamNotOpen) + s.stream.stop(plugin.ErrStreamNotOpen) } s.logger.Trace(ctx).Msg("Run stopped") }() @@ -113,7 +113,7 @@ func (s *destinationPluginAdapter) Write(ctx context.Context, r record.Record) ( } s.logger.Trace(ctx).Msg("sending record") - err = s.stream.SendInternal(req) + err = s.stream.sendInternal(req) if err != nil { return cerrors.Errorf("builtin plugin send failed: %w", err) } @@ -127,7 +127,7 @@ func (s *destinationPluginAdapter) Ack(ctx context.Context) (record.Position, er } s.logger.Trace(ctx).Msg("receiving ack") - resp, err := s.stream.RecvInternal() + resp, err := s.stream.recvInternal() if err != nil { return nil, err } @@ -158,7 +158,7 @@ func (s *destinationPluginAdapter) Stop(ctx context.Context, lastPosition record func (s *destinationPluginAdapter) Teardown(ctx context.Context) error { if s.stream != nil { // stop stream if it's open - _ = s.stream.Stop(nil) + _ = s.stream.stop(nil) } s.logger.Trace(ctx).Msg("calling Teardown") diff --git a/pkg/plugin/builtin/v1/source.go b/pkg/plugin/builtin/v1/source.go index 88eb3266d..33ea89ad9 100644 --- a/pkg/plugin/builtin/v1/source.go +++ b/pkg/plugin/builtin/v1/source.go @@ -95,11 +95,11 @@ func (s *sourcePluginAdapter) Start(ctx context.Context, p record.Position) erro s.logger.Trace(ctx).Msg("calling Run") err := runSandboxNoResp(s.impl.Run, s.withLogger(ctx), cpluginv1.SourceRunStream(s.stream)) if err != nil { - if s.stream.Stop(cerrors.Errorf("error in run: %w", err)) { + if s.stream.stop(cerrors.Errorf("error in run: %w", err)) { s.logger.Err(ctx, err).Msg("stream already stopped") } } else { - s.stream.Stop(plugin.ErrStreamNotOpen) + s.stream.stop(plugin.ErrStreamNotOpen) } s.logger.Trace(ctx).Msg("Run stopped") }() @@ -113,7 +113,7 @@ func (s *sourcePluginAdapter) Read(ctx context.Context) (record.Record, error) { } s.logger.Trace(ctx).Msg("receiving record") - resp, err := s.stream.RecvInternal() + resp, err := s.stream.recvInternal() if err != nil { return record.Record{}, cerrors.Errorf("builtin plugin receive failed: %w", err) } @@ -137,7 +137,7 @@ func (s *sourcePluginAdapter) Ack(ctx context.Context, p record.Position) error } s.logger.Trace(ctx).Msg("sending ack") - err = s.stream.SendInternal(req) + err = s.stream.sendInternal(req) if err != nil { return cerrors.Errorf("builtin plugin send failed: %w", err) } @@ -164,9 +164,8 @@ func (s *sourcePluginAdapter) Stop(ctx context.Context) (record.Position, error) } func (s *sourcePluginAdapter) Teardown(ctx context.Context) error { - // TODO stop stream before calling teardown if s.stream != nil { - s.stream.Stop(nil) + s.stream.stop(nil) } s.logger.Trace(ctx).Msg("calling Teardown") diff --git a/pkg/plugin/builtin/v1/stream.go b/pkg/plugin/builtin/v1/stream.go index ebb9b2bf7..6b287ac93 100644 --- a/pkg/plugin/builtin/v1/stream.go +++ b/pkg/plugin/builtin/v1/stream.go @@ -28,12 +28,12 @@ import ( // sent from the server to the client. type stream[REQ any, RES any] struct { ctx context.Context - stopChan chan struct{} reqChan chan REQ respChan chan RES + stopChan chan struct{} reason error - m sync.RWMutex + m sync.Mutex } func (s *stream[REQ, RES]) Send(resp RES) error { @@ -58,7 +58,7 @@ func (s *stream[REQ, RES]) Recv() (REQ, error) { } } -func (s *stream[REQ, RES]) RecvInternal() (RES, error) { +func (s *stream[REQ, RES]) recvInternal() (RES, error) { select { case <-s.ctx.Done(): return s.emptyRes(), cerrors.New(s.ctx.Err().Error()) @@ -69,7 +69,7 @@ func (s *stream[REQ, RES]) RecvInternal() (RES, error) { } } -func (s *stream[REQ, RES]) SendInternal(req REQ) error { +func (s *stream[REQ, RES]) sendInternal(req REQ) error { select { case <-s.ctx.Done(): return cerrors.New(s.ctx.Err().Error()) @@ -80,7 +80,9 @@ func (s *stream[REQ, RES]) SendInternal(req REQ) error { } } -func (s *stream[REQ, RES]) Stop(reason error) bool { +func (s *stream[REQ, RES]) stop(reason error) bool { + s.m.Lock() + defer s.m.Unlock() select { case <-s.stopChan: // channel already closed From ef069901bbe2ccbe3ff50bca8a5a196971150e08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Thu, 30 Jun 2022 17:00:15 +0200 Subject: [PATCH 27/46] create control message for source stop --- pkg/pipeline/stream/base.go | 23 +++++++++++-- pkg/pipeline/stream/destination.go | 2 ++ pkg/pipeline/stream/message.go | 18 ++++++++-- pkg/pipeline/stream/source.go | 51 +++++++++++++++++++++++----- pkg/pipeline/stream/stream_test.go | 3 -- pkg/plugin/builtin/v1/destination.go | 7 ++-- pkg/plugin/builtin/v1/source.go | 11 +++--- 7 files changed, 90 insertions(+), 25 deletions(-) diff --git a/pkg/pipeline/stream/base.go b/pkg/pipeline/stream/base.go index 235105efe..d75f78156 100644 --- a/pkg/pipeline/stream/base.go +++ b/pkg/pipeline/stream/base.go @@ -95,6 +95,9 @@ type pubNodeBase struct { running bool // lock guards private fields from concurrent changes. lock sync.Mutex + + // msgChan is an internal channel where messages from msgFetcher are collected + msgChan chan *Message } // Trigger sets up 2 goroutines, one that listens to the external error channel @@ -122,8 +125,8 @@ func (n *pubNodeBase) Trigger( } n.running = true + n.msgChan = make(chan *Message) internalErrChan := make(chan error) - msgChan := make(chan *Message) if externalErrChan != nil { // spawn goroutine that forwards external errors into the internal error @@ -150,13 +153,13 @@ func (n *pubNodeBase) Trigger( internalErrChan <- err return } - msgChan <- msg + n.msgChan <- msg } }() } trigger := func() (*Message, error) { - return n.nodeBase.Receive(ctx, logger, msgChan, internalErrChan) + return n.nodeBase.Receive(ctx, logger, n.msgChan, internalErrChan) } cleanup := func() { // TODO make sure spawned goroutines are stopped and internal channels @@ -167,6 +170,20 @@ func (n *pubNodeBase) Trigger( return trigger, cleanup, nil } +// InjectMessage can be used to inject a message into the message stream. This +// is used to inject control messages like the last position message when +// stopping a source connector. It is a bit hacky, but it doesn't require us to +// create a separate channel for signals which makes it performant and easiest +// to implement. +func (n *pubNodeBase) InjectMessage(ctx context.Context, message *Message) error { + select { + case <-ctx.Done(): + return ctx.Err() + case n.msgChan <- message: + return nil + } +} + func (n *pubNodeBase) cleanup(ctx context.Context, logger log.CtxLogger) { n.lock.Lock() defer n.lock.Unlock() diff --git a/pkg/pipeline/stream/destination.go b/pkg/pipeline/stream/destination.go index 23f775b3a..2c6eb850e 100644 --- a/pkg/pipeline/stream/destination.go +++ b/pkg/pipeline/stream/destination.go @@ -52,6 +52,8 @@ func (n *DestinationNode) Run(ctx context.Context) (err error) { return cerrors.Errorf("could not open destination connector: %w", err) } defer func() { + // TODO stop destination before teardown + // wait for acker node to receive all outstanding acks, time out after // 1 minute or right away if the context is already canceled. waitCtx, cancel := context.WithTimeout(ctx, time.Minute) diff --git a/pkg/pipeline/stream/message.go b/pkg/pipeline/stream/message.go index f8f917e5f..1ab6adb32 100644 --- a/pkg/pipeline/stream/message.go +++ b/pkg/pipeline/stream/message.go @@ -26,8 +26,13 @@ import ( "github.com/conduitio/conduit/pkg/record" ) -// MessageStatus represents the state of the message (acked, nacked or open). -type MessageStatus int +type ( + // MessageStatus represents the state of the message (acked, nacked or open). + MessageStatus int + + // ControlMessageType represents the type of a control message. + ControlMessageType string +) const ( MessageStatusAcked MessageStatus = iota @@ -49,6 +54,11 @@ type Message struct { // Record represents a single record attached to the message. Record record.Record + // controlMessageType is only populated for control messages. Control + // messages are special messages injected into the message stream that can + // change the behavior of a node and don't need to be acked/nacked. + controlMessageType ControlMessageType + // acked and nacked and are channels used to capture acks and nacks. When a // message is acked or nacked the corresponding channel is closed. acked chan struct{} @@ -114,6 +124,10 @@ func (m *Message) ID() string { return fmt.Sprintf("%s/%s", m.Record.SourceID, m.Record.Position) } +func (m *Message) ControlMessageType() ControlMessageType { + return m.controlMessageType +} + // RegisterStatusHandler is used to register a function that will be called on // any status change of the message. This function can only be called if the // message status is open, otherwise it panics. Handlers are called in the diff --git a/pkg/pipeline/stream/source.go b/pkg/pipeline/stream/source.go index a182a0dad..0c37cecbc 100644 --- a/pkg/pipeline/stream/source.go +++ b/pkg/pipeline/stream/source.go @@ -15,6 +15,7 @@ package stream import ( + "bytes" "context" "sync" "time" @@ -23,7 +24,11 @@ import ( "github.com/conduitio/conduit/pkg/foundation/cerrors" "github.com/conduitio/conduit/pkg/foundation/log" "github.com/conduitio/conduit/pkg/foundation/metrics" - "github.com/conduitio/conduit/pkg/plugin" + "github.com/conduitio/conduit/pkg/record" +) + +const ( + ControlMessageStopSourceNode ControlMessageType = "stop-source-node" ) // SourceNode wraps a Source connector and implements the Pub node interface @@ -91,14 +96,28 @@ func (n *SourceNode) Run(ctx context.Context) (err error) { } defer cleanup() + var ( + // when source node encounters the record with this position it needs to + // stop retrieving new records + stopPosition record.Position + // last processed position is stored in this position + lastPosition record.Position + ) + for { msg, err := trigger() if err != nil || msg == nil { - if cerrors.Is(err, plugin.ErrStreamNotOpen) { - // node was stopped gracefully, return stop reason + return cerrors.Errorf("source stream was stopped unexpectedly: %w", err) + } + + if msg.ControlMessageType() == ControlMessageStopSourceNode { + // this is a control message telling us to stop + stopPosition = msg.Record.Position + if bytes.Equal(stopPosition, lastPosition) { + // we already encountered the record with the last position return n.stopReason } - return err + continue } // register another open message @@ -106,8 +125,7 @@ func (n *SourceNode) Run(ctx context.Context) (err error) { msg.RegisterStatusHandler( func(msg *Message, change StatusChange) error { // this is the last handler to be executed, once this handler is - // reached we know either the message was successfully acked, nacked - // or dropped + // reached we know either the message was either acked or nacked defer n.PipelineTimer.Update(time.Since(msg.Record.ReadAt)) defer wgOpenMessages.Done() return nil @@ -118,15 +136,30 @@ func (n *SourceNode) Run(ctx context.Context) (err error) { if err != nil { return msg.Nack(err) } + + lastPosition = msg.Record.Position + if bytes.Equal(stopPosition, lastPosition) { + // it's the last record that we are supposed to process, stop here + return n.stopReason + } } } func (n *SourceNode) Stop(ctx context.Context, reason error) error { n.logger.Err(ctx, reason).Msg("stopping source connector") - n.stopReason = reason lastPosition, err := n.Source.Stop(ctx) - _ = lastPosition // TODO store lastPosition - return err + if err != nil { + return cerrors.Errorf("failed to stop source connector: %w", err) + } + + n.stopReason = reason + // InjectMessage will inject a message into the stream of messages being + // processed by SourceNode to let it know when it should stop processing new + // messages. + return n.base.InjectMessage(ctx, &Message{ + Record: record.Record{Position: lastPosition}, + controlMessageType: ControlMessageStopSourceNode, + }) } func (n *SourceNode) Pub() <-chan *Message { diff --git a/pkg/pipeline/stream/stream_test.go b/pkg/pipeline/stream/stream_test.go index 70426ec79..af6153ee1 100644 --- a/pkg/pipeline/stream/stream_test.go +++ b/pkg/pipeline/stream/stream_test.go @@ -109,7 +109,6 @@ func Example_simpleStream() { // DBG got record message_id=p/generator-10 node_id=printer // DBG received ack message_id=p/generator-10 node_id=generator // INF stopping source connector component=SourceNode node_id=generator - // DBG received error on error channel error="error reading from source: stream not open" component=SourceNode node_id=generator // DBG incoming messages channel closed component=SourceAckerNode node_id=generator-acker // DBG incoming messages channel closed component=DestinationNode node_id=printer // INF finished successfully @@ -274,8 +273,6 @@ func Example_complexStream() { // INF stopping source connector component=SourceNode node_id=generator2 // DBG incoming messages channel closed component=SourceAckerNode node_id=generator1-acker // DBG incoming messages channel closed component=SourceAckerNode node_id=generator2-acker - // DBG received error on error channel error="error reading from source: stream not open" component=SourceNode node_id=generator1 - // DBG received error on error channel error="error reading from source: stream not open" component=SourceNode node_id=generator2 // DBG incoming messages channel closed component=ProcessorNode node_id=counter // DBG incoming messages channel closed component=DestinationNode node_id=printer2 // DBG incoming messages channel closed component=DestinationNode node_id=printer1 diff --git a/pkg/plugin/builtin/v1/destination.go b/pkg/plugin/builtin/v1/destination.go index f4d1c9b84..09823bc81 100644 --- a/pkg/plugin/builtin/v1/destination.go +++ b/pkg/plugin/builtin/v1/destination.go @@ -29,9 +29,10 @@ import ( // destinationPluginAdapter implements the destination plugin interface used // internally in Conduit and relays the calls to a destination plugin defined in -// conduit-connector-protocol (cpluginv1). This adapter needs to make sure it behaves in the -// same way as the standalone plugin adapter, which communicates with the plugin -// through gRPC, so that the caller can use both of them interchangeably. +// conduit-connector-protocol (cpluginv1). This adapter needs to make sure it +// behaves in the same way as the standalone plugin adapter, which communicates +// with the plugin through gRPC, so that the caller can use both of them +// interchangeably. type destinationPluginAdapter struct { impl cpluginv1.DestinationPlugin // logger is used as the internal logger of destinationPluginAdapter. diff --git a/pkg/plugin/builtin/v1/source.go b/pkg/plugin/builtin/v1/source.go index 33ea89ad9..e290fb945 100644 --- a/pkg/plugin/builtin/v1/source.go +++ b/pkg/plugin/builtin/v1/source.go @@ -27,11 +27,12 @@ import ( "github.com/rs/zerolog" ) -// sourcePluginAdapter implements the source plugin interface used -// internally in Conduit and relays the calls to a source plugin defined in -// conduit-connector-protocol (cpluginv1). This adapter needs to make sure it behaves in the -// same way as the standalone plugin adapter, which communicates with the plugin -// through gRPC, so that the caller can use both of them interchangeably. +// sourcePluginAdapter implements the source plugin interface used internally in +// Conduit and relays the calls to a source plugin defined in +// conduit-connector-protocol (cpluginv1). This adapter needs to make sure it +// behaves in the same way as the standalone plugin adapter, which communicates +// with the plugin through gRPC, so that the caller can use both of them +// interchangeably. type sourcePluginAdapter struct { impl cpluginv1.SourcePlugin // logger is used as the internal logger of sourcePluginAdapter. From 9e50b1fd13cca1da99661915638e298590fe0cd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Thu, 30 Jun 2022 17:28:10 +0200 Subject: [PATCH 28/46] forward last position to destination --- pkg/pipeline/stream/destination.go | 16 +++++++++++++++- pkg/pipeline/stream/stream_test.go | 16 ++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/pkg/pipeline/stream/destination.go b/pkg/pipeline/stream/destination.go index 2c6eb850e..3574923de 100644 --- a/pkg/pipeline/stream/destination.go +++ b/pkg/pipeline/stream/destination.go @@ -22,6 +22,7 @@ import ( "github.com/conduitio/conduit/pkg/foundation/cerrors" "github.com/conduitio/conduit/pkg/foundation/log" "github.com/conduitio/conduit/pkg/foundation/metrics" + "github.com/conduitio/conduit/pkg/record" ) // DestinationNode wraps a Destination connector and implements the Sub node interface @@ -51,8 +52,20 @@ func (n *DestinationNode) Run(ctx context.Context) (err error) { if err != nil { return cerrors.Errorf("could not open destination connector: %w", err) } + + // lastPosition stores the position of the last successfully processed record + var lastPosition record.Position defer func() { - // TODO stop destination before teardown + stopErr := n.Destination.Stop(connectorCtx, lastPosition) + if stopErr != nil { + // log this error right away because we're not sure the connector + // will be able to stop right away, we might block for 1 minute + // waiting for acks and we don't want the log to be empty + n.logger.Err(ctx, err).Msg("could not stop destination connector") + if err == nil { + err = stopErr + } + } // wait for acker node to receive all outstanding acks, time out after // 1 minute or right away if the context is already canceled. @@ -106,6 +119,7 @@ func (n *DestinationNode) Run(ctx context.Context) (err error) { _ = msg.Nack(err) return cerrors.Errorf("error writing to destination: %w", err) } + lastPosition = msg.Record.Position n.ConnectorTimer.Update(time.Since(writeTime)) } } diff --git a/pkg/pipeline/stream/stream_test.go b/pkg/pipeline/stream/stream_test.go index af6153ee1..8cd20749a 100644 --- a/pkg/pipeline/stream/stream_test.go +++ b/pkg/pipeline/stream/stream_test.go @@ -332,6 +332,7 @@ func generatorSource(ctrl *gomock.Controller, logger log.CtxLogger, nodeID strin } func printerDestination(ctrl *gomock.Controller, logger log.CtxLogger, nodeID string) connector.Destination { + var lastPosition record.Position rchan := make(chan record.Record) destination := connmock.NewDestination(ctrl) destination.EXPECT().Open(gomock.Any()).Return(nil).Times(1) @@ -339,6 +340,7 @@ func printerDestination(ctrl *gomock.Controller, logger log.CtxLogger, nodeID st logger.Debug(ctx). Str("node_id", nodeID). Msg("got record") + lastPosition = r.Position rchan <- r return nil }).AnyTimes() @@ -353,6 +355,7 @@ func printerDestination(ctrl *gomock.Controller, logger log.CtxLogger, nodeID st return r.Position, nil } }).AnyTimes() + destination.EXPECT().Stop(gomock.Any(), EqLazy(func() interface{} { return lastPosition })).Return(nil).Times(1) destination.EXPECT().Teardown(gomock.Any()).DoAndReturn(func(ctx context.Context) error { close(rchan) return nil @@ -405,3 +408,16 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { return true // timed out } } + +func EqLazy(x func() interface{}) gomock.Matcher { return eqMatcherLazy{x} } + +type eqMatcherLazy struct { + x func() interface{} +} + +func (e eqMatcherLazy) Matches(x interface{}) bool { + return gomock.Eq(e.x()).Matches(x) +} +func (e eqMatcherLazy) String() string { + return gomock.Eq(e.x()).String() +} From c46310df9b2bde1557fc9568a09b9dfd0d323fa1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Thu, 30 Jun 2022 19:51:07 +0200 Subject: [PATCH 29/46] update connector SDK, fix race condition in source node --- go.mod | 9 +++++---- go.sum | 22 +++++++++++----------- pkg/pipeline/stream/destination_acker.go | 4 ++-- pkg/pipeline/stream/source.go | 16 +++++++--------- 4 files changed, 25 insertions(+), 26 deletions(-) diff --git a/go.mod b/go.mod index d31ba4401..1bf0cee90 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/conduitio/conduit-connector-postgres v0.1.0 github.com/conduitio/conduit-connector-protocol v0.2.1-0.20220608133528-f466a956bd4d github.com/conduitio/conduit-connector-s3 v0.1.1 - github.com/conduitio/conduit-connector-sdk v0.2.0 + github.com/conduitio/conduit-connector-sdk v0.2.1-0.20220622151135-47f1a8905435 github.com/dgraph-io/badger/v3 v3.2103.2 github.com/dop251/goja v0.0.0-20210225094849-f3cfc97811c0 github.com/golang/mock v1.6.0 @@ -33,7 +33,7 @@ require ( go.buf.build/library/go-grpc/conduitio/conduit-connector-protocol v1.4.2 golang.org/x/tools v0.1.11 golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df - google.golang.org/genproto v0.0.0-20220617124728-180714bec0ad + google.golang.org/genproto v0.0.0-20220630150403-404d0664e509 google.golang.org/grpc v1.47.0 google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.2.0 google.golang.org/protobuf v1.28.0 @@ -122,13 +122,14 @@ require ( github.com/xitongsys/parquet-go-source v0.0.0-20220315005136-aec0fe3e777c // indirect go.opencensus.io v0.23.0 // indirect go.uber.org/atomic v1.9.0 // indirect + go.uber.org/goleak v1.1.12 // indirect go.uber.org/multierr v1.8.0 // indirect go.uber.org/zap v1.21.0 // indirect golang.org/x/crypto v0.0.0-20220511200225-c6db032c6c88 // indirect golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect - golang.org/x/net v0.0.0-20220617184016-355a448f1bc9 // indirect + golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e // indirect golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f // indirect - golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c // indirect + golang.org/x/sys v0.0.0-20220627191245-f75cf1eec38b // indirect golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 // indirect golang.org/x/text v0.3.7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect diff --git a/go.sum b/go.sum index 92d9402d0..85a2215a0 100644 --- a/go.sum +++ b/go.sum @@ -155,14 +155,12 @@ github.com/conduitio/conduit-connector-kafka v0.1.1 h1:RgN4nafEWjpA4VvXLdSQBNrEO github.com/conduitio/conduit-connector-kafka v0.1.1/go.mod h1:+CbMUq4fIMxFnrINtjxuVTW5TZYa549WJXQFb63GaIU= github.com/conduitio/conduit-connector-postgres v0.1.0 h1:Dj2S1NrwnJaUOgQqb9MjGSl2vv2gre0mSFE2Ne/5OSE= github.com/conduitio/conduit-connector-postgres v0.1.0/go.mod h1:ug4N+2pGKDbG5UN++w7xRqb0A5Ua2J5Ld5wUzLbU1Q0= -github.com/conduitio/conduit-connector-protocol v0.2.0 h1:gwYXVKEMgTtU67ephQ5WwTGIDbT/eTLA9Mdr9Bnbqxc= -github.com/conduitio/conduit-connector-protocol v0.2.0/go.mod h1:udCU2AkLcYQoLjAO06tHVL2iFJPw+DamK+wllnj50hk= github.com/conduitio/conduit-connector-protocol v0.2.1-0.20220608133528-f466a956bd4d h1:f3R0yPiH45hDZwNcYMSzKJP6LOGQPELCqW9OkZmd2lA= github.com/conduitio/conduit-connector-protocol v0.2.1-0.20220608133528-f466a956bd4d/go.mod h1:1nmTaD+l3mvq3PnMmPPx8UxHPM53Xk8zGT3URu2Xx2M= github.com/conduitio/conduit-connector-s3 v0.1.1 h1:10uIakNmF65IN5TNJB1qPWC6vbdGgrHEMg8r+dxDrc8= github.com/conduitio/conduit-connector-s3 v0.1.1/go.mod h1:xpfBzOGjZkkglTmF1444qEjXuEx+do1PTYZNroPFcSE= -github.com/conduitio/conduit-connector-sdk v0.2.0 h1:yReJT3SOAGqJIlk59WC5FPgpv0Gg+NG4NFj6FJ89XnM= -github.com/conduitio/conduit-connector-sdk v0.2.0/go.mod h1:zZ/YJqhIZyXdVmFJS55zqkukpBmB+ohbX2kDduoj8Z0= +github.com/conduitio/conduit-connector-sdk v0.2.1-0.20220622151135-47f1a8905435 h1:/bjfGf/vG8vV5WjDb7vcsluVxPZVvfsYRF4nhzJg8q4= +github.com/conduitio/conduit-connector-sdk v0.2.1-0.20220622151135-47f1a8905435/go.mod h1:RVVcsR1JBSyN8cxzjBVMyTKDym3KS6MXD2Lons/Wsw4= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= @@ -612,8 +610,9 @@ go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= +go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= @@ -678,6 +677,7 @@ golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHl golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 h1:VLliZ0d+/avPrXXH+OakdXhpJuEoBZuwh1m2j7U6Iug= golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= @@ -728,8 +728,8 @@ golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.0.0-20220617184016-355a448f1bc9 h1:Yqz/iviulwKwAREEeUd3nbBFn0XuyJqkoft2IlrvOhc= -golang.org/x/net v0.0.0-20220617184016-355a448f1bc9/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e h1:TsQ7F31D3bUCLeqPT0u+yjp1guoArKaNKmCr22PYgTQ= +golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -807,8 +807,8 @@ golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c h1:aFV+BgZ4svzjfabn8ERpuB4JI4N6/rdy1iusx77G3oU= -golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220627191245-f75cf1eec38b h1:2n253B2r0pYSmEV+UNCQoPfU/FiaizQEK5Gu4Bq4JE8= +golang.org/x/sys v0.0.0-20220627191245-f75cf1eec38b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -950,8 +950,8 @@ google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210630183607-d20f26d13c79/go.mod h1:yiaVoXHpRzHGyxV3o4DktVWY4mSUErTKaeEOq6C3t3U= -google.golang.org/genproto v0.0.0-20220617124728-180714bec0ad h1:kqrS+lhvaMHCxul6sKQvKJ8nAAhlVItmZV822hYFH/U= -google.golang.org/genproto v0.0.0-20220617124728-180714bec0ad/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA= +google.golang.org/genproto v0.0.0-20220630150403-404d0664e509 h1:eUofWZEQ3SqKIW6WImdM2sxVVjnL0ahOYuIYC6WEYI8= +google.golang.org/genproto v0.0.0-20220630150403-404d0664e509/go.mod h1:KEWEmljWE5zPzLBa/oHl6DaEt9LmfH6WtH1OHIvleBA= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= diff --git a/pkg/pipeline/stream/destination_acker.go b/pkg/pipeline/stream/destination_acker.go index 317707ecb..960fd424f 100644 --- a/pkg/pipeline/stream/destination_acker.go +++ b/pkg/pipeline/stream/destination_acker.go @@ -214,12 +214,12 @@ func (n *DestinationAckerNode) Wait(ctx context.Context) { } n.logger.Debug(ctx). Int("remaining", cacheSize). - Msg("waiting for acker node to process remaining acks") + Msg("waiting for destination acker node to process remaining acks") select { case <-ctx.Done(): n.logger.Warn(ctx). Int("remaining", cacheSize). - Msg("stopped waiting for acker node even though some acks may be remaining") + Msg("stopped waiting for destination acker node even though some acks may be remaining") return case <-t.C: } diff --git a/pkg/pipeline/stream/source.go b/pkg/pipeline/stream/source.go index 0c37cecbc..08b3f6f17 100644 --- a/pkg/pipeline/stream/source.go +++ b/pkg/pipeline/stream/source.go @@ -112,7 +112,13 @@ func (n *SourceNode) Run(ctx context.Context) (err error) { if msg.ControlMessageType() == ControlMessageStopSourceNode { // this is a control message telling us to stop - stopPosition = msg.Record.Position + n.logger.Err(ctx, n.stopReason).Msg("stopping source connector") + stopPosition, err = n.Source.Stop(ctx) + if err != nil { + // TODO think through if just exiting here makes sense + return cerrors.Errorf("failed to stop source connector: %w", err) + } + if bytes.Equal(stopPosition, lastPosition) { // we already encountered the record with the last position return n.stopReason @@ -146,18 +152,10 @@ func (n *SourceNode) Run(ctx context.Context) (err error) { } func (n *SourceNode) Stop(ctx context.Context, reason error) error { - n.logger.Err(ctx, reason).Msg("stopping source connector") - lastPosition, err := n.Source.Stop(ctx) - if err != nil { - return cerrors.Errorf("failed to stop source connector: %w", err) - } - - n.stopReason = reason // InjectMessage will inject a message into the stream of messages being // processed by SourceNode to let it know when it should stop processing new // messages. return n.base.InjectMessage(ctx, &Message{ - Record: record.Record{Position: lastPosition}, controlMessageType: ControlMessageStopSourceNode, }) } From a6a898b97c1d2c21e9c30a806d84637671863274 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Fri, 1 Jul 2022 20:37:09 +0200 Subject: [PATCH 30/46] make Conduit in charge of closing connector streams * Change plugin semantics around teardown - internal connector entity is now in charge of closing the stream instead of plugin. * Map known gRPC errors to internal type (context.Canceled). * Rewrite DestinationAckerNode to be a regular node staning after DestinationNode, receiving messages and triggering ack receiving. This makes the structure simpler and in line with all other nodes. * Create OpenMessagesTracker to simplify tracking open messages in SourceNode and DestinationNode. --- go.mod | 1 + go.sum | 2 + pkg/connector/destination.go | 14 +- pkg/connector/source.go | 14 +- pkg/pipeline/lifecycle.go | 2 +- pkg/pipeline/lifecycle_test.go | 27 +- pkg/pipeline/stream/base.go | 16 +- pkg/pipeline/stream/base_test.go | 16 +- pkg/pipeline/stream/destination.go | 40 ++- pkg/pipeline/stream/destination_acker.go | 307 ++++++------------ pkg/pipeline/stream/destination_acker_test.go | 112 ------- pkg/pipeline/stream/message.go | 20 ++ pkg/pipeline/stream/metrics.go | 2 +- pkg/pipeline/stream/processor.go | 2 +- pkg/pipeline/stream/source.go | 20 +- pkg/pipeline/stream/source_acker.go | 2 +- pkg/pipeline/stream/stream_test.go | 32 +- pkg/plugin/acceptance_testing.go | 55 +--- pkg/plugin/builtin/v1/destination.go | 7 +- pkg/plugin/builtin/v1/source.go | 6 +- pkg/plugin/builtin/v1/stream.go | 5 +- pkg/plugin/standalone/v1/client.go | 15 +- pkg/plugin/standalone/v1/destination.go | 13 +- pkg/plugin/standalone/v1/source.go | 14 +- 24 files changed, 265 insertions(+), 479 deletions(-) delete mode 100644 pkg/pipeline/stream/destination_acker_test.go diff --git a/go.mod b/go.mod index 1bf0cee90..f2141850b 100644 --- a/go.mod +++ b/go.mod @@ -72,6 +72,7 @@ require ( github.com/dustin/go-humanize v1.0.0 // indirect github.com/fatih/color v1.13.0 // indirect github.com/fsnotify/fsnotify v1.5.1 // indirect + github.com/gammazero/deque v0.2.0 // indirect github.com/go-chi/chi/v5 v5.0.7 // indirect github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect github.com/gofrs/flock v0.8.1 // indirect diff --git a/go.sum b/go.sum index 85a2215a0..5a636ea26 100644 --- a/go.sum +++ b/go.sum @@ -205,6 +205,8 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWpgI= github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU= +github.com/gammazero/deque v0.2.0 h1:SkieyNB4bg2/uZZLxvya0Pq6diUlwx7m2TeT7GAIWaA= +github.com/gammazero/deque v0.2.0/go.mod h1:LFroj8x4cMYCukHJDbxFCkT+r9AndaJnFMuZDV34tuU= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8= github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= diff --git a/pkg/connector/destination.go b/pkg/connector/destination.go index 695805ed2..1db2516fe 100644 --- a/pkg/connector/destination.go +++ b/pkg/connector/destination.go @@ -53,6 +53,9 @@ type destination struct { // plugin is the running instance of the destination plugin. plugin plugin.DestinationPlugin + // stopStream is a function that closes the context of the stream + stopStream context.CancelFunc + // m can lock a destination from concurrent access (e.g. in connector persister). m sync.Mutex // wg tracks the number of in flight calls to the plugin. @@ -144,8 +147,10 @@ func (s *destination) Open(ctx context.Context) error { return err } - err = dest.Start(ctx) + streamCtx, cancelStreamCtx := context.WithCancel(ctx) + err = dest.Start(streamCtx) if err != nil { + cancelStreamCtx() _ = dest.Teardown(ctx) return err } @@ -153,6 +158,7 @@ func (s *destination) Open(ctx context.Context) error { s.logger.Info(ctx).Msg("destination connector plugin successfully started") s.plugin = dest + s.stopStream = cancelStreamCtx s.persister.ConnectorStarted() return nil } @@ -184,6 +190,12 @@ func (s *destination) Teardown(ctx context.Context) error { return plugin.ErrPluginNotRunning } + // close stream + if s.stopStream != nil { + s.stopStream() + s.stopStream = nil + } + // wait for any calls to the plugin to stop running first (e.g. Stop, Ack or Write) s.wg.Wait() diff --git a/pkg/connector/source.go b/pkg/connector/source.go index fe7570a68..336fbb81f 100644 --- a/pkg/connector/source.go +++ b/pkg/connector/source.go @@ -52,6 +52,9 @@ type source struct { // plugin is the running instance of the source plugin. plugin plugin.SourcePlugin + // stopStream is a function that closes the context of the stream + stopStream context.CancelFunc + // m can lock a source from concurrent access (e.g. in connector persister). m sync.Mutex // wg tracks the number of in flight calls to the plugin. @@ -143,8 +146,10 @@ func (s *source) Open(ctx context.Context) error { return err } - err = src.Start(ctx, s.XState.Position) + streamCtx, cancelStreamCtx := context.WithCancel(ctx) + err = src.Start(streamCtx, s.XState.Position) if err != nil { + cancelStreamCtx() _ = src.Teardown(ctx) return err } @@ -152,6 +157,7 @@ func (s *source) Open(ctx context.Context) error { s.logger.Info(ctx).Msg("source connector plugin successfully started") s.plugin = src + s.stopStream = cancelStreamCtx s.persister.ConnectorStarted() return nil } @@ -183,6 +189,12 @@ func (s *source) Teardown(ctx context.Context) error { return plugin.ErrPluginNotRunning } + // close stream + if s.stopStream != nil { + s.stopStream() + s.stopStream = nil + } + // wait for any calls to the plugin to stop running first (e.g. Stop, Ack or Read) s.wg.Wait() diff --git a/pkg/pipeline/lifecycle.go b/pkg/pipeline/lifecycle.go index cac735f29..cd9ed0fa1 100644 --- a/pkg/pipeline/lifecycle.go +++ b/pkg/pipeline/lifecycle.go @@ -340,10 +340,10 @@ func (s *Service) buildDestinationNodes( instance.Config().Plugin, strings.ToLower(instance.Type().String()), ), - AckerNode: ackerNode, } metricsNode := s.buildMetricsNode(pl, instance) destinationNode.Sub(metricsNode.Pub()) + ackerNode.Sub(destinationNode.Pub()) connNodes, err := s.buildProcessorNodes(ctx, procFetcher, pl, instance.Config().ProcessorIDs, prev, metricsNode) if err != nil { diff --git a/pkg/pipeline/lifecycle_test.go b/pkg/pipeline/lifecycle_test.go index 2424c3aa4..ac1e8408e 100644 --- a/pkg/pipeline/lifecycle_test.go +++ b/pkg/pipeline/lifecycle_test.go @@ -36,7 +36,6 @@ import ( ) func TestServiceLifecycle_PipelineSuccess(t *testing.T) { - t.Skip("TODO need to change test to keep source running forever") ctx, killAll := context.WithCancel(context.Background()) defer killAll() @@ -51,7 +50,7 @@ func TestServiceLifecycle_PipelineSuccess(t *testing.T) { // create mocked connectors ctrl := gomock.NewController(t) - source, wantRecords := generatorSource(ctrl, 10, ctx.Err(), false) + source, wantRecords := generatorSource(ctrl, 10, nil, false) destination := asserterDestination(ctrl, t, wantRecords, false) pl, err = ps.AddConnector(ctx, pl, source.ID()) @@ -197,6 +196,13 @@ func (tpf testProcessorFetcher) Get(ctx context.Context, id string) (*processor. func generatorSource(ctrl *gomock.Controller, recordCount int, wantErr error, teardown bool) (connector.Source, []record.Record) { position := 0 records := make([]record.Record, recordCount) + for i := 0; i < recordCount; i++ { + records[i] = record.Record{ + Key: record.RawData{Raw: []byte(uuid.NewString())}, + Payload: record.RawData{Raw: []byte(uuid.NewString())}, + Position: record.Position(strconv.Itoa(i)), + } + } source := basicSourceMock(ctrl) if teardown { @@ -206,15 +212,13 @@ func generatorSource(ctrl *gomock.Controller, recordCount int, wantErr error, te source.EXPECT().Ack(gomock.Any(), gomock.Any()).Return(nil).Times(recordCount) source.EXPECT().Read(gomock.Any()).DoAndReturn(func(ctx context.Context) (record.Record, error) { if position == recordCount { - return record.Record{}, wantErr - } - r := record.Record{ - Key: record.RawData{Raw: []byte(uuid.NewString())}, - Payload: record.RawData{Raw: []byte(uuid.NewString())}, - Position: record.Position(strconv.Itoa(position)), + if wantErr != nil { + return record.Record{}, wantErr + } + <-ctx.Done() + return record.Record{}, ctx.Err() } - - records[position] = r + r := records[position] position++ return r, nil }).MinTimes(recordCount + 1) @@ -236,7 +240,7 @@ func basicSourceMock(ctrl *gomock.Controller) *connmock.Source { // match the expected records. On teardown it also makes sure that it received // all expected records. func asserterDestination(ctrl *gomock.Controller, t *testing.T, want []record.Record, teardown bool) connector.Destination { - rchan := make(chan record.Record) + rchan := make(chan record.Record, 1) recordCount := 0 destination := connmock.NewDestination(ctrl) @@ -246,6 +250,7 @@ func asserterDestination(ctrl *gomock.Controller, t *testing.T, want []record.Re destination.EXPECT().Open(gomock.Any()).Return(nil).Times(1) destination.EXPECT().Errors().Return(make(chan error)) if teardown { + destination.EXPECT().Stop(gomock.Any(), want[len(want)-1].Position).Return(nil).Times(1) destination.EXPECT().Teardown(gomock.Any()).DoAndReturn(func(ctx context.Context) error { close(rchan) return nil diff --git a/pkg/pipeline/stream/base.go b/pkg/pipeline/stream/base.go index d75f78156..957474505 100644 --- a/pkg/pipeline/stream/base.go +++ b/pkg/pipeline/stream/base.go @@ -48,8 +48,9 @@ type pubSubNodeBase struct { func (n *pubSubNodeBase) Trigger( ctx context.Context, logger log.CtxLogger, + externalErrChan <-chan error, ) (triggerFunc, cleanupFunc, error) { - trigger, cleanup1, err := n.subNodeBase.Trigger(ctx, logger, nil) + trigger, cleanup1, err := n.subNodeBase.Trigger(ctx, logger, externalErrChan) if err != nil { return nil, nil, err } @@ -150,7 +151,11 @@ func (n *pubNodeBase) Trigger( for { msg, err := msgFetcher(ctx) if err != nil { - internalErrChan <- err + if !cerrors.Is(err, context.Canceled) { + // ignore context error because it is going to be caught + // by nodeBase.Receive anyway + internalErrChan <- err + } return } n.msgChan <- msg @@ -176,6 +181,8 @@ func (n *pubNodeBase) Trigger( // create a separate channel for signals which makes it performant and easiest // to implement. func (n *pubNodeBase) InjectMessage(ctx context.Context, message *Message) error { + n.lock.Lock() + defer n.lock.Unlock() select { case <-ctx.Done(): return ctx.Err() @@ -252,11 +259,6 @@ func (n *subNodeBase) Trigger( n.running = true - if errChan == nil { - // create dummy channel - errChan = make(chan error) - } - trigger := func() (*Message, error) { return n.nodeBase.Receive(ctx, logger, n.in, errChan) } diff --git a/pkg/pipeline/stream/base_test.go b/pkg/pipeline/stream/base_test.go index bf3b09181..97a3d86ab 100644 --- a/pkg/pipeline/stream/base_test.go +++ b/pkg/pipeline/stream/base_test.go @@ -29,21 +29,21 @@ func TestPubSubNodeBase_TriggerWithoutPubOrSub(t *testing.T) { logger := log.Nop() n := &pubSubNodeBase{} - trigger, cleanup, err := n.Trigger(ctx, logger) + trigger, cleanup, err := n.Trigger(ctx, logger, nil) assert.Nil(t, trigger) assert.Nil(t, cleanup) assert.Error(t, err) n = &pubSubNodeBase{} n.Pub() - trigger, cleanup, err = n.Trigger(ctx, logger) + trigger, cleanup, err = n.Trigger(ctx, logger, nil) assert.Nil(t, trigger) assert.Nil(t, cleanup) assert.Error(t, err) n = &pubSubNodeBase{} n.Sub(make(chan *Message)) - trigger, cleanup, err = n.Trigger(ctx, logger) + trigger, cleanup, err = n.Trigger(ctx, logger, nil) assert.Nil(t, trigger) assert.Nil(t, cleanup) assert.Error(t, err) @@ -76,12 +76,12 @@ func TestPubSubNodeBase_TriggerTwice(t *testing.T) { n := &pubSubNodeBase{} n.Pub() n.Sub(make(chan *Message)) - trigger, cleanup, err := n.Trigger(ctx, logger) + trigger, cleanup, err := n.Trigger(ctx, logger, nil) assert.Ok(t, err) assert.NotNil(t, trigger) assert.NotNil(t, cleanup) - trigger, cleanup, err = n.Trigger(ctx, logger) + trigger, cleanup, err = n.Trigger(ctx, logger, nil) assert.Nil(t, trigger) assert.Nil(t, cleanup) assert.Error(t, err) @@ -96,7 +96,7 @@ func TestPubSubNodeBase_TriggerSuccess(t *testing.T) { n.Sub(in) n.Pub() - trigger, cleanup, err := n.Trigger(ctx, logger) + trigger, cleanup, err := n.Trigger(ctx, logger, nil) assert.Ok(t, err) assert.NotNil(t, trigger) assert.NotNil(t, cleanup) @@ -123,7 +123,7 @@ func TestPubSubNodeBase_TriggerClosedSubChannel(t *testing.T) { n.Sub(in) n.Pub() - trigger, cleanup, err := n.Trigger(ctx, logger) + trigger, cleanup, err := n.Trigger(ctx, logger, nil) assert.Ok(t, err) assert.NotNil(t, trigger) assert.NotNil(t, cleanup) @@ -147,7 +147,7 @@ func TestPubSubNodeBase_TriggerCancelledContext(t *testing.T) { n.Sub(in) n.Pub() - trigger, cleanup, err := n.Trigger(ctx, logger) + trigger, cleanup, err := n.Trigger(ctx, logger, nil) assert.Ok(t, err) assert.NotNil(t, trigger) assert.NotNil(t, cleanup) diff --git a/pkg/pipeline/stream/destination.go b/pkg/pipeline/stream/destination.go index 3574923de..4149170b2 100644 --- a/pkg/pipeline/stream/destination.go +++ b/pkg/pipeline/stream/destination.go @@ -30,10 +30,8 @@ type DestinationNode struct { Name string Destination connector.Destination ConnectorTimer metrics.Timer - // AckerNode is responsible for handling acks - AckerNode *DestinationAckerNode - base subNodeBase + base pubSubNodeBase logger log.CtxLogger } @@ -55,6 +53,8 @@ func (n *DestinationNode) Run(ctx context.Context) (err error) { // lastPosition stores the position of the last successfully processed record var lastPosition record.Position + // openMsgTracker tracks open messages until they are acked or nacked + var openMsgTracker OpenMessagesTracker defer func() { stopErr := n.Destination.Stop(connectorCtx, lastPosition) if stopErr != nil { @@ -67,11 +67,7 @@ func (n *DestinationNode) Run(ctx context.Context) (err error) { } } - // wait for acker node to receive all outstanding acks, time out after - // 1 minute or right away if the context is already canceled. - waitCtx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() - n.AckerNode.Wait(waitCtx) + openMsgTracker.Wait() // teardown will kill the plugin process tdErr := n.Destination.Teardown(connectorCtx) @@ -99,28 +95,25 @@ func (n *DestinationNode) Run(ctx context.Context) (err error) { n.logger.Trace(msg.Ctx).Msg("writing record to destination connector") - // first signal ack handler we might receive an ack, since this could - // already happen before write returns - err = n.AckerNode.ExpectAck(msg) - if err != nil { - return err - } - writeTime := time.Now() err = n.Destination.Write(msg.Ctx, msg.Record) if err != nil { // An error in Write is a fatal error, we probably won't be able to // process any further messages because there is a problem in the - // communication with the plugin. We need to let the acker node know - // that it shouldn't wait to receive an ack for the message, we need - // to nack the message to not leave it open and then return the - // error to stop the pipeline. - n.AckerNode.Forget(msg) + // communication with the plugin. We need to nack the message to not + // leave it open and then return the error to stop the pipeline. _ = msg.Nack(err) return cerrors.Errorf("error writing to destination: %w", err) } - lastPosition = msg.Record.Position n.ConnectorTimer.Update(time.Since(writeTime)) + + openMsgTracker.Add(msg) + lastPosition = msg.Record.Position + + err = n.base.Send(ctx, n.logger, msg) + if err != nil { + return msg.Nack(err) + } } } @@ -129,6 +122,11 @@ func (n *DestinationNode) Sub(in <-chan *Message) { n.base.Sub(in) } +// Pub will return the outgoing channel. +func (n *DestinationNode) Pub() <-chan *Message { + return n.base.Pub() +} + // SetLogger sets the logger. func (n *DestinationNode) SetLogger(logger log.CtxLogger) { n.logger = logger diff --git a/pkg/pipeline/stream/destination_acker.go b/pkg/pipeline/stream/destination_acker.go index 960fd424f..0db2d835c 100644 --- a/pkg/pipeline/stream/destination_acker.go +++ b/pkg/pipeline/stream/destination_acker.go @@ -15,17 +15,15 @@ package stream import ( + "bytes" "context" "sync" - "sync/atomic" - "time" "github.com/conduitio/conduit/pkg/connector" "github.com/conduitio/conduit/pkg/foundation/cerrors" "github.com/conduitio/conduit/pkg/foundation/log" "github.com/conduitio/conduit/pkg/foundation/multierror" - "github.com/conduitio/conduit/pkg/plugin" - "github.com/conduitio/conduit/pkg/record" + "github.com/gammazero/deque" ) // DestinationAckerNode is responsible for handling acknowledgments received @@ -34,114 +32,117 @@ type DestinationAckerNode struct { Name string Destination connector.Destination - logger log.CtxLogger - // cache stores the messages that are still waiting for an ack/nack. - cache *positionMessageMap - - // start is closed once the first message is received in the destination node. - start chan struct{} - // stop is closed once the last message is received in the destination node. - stop chan struct{} - // initOnce initializes internal fields. - initOnce sync.Once - // startOnce closes start. - startOnce sync.Once - // stopOnce closes stop. - stopOnce sync.Once -} + // queue is used to store messages + queue deque.Deque[*Message] + // m guards access to queue + m sync.Mutex -// init initializes DestinationAckerNode internal fields. -func (n *DestinationAckerNode) init() { - n.initOnce.Do(func() { - n.cache = &positionMessageMap{} - n.start = make(chan struct{}) - n.stop = make(chan struct{}) - }) + base subNodeBase + logger log.CtxLogger } func (n *DestinationAckerNode) ID() string { return n.Name } -// Run continuously fetches acks from the destination and forwards them to the -// correct message by calling Ack or Nack on that message. func (n *DestinationAckerNode) Run(ctx context.Context) (err error) { - n.logger.Trace(ctx).Msg("starting acker node") - defer n.logger.Trace(ctx).Msg("acker node stopped") + // start a fresh connector context to make sure the connector is running + // until this method returns + connectorCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + signalChan := make(chan struct{}) + errChan := make(chan error) - n.init() defer func() { + close(signalChan) + workerErr := <-errChan + if workerErr != nil { + if err != nil { + // we are already returning an error, log this one instead + n.logger.Err(ctx, workerErr).Msg("destination acker node worker failed") + } else { + err = workerErr + } + } teardownErr := n.teardown(err) - if err != nil { - // we are already returning an error, just log this one - n.logger.Err(ctx, teardownErr).Msg("acker node stopped without processing all messages") - } else { - // return teardownErr instead - err = teardownErr + if teardownErr != nil { + if err != nil { + // we are already returning an error, log this one instead + n.logger.Err(ctx, teardownErr).Msg("destination acker node stopped before processing all messages") + } else { + err = teardownErr + } } }() - select { - case <-ctx.Done(): - return ctx.Err() - case <-n.stop: - // destination actually stopped without ever receiving a message, we can - // just return here - return nil - case <-n.start: - // received first message for ack, destination is open now, we can - // safely start listening to acks - n.logger.Trace(ctx).Msg("start running acker node") + trigger, cleanup, err := n.base.Trigger(ctx, n.logger, errChan) + if err != nil { + close(errChan) // need to close errChan to not block deferred function + return err } + // start worker that will fetch acks from the connector and forward them to + // internal messages + go n.worker(connectorCtx, signalChan, errChan) + + defer cleanup() for { - pos, err := n.Destination.Ack(ctx) - if pos == nil { - // empty position is returned only if an actual error happened - if cerrors.Is(err, plugin.ErrStreamNotOpen) { - // this means the plugin stopped, gracefully shut down - n.logger.Debug(ctx).Msg("ack stream closed") - return nil - } + msg, err := trigger() + if err != nil || msg == nil { return err } - msg, ok := n.cache.LoadAndDelete(pos) - if !ok { - n.logger.Error(ctx). - Str(log.RecordPositionField, pos.String()). - Msg("received unexpected ack (could be an internal bug or a badly written connector), ignoring the ack and continuing, please report the issue to the Conduit team") - continue - } - - // TODO make sure acks are called in the right order or this will block - // forever. Right now we rely on connectors sending acks back in the - // correct order and this should generally be true, but a badly written - // connector could provoke a deadlock, we could prevent that. - err = n.handleAck(msg, err) - if err != nil { - return err + n.m.Lock() + n.queue.PushBack(msg) + n.m.Unlock() + select { + case signalChan <- struct{}{}: + // triggered the start of listening to acks in worker goroutine + default: + // worker goroutine is already busy, it will pick up the message + // because it is already stored in the queue } } } -// teardown will nack all messages still in the cache and return an error in -// case there were still unprocessed messages in the cache. -func (n *DestinationAckerNode) teardown(reason error) error { - var nacked int - var err error - n.cache.Range(func(pos record.Position, msg *Message) bool { - err = multierror.Append(err, msg.Nack(reason)) - nacked++ - return true - }) - if err != nil { - return cerrors.Errorf("nacked %d messages when stopping destination acker node, some nacks failed: %w", nacked, err) - } - if nacked > 0 { - return cerrors.Errorf("nacked %d messages when stopping destination acker node", nacked) +func (n *DestinationAckerNode) worker( + ctx context.Context, + signalChan <-chan struct{}, + errChan chan<- error, +) { + defer close(errChan) + for range signalChan { + // signal is received when a new message is in the queue + // let's start fetching acks for messages in the queue + for { + // check if there are more messages waiting in the queue + n.m.Lock() + if n.queue.Len() == 0 { + n.m.Unlock() + break + } + msg := n.queue.PopFront() + n.m.Unlock() + + pos, err := n.Destination.Ack(ctx) + if pos == nil { + // empty position is returned only if an actual error happened + errChan <- cerrors.Errorf("failed to receive ack: %w", err) + return + } + if !bytes.Equal(msg.Record.Position, pos) { + errChan <- cerrors.Errorf("received unexpected ack, expected position %q but got %q", msg.Record.Position, pos) + return + } + + err = n.handleAck(msg, err) + if err != nil { + errChan <- err + return + } + } } - return nil } // handleAck either acks or nacks the message, depending on the supplied error. @@ -164,126 +165,34 @@ func (n *DestinationAckerNode) handleAck(msg *Message, err error) error { return nil } -// ExpectAck makes the handler aware of the message and signals to it that an -// ack for this message might be received at some point. -func (n *DestinationAckerNode) ExpectAck(msg *Message) error { - // happens only once to signal Run that the destination is ready to be used. - n.startOnce.Do(func() { - n.init() - close(n.start) - }) +// teardown will nack all messages still in the cache and return an error in +// case there were still unprocessed messages in the cache. +func (n *DestinationAckerNode) teardown(reason error) error { + n.m.Lock() + defer n.m.Unlock() - _, loaded := n.cache.LoadOrStore(msg.Record.Position, msg) - if loaded { - // we already have a message with the same position in the cache - n.logger.Error(msg.Ctx).Msg("encountered two records with the same " + - "position and can't differentiate them (could be that you are using " + - "a pipeline with two same source connectors and they both produced " + - "a record with the same position at the same time, could also be a " + - "badly written source connector that doesn't assign unique positions " + - "to records)") - return cerrors.Errorf("encountered two records with the same position (%q)", - msg.Record.Position.String()) + var nacked int + var err error + for n.queue.Len() > 0 { + msg := n.queue.PopFront() + err = multierror.Append(err, msg.Nack(reason)) + nacked++ + } + if err != nil { + return cerrors.Errorf("nacked %d messages when stopping destination acker node, some nacks failed: %w", nacked, err) + } + if nacked > 0 { + return cerrors.Errorf("nacked %d messages when stopping destination acker node", nacked) } return nil } -// Forget signals the handler that an ack for this message won't be received, -// and it should remove it from its cache. -func (n *DestinationAckerNode) Forget(msg *Message) { - n.cache.LoadAndDelete(msg.Record.Position) -} - -// Wait can be used to wait for the count of outstanding acks to drop to 0 or -// the context gets canceled. Wait is expected to be the last function called on -// DestinationAckerNode, after Wait returns DestinationAckerNode will soon stop -// running. -func (n *DestinationAckerNode) Wait(ctx context.Context) { - // happens only once to signal that the destination is stopping - n.stopOnce.Do(func() { - n.init() - close(n.stop) - }) - - t := time.NewTimer(time.Second) - defer t.Stop() - for { - cacheSize := n.cache.Len() - if cacheSize == 0 { - return - } - n.logger.Debug(ctx). - Int("remaining", cacheSize). - Msg("waiting for destination acker node to process remaining acks") - select { - case <-ctx.Done(): - n.logger.Warn(ctx). - Int("remaining", cacheSize). - Msg("stopped waiting for destination acker node even though some acks may be remaining") - return - case <-t.C: - } - } +// Sub will subscribe this node to an incoming channel. +func (n *DestinationAckerNode) Sub(in <-chan *Message) { + n.base.Sub(in) } // SetLogger sets the logger. func (n *DestinationAckerNode) SetLogger(logger log.CtxLogger) { n.logger = logger } - -// positionMessageMap is like a Go map[record.Position]*Message but is safe for -// concurrent use by multiple goroutines. See documentation for sync.Map for -// more information (it's being used under the hood). -type positionMessageMap struct { - m sync.Map - length uint32 -} - -// LoadAndDelete deletes the value for a key, returning the previous value if any. -// The loaded result reports whether the key was present. -func (m *positionMessageMap) LoadAndDelete(pos record.Position) (msg *Message, loaded bool) { - val, loaded := m.m.LoadAndDelete(m.key(pos)) - if !loaded { - return nil, false - } - atomic.AddUint32(&m.length, ^uint32(0)) // decrement - return val.(*Message), loaded -} - -// LoadOrStore returns the existing value for the key if present. -// Otherwise, it stores and returns the given value. -// The loaded result is true if the value was loaded, false if stored. -func (m *positionMessageMap) LoadOrStore(pos record.Position, msg *Message) (actual *Message, loaded bool) { - val, loaded := m.m.LoadOrStore(m.key(pos), msg) - if !loaded { - atomic.AddUint32(&m.length, 1) // increment - } - return val.(*Message), loaded -} - -// Range calls f sequentially for each key and value present in the map. -// If f returns false, range stops the iteration. -// -// Range does not necessarily correspond to any consistent snapshot of the Map's -// contents: no key will be visited more than once, but if the value for any key -// is stored or deleted concurrently, Range may reflect any mapping for that key -// from any point during the Range call. -// -// Range may be O(N) with the number of elements in the map even if f returns -// false after a constant number of calls. -func (m *positionMessageMap) Range(f func(pos record.Position, msg *Message) bool) { - m.m.Range(func(key, value interface{}) bool { - return f(record.Position(key.(string)), value.(*Message)) - }) -} - -// Len returns the number of elements in the map. -func (m *positionMessageMap) Len() int { - return int(atomic.LoadUint32(&m.length)) -} - -// key takes a position and converts it into a hashable object that can be used -// as a key in a map. -func (m *positionMessageMap) key(pos record.Position) interface{} { - return string(pos) -} diff --git a/pkg/pipeline/stream/destination_acker_test.go b/pkg/pipeline/stream/destination_acker_test.go deleted file mode 100644 index 66dfe759e..000000000 --- a/pkg/pipeline/stream/destination_acker_test.go +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright © 2022 Meroxa, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stream - -import ( - "context" - "testing" - "time" - - "github.com/conduitio/conduit/pkg/connector/mock" - "github.com/conduitio/conduit/pkg/plugin" - "github.com/conduitio/conduit/pkg/record" - "github.com/golang/mock/gomock" - "github.com/matryer/is" -) - -func TestAckerNode_Run_StopAfterWait(t *testing.T) { - is := is.New(t) - ctx := context.Background() - ctrl := gomock.NewController(t) - dest := mock.NewDestination(ctrl) - - node := &DestinationAckerNode{ - Name: "acker-node", - Destination: dest, - } - - nodeDone := make(chan struct{}) - go func() { - defer close(nodeDone) - err := node.Run(ctx) - is.NoErr(err) - }() - - // note that there should be no calls to the destination at all if we didn't - // receive any ExpectedAck call - - // give the test 1 second to finish - waitCtx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - node.Wait(waitCtx) - - select { - case <-waitCtx.Done(): - is.Fail() // expected node to stop running - case <-nodeDone: - // all good - } -} - -func TestAckerNode_Run_StopAfterExpectAck(t *testing.T) { - is := is.New(t) - ctx := context.Background() - ctrl := gomock.NewController(t) - dest := mock.NewDestination(ctrl) - - node := &DestinationAckerNode{ - Name: "acker-node", - Destination: dest, - } - - nodeDone := make(chan struct{}) - go func() { - defer close(nodeDone) - err := node.Run(ctx) - is.NoErr(err) - }() - - // up to this point there should have been no calls to the destination - // only after the call to ExpectAck should the node try to fetch any acks - msg := &Message{ - Record: record.Record{Position: record.Position("test-position")}, - } - // first return position - expectAck := make(chan struct{}) - c1 := dest.EXPECT().Ack(gomock.Any()). - DoAndReturn(func(ctx context.Context) (record.Position, error) { - // wait until ExpectAck is called - <-expectAck - return msg.Record.Position, nil - }) - // second return closed stream - dest.EXPECT().Ack(gomock.Any()). - Return(nil, plugin.ErrStreamNotOpen).After(c1) - - err := node.ExpectAck(msg) - close(expectAck) // signal to mock that ExpectAck returned - is.NoErr(err) - - // give the test 1 second to finish - waitCtx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - - select { - case <-waitCtx.Done(): - is.Fail() // expected node to stop running - case <-nodeDone: - // all good - } -} diff --git a/pkg/pipeline/stream/message.go b/pkg/pipeline/stream/message.go index 1ab6adb32..822584518 100644 --- a/pkg/pipeline/stream/message.go +++ b/pkg/pipeline/stream/message.go @@ -261,3 +261,23 @@ func (m *Message) Status() MessageStatus { return MessageStatusOpen } } + +// OpenMessagesTracker allows you to track messages until they reach the end of +// the pipeline. +type OpenMessagesTracker sync.WaitGroup + +// Add will increase the counter in the wait group and register a status handler +// that will decrease the counter when the message is acked or nacked. +func (t *OpenMessagesTracker) Add(msg *Message) { + (*sync.WaitGroup)(t).Add(1) + msg.RegisterStatusHandler( + func(msg *Message, change StatusChange) error { + (*sync.WaitGroup)(t).Done() + return nil + }, + ) +} + +func (t *OpenMessagesTracker) Wait() { + (*sync.WaitGroup)(t).Wait() +} diff --git a/pkg/pipeline/stream/metrics.go b/pkg/pipeline/stream/metrics.go index 0fce43dd7..9395cee0d 100644 --- a/pkg/pipeline/stream/metrics.go +++ b/pkg/pipeline/stream/metrics.go @@ -34,7 +34,7 @@ func (n *MetricsNode) ID() string { } func (n *MetricsNode) Run(ctx context.Context) error { - trigger, cleanup, err := n.base.Trigger(ctx, n.logger) + trigger, cleanup, err := n.base.Trigger(ctx, n.logger, nil) if err != nil { return err } diff --git a/pkg/pipeline/stream/processor.go b/pkg/pipeline/stream/processor.go index 74a68a3d6..3526575b1 100644 --- a/pkg/pipeline/stream/processor.go +++ b/pkg/pipeline/stream/processor.go @@ -38,7 +38,7 @@ func (n *ProcessorNode) ID() string { } func (n *ProcessorNode) Run(ctx context.Context) error { - trigger, cleanup, err := n.base.Trigger(ctx, n.logger) + trigger, cleanup, err := n.base.Trigger(ctx, n.logger, nil) if err != nil { return err } diff --git a/pkg/pipeline/stream/source.go b/pkg/pipeline/stream/source.go index 08b3f6f17..38847e899 100644 --- a/pkg/pipeline/stream/source.go +++ b/pkg/pipeline/stream/source.go @@ -17,7 +17,6 @@ package stream import ( "bytes" "context" - "sync" "time" "github.com/conduitio/conduit/pkg/connector" @@ -59,11 +58,12 @@ func (n *SourceNode) Run(ctx context.Context) (err error) { return cerrors.Errorf("could not open source connector: %w", err) } - var wgOpenMessages sync.WaitGroup + // openMsgTracker tracks open messages until they are acked or nacked + var openMsgTracker OpenMessagesTracker defer func() { // wait for open messages before tearing down connector n.logger.Trace(ctx).Msg("waiting for open messages to be processed") - wgOpenMessages.Wait() + openMsgTracker.Wait() n.logger.Trace(ctx).Msg("all messages processed, tearing down source") tdErr := n.Source.Teardown(connectorCtx) if tdErr != nil { @@ -82,7 +82,6 @@ func (n *SourceNode) Run(ctx context.Context) (err error) { n.Source.Errors(), func(ctx context.Context) (*Message, error) { n.logger.Trace(ctx).Msg("reading record from source connector") - r, err := n.Source.Read(ctx) if err != nil { return nil, cerrors.Errorf("error reading from source: %w", err) @@ -126,24 +125,22 @@ func (n *SourceNode) Run(ctx context.Context) (err error) { continue } - // register another open message - wgOpenMessages.Add(1) + // track message until it reaches an end state + openMsgTracker.Add(msg) + msg.RegisterStatusHandler( func(msg *Message, change StatusChange) error { - // this is the last handler to be executed, once this handler is - // reached we know either the message was either acked or nacked - defer n.PipelineTimer.Update(time.Since(msg.Record.ReadAt)) - defer wgOpenMessages.Done() + n.PipelineTimer.Update(time.Since(msg.Record.ReadAt)) return nil }, ) + lastPosition = msg.Record.Position err = n.base.Send(ctx, n.logger, msg) if err != nil { return msg.Nack(err) } - lastPosition = msg.Record.Position if bytes.Equal(stopPosition, lastPosition) { // it's the last record that we are supposed to process, stop here return n.stopReason @@ -152,6 +149,7 @@ func (n *SourceNode) Run(ctx context.Context) (err error) { } func (n *SourceNode) Stop(ctx context.Context, reason error) error { + n.stopReason = reason // InjectMessage will inject a message into the stream of messages being // processed by SourceNode to let it know when it should stop processing new // messages. diff --git a/pkg/pipeline/stream/source_acker.go b/pkg/pipeline/stream/source_acker.go index 8614ed8e0..7c276d7f5 100644 --- a/pkg/pipeline/stream/source_acker.go +++ b/pkg/pipeline/stream/source_acker.go @@ -46,7 +46,7 @@ func (n *SourceAckerNode) ID() string { } func (n *SourceAckerNode) Run(ctx context.Context) error { - trigger, cleanup, err := n.base.Trigger(ctx, n.logger) + trigger, cleanup, err := n.base.Trigger(ctx, n.logger, nil) if err != nil { return err } diff --git a/pkg/pipeline/stream/stream_test.go b/pkg/pipeline/stream/stream_test.go index 8cd20749a..b67aed0d2 100644 --- a/pkg/pipeline/stream/stream_test.go +++ b/pkg/pipeline/stream/stream_test.go @@ -60,7 +60,6 @@ func Example_simpleStream() { Name: "printer-acker", Destination: node3.Destination, } - node3.AckerNode = node4 stream.SetLogger(node1, logger) stream.SetLogger(node2, logger) @@ -70,6 +69,7 @@ func Example_simpleStream() { // put everything together node2.Sub(node1.Pub()) node3.Sub(node2.Pub()) + node4.Sub(node3.Pub()) var wg sync.WaitGroup wg.Add(4) @@ -111,6 +111,7 @@ func Example_simpleStream() { // INF stopping source connector component=SourceNode node_id=generator // DBG incoming messages channel closed component=SourceAckerNode node_id=generator-acker // DBG incoming messages channel closed component=DestinationNode node_id=printer + // DBG incoming messages channel closed component=DestinationAckerNode node_id=printer-acker // INF finished successfully } @@ -153,23 +154,25 @@ func Example_complexStream() { Destination: printerDestination(ctrl, logger, "printer1"), ConnectorTimer: noop.Timer{}, } - node9 := &stream.DestinationNode{ + node9 := &stream.DestinationAckerNode{ + Name: "printer1-acker", + Destination: node8.Destination, + } + node10 := &stream.DestinationNode{ Name: "printer2", Destination: printerDestination(ctrl, logger, "printer2"), ConnectorTimer: noop.Timer{}, } - node10 := &stream.DestinationAckerNode{ - Name: "printer1-acker", - Destination: node8.Destination, - } - node8.AckerNode = node10 node11 := &stream.DestinationAckerNode{ Name: "printer2-acker", - Destination: node9.Destination, + Destination: node10.Destination, } - node9.AckerNode = node11 // put everything together + // this is the pipeline we are building + // [1] -> [2] -\ /-> [8] -> [9] + // |- [5] -> [6] -> [7] -| + // [3] -> [4] -/ \-> [10] -> [11] node2.Sub(node1.Pub()) node4.Sub(node3.Pub()) @@ -180,7 +183,10 @@ func Example_complexStream() { node7.Sub(node6.Pub()) node8.Sub(node7.Pub()) - node9.Sub(node7.Pub()) + node10.Sub(node7.Pub()) + + node9.Sub(node8.Pub()) + node11.Sub(node10.Pub()) // run nodes nodes := []stream.Node{node1, node2, node3, node4, node5, node6, node7, node8, node9, node10, node11} @@ -274,8 +280,10 @@ func Example_complexStream() { // DBG incoming messages channel closed component=SourceAckerNode node_id=generator1-acker // DBG incoming messages channel closed component=SourceAckerNode node_id=generator2-acker // DBG incoming messages channel closed component=ProcessorNode node_id=counter - // DBG incoming messages channel closed component=DestinationNode node_id=printer2 // DBG incoming messages channel closed component=DestinationNode node_id=printer1 + // DBG incoming messages channel closed component=DestinationNode node_id=printer2 + // DBG incoming messages channel closed component=DestinationAckerNode node_id=printer1-acker + // DBG incoming messages channel closed component=DestinationAckerNode node_id=printer2-acker // INF counter node counted 20 messages // INF finished successfully } @@ -333,7 +341,7 @@ func generatorSource(ctrl *gomock.Controller, logger log.CtxLogger, nodeID strin func printerDestination(ctrl *gomock.Controller, logger log.CtxLogger, nodeID string) connector.Destination { var lastPosition record.Position - rchan := make(chan record.Record) + rchan := make(chan record.Record, 1) destination := connmock.NewDestination(ctrl) destination.EXPECT().Open(gomock.Any()).Return(nil).Times(1) destination.EXPECT().Write(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, r record.Record) error { diff --git a/pkg/plugin/acceptance_testing.go b/pkg/plugin/acceptance_testing.go index a995dc079..1ca0232e5 100644 --- a/pkg/plugin/acceptance_testing.go +++ b/pkg/plugin/acceptance_testing.go @@ -18,7 +18,6 @@ package plugin import ( "context" "fmt" - "io" "reflect" "runtime" "strings" @@ -74,7 +73,6 @@ func AcceptanceTestV1(t *testing.T, tdf testDispenserFunc) { run(t, tdf, testDestination_Ack_WithoutStart) run(t, tdf, testDestination_Run_Fail) run(t, tdf, testDestination_Teardown_Success) - run(t, tdf, testDestination_Teardown_CloseSend) } func run(t *testing.T, tdf testDispenserFunc, test func(*testing.T, testDispenserFunc)) { @@ -411,11 +409,7 @@ func testSource_Read_CancelContext(t *testing.T, tdf testDispenserFunc) { }) _, err = source.Read(ctx) - is.True(err != nil) - // TODO see if we can change this error into context.Canceled, right now we - // follow the default gRPC behavior - is.True(!cerrors.Is(err, context.Canceled)) - is.True(!cerrors.Is(err, ErrStreamNotOpen)) + is.True(cerrors.Is(err, context.Canceled)) close(stopRunCh) // stop run channel } @@ -916,50 +910,3 @@ func testDestination_Teardown_Success(t *testing.T, tdf testDispenserFunc) { t.Fatal("should've received call to destination.Run") } } - -func testDestination_Teardown_CloseSend(t *testing.T, tdf testDispenserFunc) { - is := is.New(t) - - ctx := context.Background() - dispenser, _, _, mockDestination := tdf(t) - - closeCh := make(chan struct{}) - mockDestination.EXPECT(). - Start(gomock.Any(), cpluginv1.DestinationStartRequest{}). - Return(cpluginv1.DestinationStartResponse{}, nil) - mockDestination.EXPECT(). - Stop(gomock.Any(), cpluginv1.DestinationStopRequest{ - LastPosition: []byte("foo"), - }). - Return(cpluginv1.DestinationStopResponse{}, nil) - mockDestination.EXPECT(). - Run(gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, stream cpluginv1.DestinationRunStream) error { - _, recvErr := stream.Recv() - is.Equal(recvErr, io.EOF) - close(closeCh) - return recvErr - }) - mockDestination.EXPECT(). - Teardown(gomock.Any(), cpluginv1.DestinationTeardownRequest{}). - Return(cpluginv1.DestinationTeardownResponse{}, nil) - - destination, err := dispenser.DispenseDestination() - is.NoErr(err) - - err = destination.Start(ctx) - is.NoErr(err) - err = destination.Stop(ctx, record.Position("foo")) - is.NoErr(err) - - err = destination.Teardown(ctx) - is.NoErr(err) - - select { - case <-closeCh: - // all good, outgoing stream was closed - case <-time.After(time.Second): - is.Fail() // expected outgoing stream to be closed - } - -} diff --git a/pkg/plugin/builtin/v1/destination.go b/pkg/plugin/builtin/v1/destination.go index 09823bc81..d7acc5e64 100644 --- a/pkg/plugin/builtin/v1/destination.go +++ b/pkg/plugin/builtin/v1/destination.go @@ -91,7 +91,7 @@ func (s *destinationPluginAdapter) Start(ctx context.Context) error { s.logger.Trace(ctx).Msg("calling Run") err := runSandboxNoResp(s.impl.Run, s.withLogger(ctx), cpluginv1.DestinationRunStream(s.stream)) if err != nil { - if s.stream.stop(cerrors.Errorf("error in run: %w", err)) { + if !s.stream.stop(err) { s.logger.Err(ctx, err).Msg("stream already stopped") } } else { @@ -157,11 +157,6 @@ func (s *destinationPluginAdapter) Stop(ctx context.Context, lastPosition record } func (s *destinationPluginAdapter) Teardown(ctx context.Context) error { - if s.stream != nil { - // stop stream if it's open - _ = s.stream.stop(nil) - } - s.logger.Trace(ctx).Msg("calling Teardown") resp, err := runSandbox(s.impl.Teardown, s.withLogger(ctx), toplugin.DestinationTeardownRequest()) if err != nil { diff --git a/pkg/plugin/builtin/v1/source.go b/pkg/plugin/builtin/v1/source.go index e290fb945..0e722115b 100644 --- a/pkg/plugin/builtin/v1/source.go +++ b/pkg/plugin/builtin/v1/source.go @@ -96,7 +96,7 @@ func (s *sourcePluginAdapter) Start(ctx context.Context, p record.Position) erro s.logger.Trace(ctx).Msg("calling Run") err := runSandboxNoResp(s.impl.Run, s.withLogger(ctx), cpluginv1.SourceRunStream(s.stream)) if err != nil { - if s.stream.stop(cerrors.Errorf("error in run: %w", err)) { + if !s.stream.stop(err) { s.logger.Err(ctx, err).Msg("stream already stopped") } } else { @@ -165,10 +165,6 @@ func (s *sourcePluginAdapter) Stop(ctx context.Context) (record.Position, error) } func (s *sourcePluginAdapter) Teardown(ctx context.Context) error { - if s.stream != nil { - s.stream.stop(nil) - } - s.logger.Trace(ctx).Msg("calling Teardown") resp, err := runSandbox(s.impl.Teardown, s.withLogger(ctx), toplugin.SourceTeardownRequest()) if err != nil { diff --git a/pkg/plugin/builtin/v1/stream.go b/pkg/plugin/builtin/v1/stream.go index 6b287ac93..541792eb7 100644 --- a/pkg/plugin/builtin/v1/stream.go +++ b/pkg/plugin/builtin/v1/stream.go @@ -19,7 +19,6 @@ import ( "io" "sync" - "github.com/conduitio/conduit/pkg/foundation/cerrors" "github.com/conduitio/conduit/pkg/plugin" ) @@ -61,7 +60,7 @@ func (s *stream[REQ, RES]) Recv() (REQ, error) { func (s *stream[REQ, RES]) recvInternal() (RES, error) { select { case <-s.ctx.Done(): - return s.emptyRes(), cerrors.New(s.ctx.Err().Error()) + return s.emptyRes(), s.ctx.Err() case <-s.stopChan: return s.emptyRes(), s.reason case resp := <-s.respChan: @@ -72,7 +71,7 @@ func (s *stream[REQ, RES]) recvInternal() (RES, error) { func (s *stream[REQ, RES]) sendInternal(req REQ) error { select { case <-s.ctx.Done(): - return cerrors.New(s.ctx.Err().Error()) + return s.ctx.Err() case <-s.stopChan: return plugin.ErrStreamNotOpen case s.reqChan <- req: diff --git a/pkg/plugin/standalone/v1/client.go b/pkg/plugin/standalone/v1/client.go index cb007d9aa..ef17e2c87 100644 --- a/pkg/plugin/standalone/v1/client.go +++ b/pkg/plugin/standalone/v1/client.go @@ -15,6 +15,7 @@ package standalonev1 import ( + "context" "fmt" "net" "os/exec" @@ -138,11 +139,23 @@ func getFreePort() int { return l.Addr().(*net.TCPAddr).Port } -// unwrapGRPCError removes the gRPC wrapper from the error. +// knownErrors contains known error messages that are mapped to internal error +// types. gRPC does not retain error types, so we have to resort to relying on +// the error message itself. +var knownErrors = map[string]error{ + "context canceled": context.Canceled, + "context deadline exceeded": context.DeadlineExceeded, +} + +// unwrapGRPCError removes the gRPC wrapper from the error and returns a known +// error if possible, otherwise creates an internal error. func unwrapGRPCError(err error) error { st, ok := status.FromError(err) if !ok { return err } + if knownErr, ok := knownErrors[st.Message()]; ok { + return knownErr + } return cerrors.New(st.Message()) } diff --git a/pkg/plugin/standalone/v1/destination.go b/pkg/plugin/standalone/v1/destination.go index 81d37cfd0..437236f35 100644 --- a/pkg/plugin/standalone/v1/destination.go +++ b/pkg/plugin/standalone/v1/destination.go @@ -19,7 +19,6 @@ import ( "io" "github.com/conduitio/conduit/pkg/foundation/cerrors" - "github.com/conduitio/conduit/pkg/foundation/multierror" "github.com/conduitio/conduit/pkg/plugin" "github.com/conduitio/conduit/pkg/plugin/standalone/v1/internal/fromproto" "github.com/conduitio/conduit/pkg/plugin/standalone/v1/internal/toproto" @@ -139,20 +138,12 @@ func (s *destinationPluginClient) Stop(ctx context.Context, lastPosition record. } func (s *destinationPluginClient) Teardown(ctx context.Context) error { - var errOut error - if s.stream != nil { - err := s.stream.CloseSend() - if err != nil { - errOut = multierror.Append(errOut, unwrapGRPCError(err)) - } - } - protoReq := toproto.DestinationTeardownRequest() protoResp, err := s.grpcClient.Teardown(ctx, protoReq) if err != nil { - errOut = multierror.Append(errOut, unwrapGRPCError(err)) + return unwrapGRPCError(err) } _ = protoResp // response is empty - return errOut + return nil } diff --git a/pkg/plugin/standalone/v1/source.go b/pkg/plugin/standalone/v1/source.go index 8df343cdb..0eff20724 100644 --- a/pkg/plugin/standalone/v1/source.go +++ b/pkg/plugin/standalone/v1/source.go @@ -19,7 +19,6 @@ import ( "io" "github.com/conduitio/conduit/pkg/foundation/cerrors" - "github.com/conduitio/conduit/pkg/foundation/multierror" "github.com/conduitio/conduit/pkg/plugin" "github.com/conduitio/conduit/pkg/plugin/standalone/v1/internal/fromproto" "github.com/conduitio/conduit/pkg/plugin/standalone/v1/internal/toproto" @@ -142,21 +141,12 @@ func (s *sourcePluginClient) Stop(ctx context.Context) (record.Position, error) } func (s *sourcePluginClient) Teardown(ctx context.Context) error { - var errOut error - - if s.stream != nil { - err := s.stream.CloseSend() - if err != nil { - errOut = multierror.Append(errOut, unwrapGRPCError(err)) - } - } - protoReq := toproto.SourceTeardownRequest() protoResp, err := s.grpcClient.Teardown(ctx, protoReq) if err != nil { - errOut = multierror.Append(errOut, unwrapGRPCError(err)) + return unwrapGRPCError(err) } _ = protoResp // response is empty - return errOut + return nil } From 8c534b14679f2f1bb5a726e9fb7ca7249ee4f3bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Mon, 4 Jul 2022 17:43:14 +0200 Subject: [PATCH 31/46] destination acker tests --- pkg/pipeline/stream/destination_acker.go | 14 +- pkg/pipeline/stream/destination_acker_test.go | 356 ++++++++++++++++++ 2 files changed, 368 insertions(+), 2 deletions(-) create mode 100644 pkg/pipeline/stream/destination_acker_test.go diff --git a/pkg/pipeline/stream/destination_acker.go b/pkg/pipeline/stream/destination_acker.go index 0db2d835c..1ae746aca 100644 --- a/pkg/pipeline/stream/destination_acker.go +++ b/pkg/pipeline/stream/destination_acker.go @@ -85,6 +85,7 @@ func (n *DestinationAckerNode) Run(ctx context.Context) (err error) { // start worker that will fetch acks from the connector and forward them to // internal messages go n.worker(connectorCtx, signalChan, errChan) + signalChan <- struct{}{} // wait for worker to start before fetching first message defer cleanup() for { @@ -111,6 +112,15 @@ func (n *DestinationAckerNode) worker( signalChan <-chan struct{}, errChan chan<- error, ) { + handleError := func(msg *Message, err error) { + // push message back to the front of the queue and return error + n.m.Lock() + n.queue.PushFront(msg) + n.m.Unlock() + + errChan <- err + } + defer close(errChan) for range signalChan { // signal is received when a new message is in the queue @@ -128,11 +138,11 @@ func (n *DestinationAckerNode) worker( pos, err := n.Destination.Ack(ctx) if pos == nil { // empty position is returned only if an actual error happened - errChan <- cerrors.Errorf("failed to receive ack: %w", err) + handleError(msg, cerrors.Errorf("failed to receive ack: %w", err)) return } if !bytes.Equal(msg.Record.Position, pos) { - errChan <- cerrors.Errorf("received unexpected ack, expected position %q but got %q", msg.Record.Position, pos) + handleError(msg, cerrors.Errorf("received unexpected ack, expected position %q but got %q", msg.Record.Position, pos)) return } diff --git a/pkg/pipeline/stream/destination_acker_test.go b/pkg/pipeline/stream/destination_acker_test.go new file mode 100644 index 000000000..7b9463472 --- /dev/null +++ b/pkg/pipeline/stream/destination_acker_test.go @@ -0,0 +1,356 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stream + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/conduitio/conduit/pkg/connector/mock" + "github.com/conduitio/conduit/pkg/foundation/cerrors" + "github.com/conduitio/conduit/pkg/record" + "github.com/golang/mock/gomock" + "github.com/matryer/is" +) + +func TestDestinationAckerNode_Cache(t *testing.T) { + is := is.New(t) + ctx := context.Background() + ctrl := gomock.NewController(t) + dest := mock.NewDestination(ctrl) + + node := &DestinationAckerNode{ + Name: "destination-acker-node", + Destination: dest, + } + + in := make(chan *Message) + node.Sub(in) + + nodeDone := make(chan struct{}) + go func() { + defer close(nodeDone) + err := node.Run(ctx) + is.NoErr(err) + }() + + const count = 1000 + currentPosition := 0 + dest.EXPECT().Ack(gomock.Any()). + DoAndReturn(func(ctx context.Context) (record.Position, error) { + pos := fmt.Sprintf("test-position-%d", currentPosition) + currentPosition++ + return record.Position(pos), nil + }).Times(count) + + var ackHandlerWg sync.WaitGroup + ackHandlerWg.Add(count) + for i := 0; i < count; i++ { + pos := fmt.Sprintf("test-position-%d", i) + msg := &Message{ + Record: record.Record{Position: record.Position(pos)}, + } + msg.RegisterAckHandler(func(msg *Message) error { + ackHandlerWg.Done() + return nil + }) + in <- msg + } + + // note that there should be no calls to the destination at all if the node + // didn't receive any messages + close(in) + + select { + case <-time.After(time.Second): + is.Fail() // expected node to stop running + case <-nodeDone: + // all good + } + + ackHandlerWg.Wait() // all ack handler should be called by now +} + +func TestDestinationAckerNode_ForwardAck(t *testing.T) { + is := is.New(t) + ctx := context.Background() + ctrl := gomock.NewController(t) + dest := mock.NewDestination(ctrl) + + node := &DestinationAckerNode{ + Name: "destination-acker-node", + Destination: dest, + } + + in := make(chan *Message) + node.Sub(in) + + nodeDone := make(chan struct{}) + go func() { + defer close(nodeDone) + err := node.Run(ctx) + is.NoErr(err) + }() + + // up to this point there should have been no calls to the destination + // only after a received message should the node try to fetch the ack + msg := &Message{ + Record: record.Record{Position: record.Position("test-position")}, + } + dest.EXPECT().Ack(gomock.Any()). + DoAndReturn(func(ctx context.Context) (record.Position, error) { + return msg.Record.Position, nil + }) + ackHandlerDone := make(chan struct{}) + msg.RegisterAckHandler(func(got *Message) error { + defer close(ackHandlerDone) + is.Equal(msg, got) + return nil + }) + in <- msg // send message to incoming channel + + select { + case <-time.After(time.Second): + is.Fail() // expected ack handler to be called + case <-ackHandlerDone: + // all good + } + + // note that there should be no calls to the destination at all if the node + // didn't receive any messages + close(in) + + select { + case <-time.After(time.Second): + is.Fail() // expected node to stop running + case <-nodeDone: + // all good + } +} + +func TestDestinationAckerNode_ForwardNack(t *testing.T) { + is := is.New(t) + ctx := context.Background() + ctrl := gomock.NewController(t) + dest := mock.NewDestination(ctrl) + + node := &DestinationAckerNode{ + Name: "destination-acker-node", + Destination: dest, + } + + in := make(chan *Message) + node.Sub(in) + + nodeDone := make(chan struct{}) + go func() { + defer close(nodeDone) + err := node.Run(ctx) + is.NoErr(err) + }() + + // up to this point there should have been no calls to the destination + // only after a received message should the node try to fetch the ack + msg := &Message{ + Record: record.Record{Position: record.Position("test-position")}, + } + wantErr := cerrors.New("test error") + dest.EXPECT().Ack(gomock.Any()). + DoAndReturn(func(ctx context.Context) (record.Position, error) { + return msg.Record.Position, wantErr // destination returns nack + }) + nackHandlerDone := make(chan struct{}) + msg.RegisterNackHandler(func(got *Message, reason error) error { + defer close(nackHandlerDone) + is.Equal(msg, got) + is.Equal(wantErr, reason) + return nil + }) + in <- msg // send message to incoming channel + + select { + case <-time.After(time.Second): + is.Fail() // expected nack handler to be called + case <-nackHandlerDone: + // all good + } + + // note that there should be no calls to the destination at all if the node + // didn't receive any messages + close(in) + + select { + case <-time.After(time.Second): + is.Fail() // expected node to stop running + case <-nodeDone: + // all good + } +} + +func TestDestinationAckerNode_UnexpectedPosition(t *testing.T) { + is := is.New(t) + ctx := context.Background() + ctrl := gomock.NewController(t) + dest := mock.NewDestination(ctrl) + + node := &DestinationAckerNode{ + Name: "destination-acker-node", + Destination: dest, + } + + in := make(chan *Message) + node.Sub(in) + + nodeDone := make(chan struct{}) + go func() { + defer close(nodeDone) + err := node.Run(ctx) + is.True(err != nil) // expected node to fail + }() + + msg := &Message{ + Record: record.Record{Position: record.Position("test-position")}, + } + dest.EXPECT().Ack(gomock.Any()). + DoAndReturn(func(ctx context.Context) (record.Position, error) { + return record.Position("something-unexpected"), nil // destination returns unexpected position + }) + + // nack should be still called when node exits + nackHandlerDone := make(chan struct{}) + msg.RegisterNackHandler(func(got *Message, reason error) error { + defer close(nackHandlerDone) + is.True(reason != nil) + return nil + }) + in <- msg // send message to incoming channel + + select { + case <-time.After(time.Second): + is.Fail() // expected nack handler to be called + case <-nackHandlerDone: + // all good + } + + // note that we don't close the in channel this time and still expect the + // node to stop running + + select { + case <-time.After(time.Second): + is.Fail() // expected node to stop running + case <-nodeDone: + // all good + } +} + +func TestDestinationAckerNode_DestinationAckError(t *testing.T) { + is := is.New(t) + ctx := context.Background() + ctrl := gomock.NewController(t) + dest := mock.NewDestination(ctrl) + + node := &DestinationAckerNode{ + Name: "destination-acker-node", + Destination: dest, + } + + in := make(chan *Message) + node.Sub(in) + + wantErr := cerrors.New("test error") + nodeDone := make(chan struct{}) + go func() { + defer close(nodeDone) + err := node.Run(ctx) + is.True(cerrors.Is(err, wantErr)) // expected node to fail with specific error + }() + + dest.EXPECT().Ack(gomock.Any()). + DoAndReturn(func(ctx context.Context) (record.Position, error) { + return nil, wantErr // destination returns unexpected error + }) + + in <- &Message{ + Record: record.Record{Position: record.Position("test-position")}, + } + + // note that we don't close the in channel this time and still expect the + // node to stop running + + select { + case <-time.After(time.Second): + is.Fail() // expected node to stop running + case <-nodeDone: + // all good + } +} + +func TestDestinationAckerNode_MessageAckError(t *testing.T) { + is := is.New(t) + ctx := context.Background() + ctrl := gomock.NewController(t) + dest := mock.NewDestination(ctrl) + + node := &DestinationAckerNode{ + Name: "destination-acker-node", + Destination: dest, + } + + in := make(chan *Message) + node.Sub(in) + + wantErr := cerrors.New("test error") + nodeDone := make(chan struct{}) + go func() { + defer close(nodeDone) + err := node.Run(ctx) + is.True(cerrors.Is(err, wantErr)) // expected node to fail with specific error + }() + + msg := &Message{ + Record: record.Record{Position: record.Position("test-position")}, + } + dest.EXPECT().Ack(gomock.Any()). + DoAndReturn(func(ctx context.Context) (record.Position, error) { + return msg.Record.Position, nil + }) + ackHandlerDone := make(chan struct{}) + msg.RegisterAckHandler(func(*Message) error { + defer close(ackHandlerDone) + return wantErr // ack handler fails + }) + + in <- msg + + select { + case <-time.After(time.Second): + is.Fail() // expected ack handler to be called + case <-ackHandlerDone: + // all good + } + + // note that we don't close the in channel this time and still expect the + // node to stop running + + select { + case <-time.After(time.Second): + is.Fail() // expected node to stop running + case <-nodeDone: + // all good + } +} From f43fe35c5fd08179ae173038beebcaf8f159a18e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Thu, 7 Jul 2022 12:53:03 +0200 Subject: [PATCH 32/46] use cerrors.New --- pkg/foundation/semaphore/semaphore.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/foundation/semaphore/semaphore.go b/pkg/foundation/semaphore/semaphore.go index 9a267eba4..e68081a40 100644 --- a/pkg/foundation/semaphore/semaphore.go +++ b/pkg/foundation/semaphore/semaphore.go @@ -70,7 +70,7 @@ func (s *Simple) Acquire(t Ticket) error { s.mu.Lock() if s.batch != t.batch { s.mu.Unlock() - return cerrors.Errorf("semaphore: invalid batch") + return cerrors.New("semaphore: invalid batch") } w := s.waiters[t.index] @@ -101,7 +101,7 @@ func (s *Simple) Release(t Ticket) error { defer s.mu.Unlock() if s.batch != t.batch { - return cerrors.Errorf("semaphore: invalid batch") + return cerrors.New("semaphore: invalid batch") } w := s.waiters[t.index] if !w.acquired { From befaf4427cfdd9e662a62e15527e20639f6f066d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Thu, 7 Jul 2022 17:46:39 +0200 Subject: [PATCH 33/46] use LogOrReplace --- pkg/pipeline/stream/source_acker.go | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/pkg/pipeline/stream/source_acker.go b/pkg/pipeline/stream/source_acker.go index c716f306e..5c354b25f 100644 --- a/pkg/pipeline/stream/source_acker.go +++ b/pkg/pipeline/stream/source_acker.go @@ -79,14 +79,9 @@ func (n *SourceAckerNode) registerAckHandler(msg *Message, ticket semaphore.Tick n.fail = true } tmpErr := n.sem.Release(ticket) - if tmpErr != nil { - if err != nil { - // we are already returning an error, log this one instead - n.logger.Err(msg.Ctx, tmpErr).Msg("error releasing semaphore ticket for ack") - } else { - err = tmpErr - } - } + err = cerrors.LogOrReplace(err, tmpErr, func() { + n.logger.Err(msg.Ctx, tmpErr).Msg("error releasing semaphore ticket for ack") + }) }() n.logger.Trace(msg.Ctx).Msg("acquiring semaphore for ack") err = n.sem.Acquire(ticket) @@ -113,14 +108,9 @@ func (n *SourceAckerNode) registerNackHandler(msg *Message, ticket semaphore.Tic n.fail = true } tmpErr := n.sem.Release(ticket) - if tmpErr != nil { - if err != nil { - // we are already returning an error, log this one instead - n.logger.Err(msg.Ctx, tmpErr).Msg("error releasing semaphore ticket for nack") - } else { - err = tmpErr - } - } + err = cerrors.LogOrReplace(err, tmpErr, func() { + n.logger.Err(msg.Ctx, tmpErr).Msg("error releasing semaphore ticket for nack") + }) }() n.logger.Trace(msg.Ctx).Msg("acquiring semaphore for nack") err = n.sem.Acquire(ticket) From 297f553f48ea56cab5869dc7b863a7d6bfc5254f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Thu, 7 Jul 2022 17:55:35 +0200 Subject: [PATCH 34/46] use LogOrReplace --- pkg/pipeline/stream/destination.go | 11 +++-------- pkg/pipeline/stream/destination_acker.go | 22 ++++++---------------- pkg/pipeline/stream/source.go | 11 +++-------- 3 files changed, 12 insertions(+), 32 deletions(-) diff --git a/pkg/pipeline/stream/destination.go b/pkg/pipeline/stream/destination.go index 4149170b2..4facc1f92 100644 --- a/pkg/pipeline/stream/destination.go +++ b/pkg/pipeline/stream/destination.go @@ -71,14 +71,9 @@ func (n *DestinationNode) Run(ctx context.Context) (err error) { // teardown will kill the plugin process tdErr := n.Destination.Teardown(connectorCtx) - if tdErr != nil { - if err == nil { - err = tdErr - } else { - // we are already returning an error, just log this error - n.logger.Err(ctx, err).Msg("could not tear down destination connector") - } - } + err = cerrors.LogOrReplace(err, tdErr, func() { + n.logger.Err(ctx, tdErr).Msg("could not tear down destination connector") + }) }() trigger, cleanup, err := n.base.Trigger(ctx, n.logger, n.Destination.Errors()) diff --git a/pkg/pipeline/stream/destination_acker.go b/pkg/pipeline/stream/destination_acker.go index 1ae746aca..1fcb89e4e 100644 --- a/pkg/pipeline/stream/destination_acker.go +++ b/pkg/pipeline/stream/destination_acker.go @@ -57,23 +57,13 @@ func (n *DestinationAckerNode) Run(ctx context.Context) (err error) { defer func() { close(signalChan) workerErr := <-errChan - if workerErr != nil { - if err != nil { - // we are already returning an error, log this one instead - n.logger.Err(ctx, workerErr).Msg("destination acker node worker failed") - } else { - err = workerErr - } - } + err = cerrors.LogOrReplace(err, workerErr, func() { + n.logger.Err(ctx, workerErr).Msg("destination acker node worker failed") + }) teardownErr := n.teardown(err) - if teardownErr != nil { - if err != nil { - // we are already returning an error, log this one instead - n.logger.Err(ctx, teardownErr).Msg("destination acker node stopped before processing all messages") - } else { - err = teardownErr - } - } + err = cerrors.LogOrReplace(err, teardownErr, func() { + n.logger.Err(ctx, teardownErr).Msg("destination acker node stopped before processing all messages") + }) }() trigger, cleanup, err := n.base.Trigger(ctx, n.logger, errChan) diff --git a/pkg/pipeline/stream/source.go b/pkg/pipeline/stream/source.go index 38847e899..94283a2b7 100644 --- a/pkg/pipeline/stream/source.go +++ b/pkg/pipeline/stream/source.go @@ -66,14 +66,9 @@ func (n *SourceNode) Run(ctx context.Context) (err error) { openMsgTracker.Wait() n.logger.Trace(ctx).Msg("all messages processed, tearing down source") tdErr := n.Source.Teardown(connectorCtx) - if tdErr != nil { - if err == nil { - err = tdErr - } else { - // we are already returning an error, just log this error - n.logger.Err(ctx, err).Msg("could not tear down source connector") - } - } + err = cerrors.LogOrReplace(err, tdErr, func() { + n.logger.Err(ctx, tdErr).Msg("could not tear down source connector") + }) }() trigger, cleanup, err := n.base.Trigger( From f18f96c337f390a9fa0ccef887dc054228a6b36f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Fri, 8 Jul 2022 19:55:19 +0200 Subject: [PATCH 35/46] make signal channel buffered --- pkg/pipeline/stream/destination_acker.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/pipeline/stream/destination_acker.go b/pkg/pipeline/stream/destination_acker.go index 1fcb89e4e..f81b7938b 100644 --- a/pkg/pipeline/stream/destination_acker.go +++ b/pkg/pipeline/stream/destination_acker.go @@ -51,7 +51,8 @@ func (n *DestinationAckerNode) Run(ctx context.Context) (err error) { connectorCtx, cancel := context.WithCancel(context.Background()) defer cancel() - signalChan := make(chan struct{}) + // signalChan is buffered to ensure signals don't get lost if worker is busy + signalChan := make(chan struct{}, 1) errChan := make(chan error) defer func() { @@ -75,7 +76,6 @@ func (n *DestinationAckerNode) Run(ctx context.Context) (err error) { // start worker that will fetch acks from the connector and forward them to // internal messages go n.worker(connectorCtx, signalChan, errChan) - signalChan <- struct{}{} // wait for worker to start before fetching first message defer cleanup() for { From 64b7d2549c74e6b58dd179aff5188e90ba24114e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Mon, 11 Jul 2022 15:04:03 +0200 Subject: [PATCH 36/46] improve benchmarks --- .../semaphore/semaphore_bench_test.go | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/pkg/foundation/semaphore/semaphore_bench_test.go b/pkg/foundation/semaphore/semaphore_bench_test.go index 8858428ed..d4448f5d5 100644 --- a/pkg/foundation/semaphore/semaphore_bench_test.go +++ b/pkg/foundation/semaphore/semaphore_bench_test.go @@ -15,7 +15,6 @@ package semaphore_test import ( - "container/list" "fmt" "testing" @@ -28,7 +27,7 @@ func BenchmarkNewSem(b *testing.B) { } } -func BenchmarkAcquireSem(b *testing.B) { +func BenchmarkEnqueueOneByOne(b *testing.B) { for _, N := range []int{1, 2, 8, 64, 128} { b.Run(fmt.Sprintf("acquire-%d", N), func(b *testing.B) { b.ResetTimer() @@ -36,30 +35,27 @@ func BenchmarkAcquireSem(b *testing.B) { for i := 0; i < b.N; i++ { for j := 0; j < N; j++ { t := sem.Enqueue() - _ = sem.Acquire(t) - _ = sem.Release(t) + sem.Acquire(t) + sem.Release(t) } } }) } } -func BenchmarkEnqueueReleaseSem(b *testing.B) { +func BenchmarkEnqueueAll(b *testing.B) { for _, N := range []int{1, 2, 8, 64, 128} { b.Run(fmt.Sprintf("enqueue/release-%d", N), func(b *testing.B) { - b.ResetTimer() sem := &semaphore.Simple{} - tickets := list.New() + tickets := make([]semaphore.Ticket, N) + b.ResetTimer() for i := 0; i < b.N; i++ { - tickets.Init() for j := 0; j < N; j++ { - t := sem.Enqueue() - tickets.PushBack(t) + tickets[j] = sem.Enqueue() } - ticket := tickets.Front() - for ticket != nil { - _ = sem.Release(ticket.Value.(semaphore.Ticket)) - ticket = ticket.Next() + for j := 0; j < N; j++ { + _ = sem.Acquire(tickets[j]) + _ = sem.Release(tickets[j]) } } }) From 24c33860ad8fcfeb9f194b644c80b264e6dd4c44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Mon, 11 Jul 2022 15:08:37 +0200 Subject: [PATCH 37/46] fix linter error --- pkg/foundation/semaphore/semaphore_bench_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/foundation/semaphore/semaphore_bench_test.go b/pkg/foundation/semaphore/semaphore_bench_test.go index d4448f5d5..4e63dd130 100644 --- a/pkg/foundation/semaphore/semaphore_bench_test.go +++ b/pkg/foundation/semaphore/semaphore_bench_test.go @@ -35,8 +35,8 @@ func BenchmarkEnqueueOneByOne(b *testing.B) { for i := 0; i < b.N; i++ { for j := 0; j < N; j++ { t := sem.Enqueue() - sem.Acquire(t) - sem.Release(t) + _ = sem.Acquire(t) + _ = sem.Release(t) } } }) From d4dd111fe999552cf24c1286cb333e161ab2fc5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Mon, 11 Jul 2022 15:30:03 +0200 Subject: [PATCH 38/46] add comments --- pkg/foundation/semaphore/semaphore.go | 32 ++++++++++++++++++++------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/pkg/foundation/semaphore/semaphore.go b/pkg/foundation/semaphore/semaphore.go index e68081a40..2b535d9a5 100644 --- a/pkg/foundation/semaphore/semaphore.go +++ b/pkg/foundation/semaphore/semaphore.go @@ -23,26 +23,42 @@ import ( // Simple provides a way to bound concurrent access to a resource. It only // allows one caller to gain access at a time. type Simple struct { - waiters []waiter - front int - batch int64 + // waiters stores all waiters that have yet to acquire the semaphore. The + // slice will grow while new waiters are coming in, once tickets for all + // waiters are released the batch is incremented and the slice is reset. + waiters []waiter + // front is the index of the waiter that is next in line to acquire the + // semaphore. + front int + // batch is increased every time the batch is incremented. + batch int64 + // acquired is true if the semaphore is currently in the acquired state and + // needs to be released before it's acquired again. acquired bool + // released gets incremented every time a ticket is released. Once the count + // of waiters equals the number of released tickets the batch gets + // incremented and the waiters slice is reset. released int - mu sync.Mutex + // mu guards concurrent access to the fields above. + mu sync.Mutex } type waiter struct { - index int - ready chan struct{} // Closed when semaphore acquired. + // ready is closed when semaphore acquired. + ready chan struct{} - released bool + // acquired is set to true once the waiter acquires the semaphore. acquired bool + // released is set to true once the waiter releases the semaphore. + released bool } // Ticket reserves a place in the queue and can be used to acquire access to a // resource. type Ticket struct { + // index stores the index of the waiter in the semaphore. index int + // batch stores the batch in which this ticket was issued. batch int64 } @@ -54,7 +70,7 @@ func (s *Simple) Enqueue() Ticket { defer s.mu.Unlock() index := len(s.waiters) - w := waiter{index: index, ready: make(chan struct{})} + w := waiter{ready: make(chan struct{})} s.waiters = append(s.waiters, w) return Ticket{ From 3ebb744b777be31b9ec8d68ba3687243e6cc5630 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Mon, 11 Jul 2022 16:48:26 +0200 Subject: [PATCH 39/46] simplify implementation --- pkg/foundation/semaphore/semaphore.go | 129 +++++------------- .../semaphore/semaphore_bench_test.go | 13 +- pkg/foundation/semaphore/semaphore_test.go | 62 ++------- 3 files changed, 50 insertions(+), 154 deletions(-) diff --git a/pkg/foundation/semaphore/semaphore.go b/pkg/foundation/semaphore/semaphore.go index 2b535d9a5..b950dcd44 100644 --- a/pkg/foundation/semaphore/semaphore.go +++ b/pkg/foundation/semaphore/semaphore.go @@ -23,43 +23,19 @@ import ( // Simple provides a way to bound concurrent access to a resource. It only // allows one caller to gain access at a time. type Simple struct { - // waiters stores all waiters that have yet to acquire the semaphore. The - // slice will grow while new waiters are coming in, once tickets for all - // waiters are released the batch is incremented and the slice is reset. - waiters []waiter - // front is the index of the waiter that is next in line to acquire the - // semaphore. - front int - // batch is increased every time the batch is incremented. - batch int64 - // acquired is true if the semaphore is currently in the acquired state and - // needs to be released before it's acquired again. - acquired bool - // released gets incremented every time a ticket is released. Once the count - // of waiters equals the number of released tickets the batch gets - // incremented and the waiters slice is reset. - released int - // mu guards concurrent access to the fields above. + // lastTicket holds the last issued ticket. + lastTicket Ticket + // mu guards concurrent access to lastTicket. mu sync.Mutex } -type waiter struct { - // ready is closed when semaphore acquired. - ready chan struct{} - - // acquired is set to true once the waiter acquires the semaphore. - acquired bool - // released is set to true once the waiter releases the semaphore. - released bool -} - // Ticket reserves a place in the queue and can be used to acquire access to a // resource. type Ticket struct { - // index stores the index of the waiter in the semaphore. - index int - // batch stores the batch in which this ticket was issued. - batch int64 + // ready is closed when the ticket acquired the semaphore. + ready chan struct{} + // next is closed when the ticket is released. + next chan struct{} } // Enqueue reserves the next place in the queue and returns a Ticket used to @@ -69,87 +45,44 @@ func (s *Simple) Enqueue() Ticket { s.mu.Lock() defer s.mu.Unlock() - index := len(s.waiters) - w := waiter{ready: make(chan struct{})} - s.waiters = append(s.waiters, w) - - return Ticket{ - index: index, - batch: s.batch, + t := Ticket{ + ready: s.lastTicket.next, + next: make(chan struct{}), } -} - -// Acquire acquires the semaphore, blocking until resources are available. On -// success, returns nil. On failure, returns an error and leaves the semaphore -// unchanged. -func (s *Simple) Acquire(t Ticket) error { - s.mu.Lock() - if s.batch != t.batch { - s.mu.Unlock() - return cerrors.New("semaphore: invalid batch") - } - - w := s.waiters[t.index] - if w.acquired { - return cerrors.New("semaphore: can't acquire ticket that was already acquired") - } - - w.acquired = true // mark that Acquire was already called for this Ticket - s.waiters[t.index] = w - - if s.front == t.index && !s.acquired { - s.front++ - s.acquired = true - s.mu.Unlock() - return nil + if t.ready == nil { + // first time we create a ticket it will be already acquired + t.ready = make(chan struct{}) + close(t.ready) } - s.mu.Unlock() + s.lastTicket = t + return t +} - <-w.ready - return nil +// Acquire acquires the semaphore, blocking until resources are available. +// Returns nil if acquire was successful or ctx.Err if the context was cancelled +// in the meantime. +func (s *Simple) Acquire(t Ticket) { + <-t.ready } // Release releases the semaphore and notifies the next in line if any. // If the ticket was already released the function returns an error. After the // ticket is released it should be discarded. func (s *Simple) Release(t Ticket) error { - s.mu.Lock() - defer s.mu.Unlock() - - if s.batch != t.batch { - return cerrors.New("semaphore: invalid batch") - } - w := s.waiters[t.index] - if !w.acquired { + select { + case <-t.ready: + default: return cerrors.New("semaphore: can't release ticket that was not acquired") } - if w.released { + + select { + case <-t.next: return cerrors.New("semaphore: ticket already released") + default: } - w.released = true - s.waiters[t.index] = w - s.acquired = false - s.released++ - s.notifyWaiter() - if s.released == len(s.waiters) { - s.increaseBatch() + if t.next != nil { + close(t.next) } return nil } - -func (s *Simple) notifyWaiter() { - if len(s.waiters) > s.front { - w := s.waiters[s.front] - s.acquired = true - s.front++ - close(w.ready) - } -} - -func (s *Simple) increaseBatch() { - s.waiters = s.waiters[:0] - s.batch++ - s.front = 0 - s.released = 0 -} diff --git a/pkg/foundation/semaphore/semaphore_bench_test.go b/pkg/foundation/semaphore/semaphore_bench_test.go index 4e63dd130..67e1e0e14 100644 --- a/pkg/foundation/semaphore/semaphore_bench_test.go +++ b/pkg/foundation/semaphore/semaphore_bench_test.go @@ -28,14 +28,13 @@ func BenchmarkNewSem(b *testing.B) { } func BenchmarkEnqueueOneByOne(b *testing.B) { - for _, N := range []int{1, 2, 8, 64, 128} { - b.Run(fmt.Sprintf("acquire-%d", N), func(b *testing.B) { - b.ResetTimer() + for _, N := range []int{1, 2, 8, 64, 128, 1024} { + b.Run(fmt.Sprintf("ticket-count-%d", N), func(b *testing.B) { sem := &semaphore.Simple{} for i := 0; i < b.N; i++ { for j := 0; j < N; j++ { t := sem.Enqueue() - _ = sem.Acquire(t) + sem.Acquire(t) _ = sem.Release(t) } } @@ -44,8 +43,8 @@ func BenchmarkEnqueueOneByOne(b *testing.B) { } func BenchmarkEnqueueAll(b *testing.B) { - for _, N := range []int{1, 2, 8, 64, 128} { - b.Run(fmt.Sprintf("enqueue/release-%d", N), func(b *testing.B) { + for _, N := range []int{1, 2, 8, 64, 128, 1024} { + b.Run(fmt.Sprintf("ticket-count-%d", N), func(b *testing.B) { sem := &semaphore.Simple{} tickets := make([]semaphore.Ticket, N) b.ResetTimer() @@ -54,7 +53,7 @@ func BenchmarkEnqueueAll(b *testing.B) { tickets[j] = sem.Enqueue() } for j := 0; j < N; j++ { - _ = sem.Acquire(tickets[j]) + sem.Acquire(tickets[j]) _ = sem.Release(tickets[j]) } } diff --git a/pkg/foundation/semaphore/semaphore_test.go b/pkg/foundation/semaphore/semaphore_test.go index 021887c6c..42d3a84e8 100644 --- a/pkg/foundation/semaphore/semaphore_test.go +++ b/pkg/foundation/semaphore/semaphore_test.go @@ -29,13 +29,10 @@ const maxSleep = 1 * time.Millisecond func HammerSimple(sem *semaphore.Simple, loops int) { for i := 0; i < loops; i++ { tkn := sem.Enqueue() - err := sem.Acquire(tkn) - if err != nil { - panic(err) - } + sem.Acquire(tkn) //nolint:gosec // math/rand is good enough for a test time.Sleep(time.Duration(rand.Int63n(int64(maxSleep/time.Nanosecond))) * time.Nanosecond) - err = sem.Release(tkn) + err := sem.Release(tkn) if err != nil { panic(err) } @@ -64,7 +61,8 @@ func TestSimpleReleaseUnacquired(t *testing.T) { t.Parallel() w := &semaphore.Simple{} - tkn := w.Enqueue() + _ = w.Enqueue() // first ticket is automatically acquired + tkn := w.Enqueue() // next should be unacquired err := w.Release(tkn) if err == nil { t.Errorf("release of an unacquired ticket did not return an error") @@ -76,11 +74,8 @@ func TestSimpleReleaseTwice(t *testing.T) { w := &semaphore.Simple{} tkn := w.Enqueue() - err := w.Acquire(tkn) - if err != nil { - t.Errorf("unexpected error: %v", err) - } - err = w.Release(tkn) + w.Acquire(tkn) + err := w.Release(tkn) if err != nil { t.Errorf("release of an acquired ticket errored out: %v", err) } @@ -91,41 +86,19 @@ func TestSimpleReleaseTwice(t *testing.T) { } } -func TestSimpleAcquireTwice(t *testing.T) { - t.Parallel() - - w := &semaphore.Simple{} - tkn := w.Enqueue() - err := w.Acquire(tkn) - if err != nil { - t.Errorf("acquire of a ticket errored out: %v", err) - } - - err = w.Acquire(tkn) - if err == nil { - t.Errorf("acquire of an already acquired ticket did not return an error") - } -} - func TestSimpleAcquire(t *testing.T) { t.Parallel() sem := &semaphore.Simple{} tkn1 := sem.Enqueue() - err := sem.Acquire(tkn1) - if err != nil { - t.Errorf("unexpected error: %v", err) - } + sem.Acquire(tkn1) tkn2done := make(chan struct{}) go func() { defer close(tkn2done) tkn2 := sem.Enqueue() - err := sem.Acquire(tkn2) - if err != nil { - t.Errorf("unexpected error: %v", err) - } + sem.Acquire(tkn2) }() select { @@ -135,7 +108,7 @@ func TestSimpleAcquire(t *testing.T) { // tkn2 Acquire is blocking as expected } - err = sem.Release(tkn1) + err := sem.Release(tkn1) if err != nil { t.Errorf("unexpected error: %v", err) } @@ -161,10 +134,7 @@ func TestLargeAcquireDoesntStarve(t *testing.T) { wg.Add(int(n)) for i := n; i > 0; i-- { tkn := sem.Enqueue() - err := sem.Acquire(tkn) - if err != nil { - t.Errorf("unexpected error: %v", err) - } + sem.Acquire(tkn) go func() { defer func() { @@ -181,22 +151,16 @@ func TestLargeAcquireDoesntStarve(t *testing.T) { t.Errorf("unexpected error: %v", err) } tkn = sem.Enqueue() - err = sem.Acquire(tkn) - if err != nil { - t.Errorf("unexpected error: %v", err) - } + sem.Acquire(tkn) } }() } tkn := sem.Enqueue() - err := sem.Acquire(tkn) - if err != nil { - t.Errorf("unexpected error: %v", err) - } + sem.Acquire(tkn) running = false - err = sem.Release(tkn) + err := sem.Release(tkn) if err != nil { t.Errorf("unexpected error: %v", err) } From dc31319e60a59a99d4fa6f485fe4d3dae854188e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Mon, 11 Jul 2022 17:06:19 +0200 Subject: [PATCH 40/46] update semaphore --- pkg/pipeline/stream/source_acker.go | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/pkg/pipeline/stream/source_acker.go b/pkg/pipeline/stream/source_acker.go index 5c354b25f..682151cf0 100644 --- a/pkg/pipeline/stream/source_acker.go +++ b/pkg/pipeline/stream/source_acker.go @@ -84,10 +84,7 @@ func (n *SourceAckerNode) registerAckHandler(msg *Message, ticket semaphore.Tick }) }() n.logger.Trace(msg.Ctx).Msg("acquiring semaphore for ack") - err = n.sem.Acquire(ticket) - if err != nil { - return cerrors.Errorf("could not acquire semaphore for ack: %w", err) - } + n.sem.Acquire(ticket) if n.fail { n.logger.Trace(msg.Ctx).Msg("blocking forwarding of ack to source connector, because another message failed to be acked/nacked") @@ -113,10 +110,7 @@ func (n *SourceAckerNode) registerNackHandler(msg *Message, ticket semaphore.Tic }) }() n.logger.Trace(msg.Ctx).Msg("acquiring semaphore for nack") - err = n.sem.Acquire(ticket) - if err != nil { - return cerrors.Errorf("could not acquire semaphore for nack: %w", err) - } + n.sem.Acquire(ticket) if n.fail { n.logger.Trace(msg.Ctx).Msg("blocking forwarding of nack to DLQ handler, because another message failed to be acked/nacked") From 5a6e8a3858931f2b111230981ea4a2ca8abce795 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Mon, 11 Jul 2022 17:07:14 +0200 Subject: [PATCH 41/46] update param name --- pkg/pipeline/stream/message.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pkg/pipeline/stream/message.go b/pkg/pipeline/stream/message.go index 5a2c761e3..69be73212 100644 --- a/pkg/pipeline/stream/message.go +++ b/pkg/pipeline/stream/message.go @@ -129,7 +129,7 @@ func (m *Message) ID() string { // any status change of the message. This function can only be called if the // message status is open, otherwise it panics. Handlers are called in the // reverse order of how they were registered. -func (m *Message) RegisterStatusHandler(mw StatusChangeHandler) { +func (m *Message) RegisterStatusHandler(h StatusChangeHandler) { m.init() m.handlerGuard.Lock() defer m.handlerGuard.Unlock() @@ -141,7 +141,7 @@ func (m *Message) RegisterStatusHandler(mw StatusChangeHandler) { next := m.handler m.handler = func(msg *Message, change StatusChange) error { // all handlers are called and errors collected - err1 := mw(msg, change) + err1 := h(msg, change) err2 := next(msg, change) return multierror.Append(err1, err2) } @@ -150,24 +150,24 @@ func (m *Message) RegisterStatusHandler(mw StatusChangeHandler) { // RegisterAckHandler is used to register a function that will be called when // the message is acked. This function can only be called if the message status // is open, otherwise it panics. -func (m *Message) RegisterAckHandler(mw AckHandler) { +func (m *Message) RegisterAckHandler(h AckHandler) { m.RegisterStatusHandler(func(msg *Message, change StatusChange) error { if change.New != MessageStatusAcked { return nil // skip } - return mw(msg) + return h(msg) }) } // RegisterNackHandler is used to register a function that will be called when // the message is nacked. This function can only be called if the message status // is open, otherwise it panics. -func (m *Message) RegisterNackHandler(mw NackHandler) { +func (m *Message) RegisterNackHandler(h NackHandler) { m.RegisterStatusHandler(func(msg *Message, change StatusChange) error { if change.New != MessageStatusNacked { return nil // skip } - return mw(msg, change.Reason) + return h(msg, change.Reason) }) m.hasNackHandler = true } @@ -175,12 +175,12 @@ func (m *Message) RegisterNackHandler(mw NackHandler) { // RegisterDropHandler is used to register a function that will be called when // the message is dropped. This function can only be called if the message // status is open, otherwise it panics. -func (m *Message) RegisterDropHandler(mw DropHandler) { +func (m *Message) RegisterDropHandler(h DropHandler) { m.RegisterStatusHandler(func(msg *Message, change StatusChange) error { if change.New != MessageStatusDropped { return nil } - mw(msg, change.Reason) + h(msg, change.Reason) return nil }) } From 54a65d1779081576e20f313513250db365812e5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Mon, 11 Jul 2022 17:12:31 +0200 Subject: [PATCH 42/46] remove redundant if clause --- pkg/foundation/semaphore/semaphore.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pkg/foundation/semaphore/semaphore.go b/pkg/foundation/semaphore/semaphore.go index b950dcd44..e51cfd6b4 100644 --- a/pkg/foundation/semaphore/semaphore.go +++ b/pkg/foundation/semaphore/semaphore.go @@ -81,8 +81,6 @@ func (s *Simple) Release(t Ticket) error { default: } - if t.next != nil { - close(t.next) - } + close(t.next) return nil } From e9b8c3bb63d63eb93b3e2df3dd6c6cf6fd4767d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Mon, 25 Jul 2022 14:19:04 +0200 Subject: [PATCH 43/46] make it possible only to inject control messages --- pkg/pipeline/stream/base.go | 8 ++++---- pkg/pipeline/stream/source.go | 10 ++++------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/pkg/pipeline/stream/base.go b/pkg/pipeline/stream/base.go index 957474505..71cfd220f 100644 --- a/pkg/pipeline/stream/base.go +++ b/pkg/pipeline/stream/base.go @@ -175,18 +175,18 @@ func (n *pubNodeBase) Trigger( return trigger, cleanup, nil } -// InjectMessage can be used to inject a message into the message stream. This -// is used to inject control messages like the last position message when +// InjectControlMessage can be used to inject a message into the message stream. +// This is used to inject control messages like the last position message when // stopping a source connector. It is a bit hacky, but it doesn't require us to // create a separate channel for signals which makes it performant and easiest // to implement. -func (n *pubNodeBase) InjectMessage(ctx context.Context, message *Message) error { +func (n *pubNodeBase) InjectControlMessage(ctx context.Context, msgType ControlMessageType) error { n.lock.Lock() defer n.lock.Unlock() select { case <-ctx.Done(): return ctx.Err() - case n.msgChan <- message: + case n.msgChan <- &Message{controlMessageType: msgType}: return nil } } diff --git a/pkg/pipeline/stream/source.go b/pkg/pipeline/stream/source.go index 94283a2b7..7e2a55315 100644 --- a/pkg/pipeline/stream/source.go +++ b/pkg/pipeline/stream/source.go @@ -145,12 +145,10 @@ func (n *SourceNode) Run(ctx context.Context) (err error) { func (n *SourceNode) Stop(ctx context.Context, reason error) error { n.stopReason = reason - // InjectMessage will inject a message into the stream of messages being - // processed by SourceNode to let it know when it should stop processing new - // messages. - return n.base.InjectMessage(ctx, &Message{ - controlMessageType: ControlMessageStopSourceNode, - }) + // InjectControlMessage will inject a message into the stream of messages + // being processed by SourceNode to let it know when it should stop + // processing new messages. + return n.base.InjectControlMessage(ctx, ControlMessageStopSourceNode) } func (n *SourceNode) Pub() <-chan *Message { From 7d088d6f573c75224488354f978f071b7a2c0f53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Mon, 25 Jul 2022 17:41:34 +0200 Subject: [PATCH 44/46] improve destination acker caching test --- pkg/pipeline/stream/destination_acker_test.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pkg/pipeline/stream/destination_acker_test.go b/pkg/pipeline/stream/destination_acker_test.go index 7b9463472..48f18defb 100644 --- a/pkg/pipeline/stream/destination_acker_test.go +++ b/pkg/pipeline/stream/destination_acker_test.go @@ -51,8 +51,16 @@ func TestDestinationAckerNode_Cache(t *testing.T) { const count = 1000 currentPosition := 0 + + // create wait group that will be done once we send all messages to the node + var msgProducerWg sync.WaitGroup + msgProducerWg.Add(1) + dest.EXPECT().Ack(gomock.Any()). DoAndReturn(func(ctx context.Context) (record.Position, error) { + // wait for all messages to be produced, this means the node should + // be caching them + msgProducerWg.Wait() pos := fmt.Sprintf("test-position-%d", currentPosition) currentPosition++ return record.Position(pos), nil @@ -71,6 +79,7 @@ func TestDestinationAckerNode_Cache(t *testing.T) { }) in <- msg } + msgProducerWg.Done() // note that there should be no calls to the destination at all if the node // didn't receive any messages From d9a0087303aa7645aba93ce5256fd7dd279f623e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Mon, 25 Jul 2022 17:55:44 +0200 Subject: [PATCH 45/46] remove TODO comment --- pkg/pipeline/stream/source.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/pipeline/stream/source.go b/pkg/pipeline/stream/source.go index 7e2a55315..3817c6410 100644 --- a/pkg/pipeline/stream/source.go +++ b/pkg/pipeline/stream/source.go @@ -109,7 +109,6 @@ func (n *SourceNode) Run(ctx context.Context) (err error) { n.logger.Err(ctx, n.stopReason).Msg("stopping source connector") stopPosition, err = n.Source.Stop(ctx) if err != nil { - // TODO think through if just exiting here makes sense return cerrors.Errorf("failed to stop source connector: %w", err) } From 60995425f4f1ad15ca396dca75aef7db5b50bb6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Tue, 26 Jul 2022 17:23:43 +0200 Subject: [PATCH 46/46] update comment --- pkg/pipeline/stream/node.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/pipeline/stream/node.go b/pkg/pipeline/stream/node.go index 6cb8d1d48..564f57950 100644 --- a/pkg/pipeline/stream/node.go +++ b/pkg/pipeline/stream/node.go @@ -32,7 +32,9 @@ type Node interface { // as soon as the supplied context is done. If an error occurs while // processing messages, the processing should stop and the error should be // returned. If processing stopped because the context was canceled, the - // function should return ctx.Err(). + // function should return ctx.Err(). All nodes that are part of the same + // pipeline will receive the same context in Run and as soon as one node + // returns an error the context will be canceled. // Run has different responsibilities, depending on the node type: // * PubNode has to start producing new messages into the outgoing channel. // The context supplied to Run has to be attached to all messages. Each