Skip to content

Commit

Permalink
Cancel consumer before closing channel (#70)
Browse files Browse the repository at this point in the history
* Cancel consumer before closing channel

* consumer cancel: fix tests

* consumer cancel: fix linter
  • Loading branch information
makasim authored Jan 26, 2023
1 parent 5826978 commit e05133a
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 2 deletions.
33 changes: 33 additions & 0 deletions consumer/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package consumer
import (
"context"
"fmt"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/makasim/amqpextra/logger"
Expand Down Expand Up @@ -52,6 +55,7 @@ type AMQPChannel interface {
NotifyCancel(c chan string) chan string
QueueDeclare(name string, durable, autoDelete, exclusive, noWait bool, args amqp.Table) (amqp.Queue, error)
QueueBind(name, key, exchange string, noWait bool, args amqp.Table) error
Cancel(consumer string, noWait bool) error
Close() error
}

Expand Down Expand Up @@ -149,6 +153,10 @@ func New(
return nil, fmt.Errorf("handler must be not nil")
}

if c.consumer == `` {
c.consumer = uniqueConsumerTag()
}

if c.queue == "" && c.exchange == "" && !c.queueDeclare {
return nil, fmt.Errorf("WithQueue or WithExchange or WithDeclareQueue or WithTmpQueue options must be set")
}
Expand Down Expand Up @@ -522,8 +530,33 @@ func (c *Consumer) notifyUnready(err error) State {

func (c *Consumer) close(ch AMQPChannel) {
if ch != nil {
if err := ch.Cancel(c.consumer, false); err != nil {
c.logger.Printf("[WARN] channel cancel: %s", err)
}
if err := ch.Close(); err != nil && !strings.Contains(err.Error(), "channel/connection is not open") {
c.logger.Printf("[WARN] channel close: %s", err)
}
}
}

// COPY AND PASTE from amqp091 library

var consumerSeq uint64

const consumerTagLengthMax = 0xFF // see writeShortstr

func uniqueConsumerTag() string {
return commandNameBasedUniqueConsumerTag(os.Args[0])
}

func commandNameBasedUniqueConsumerTag(commandName string) string {
tagPrefix := "ctag-"
tagInfix := commandName
tagSuffix := "-" + strconv.FormatUint(atomic.AddUint64(&consumerSeq, 1), 10)

if len(tagPrefix)+len(tagInfix)+len(tagSuffix) > consumerTagLengthMax {
tagInfix = "streadway/amqp"
}

return tagPrefix + tagInfix + tagSuffix
}
64 changes: 62 additions & 2 deletions consumer/consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ func TestNotify(main *testing.T) {
ch.EXPECT().
NotifyClose(any()).
AnyTimes()
ch.EXPECT().Cancel(any(), false).AnyTimes()
ch.EXPECT().Close().AnyTimes()

connCh := make(chan *consumer.Connection, 1)
Expand Down Expand Up @@ -162,6 +163,7 @@ func TestNotify(main *testing.T) {
ch.EXPECT().
NotifyClose(any()).
AnyTimes()
ch.EXPECT().Cancel(any(), false).AnyTimes()
ch.EXPECT().Close().AnyTimes()

connCh := make(chan *consumer.Connection, 1)
Expand Down Expand Up @@ -211,7 +213,7 @@ func TestNotify(main *testing.T) {
DeclareArgs: amqp.Table{
"foo": "fooVal",
},

Consumer: "theConsumer",
AutoAck: true,
Exclusive: true,
Expand Down Expand Up @@ -683,6 +685,7 @@ func TestConsume(main *testing.T) {
Return(chCloseCh).Times(1)
ch.EXPECT().NotifyCancel(any()).
Return(cancelCh).Times(1)
ch.EXPECT().Cancel(any(), false).Times(1)
ch.EXPECT().Close().Times(1)
ch.EXPECT().Qos(any(), any(), any()).
Times(1)
Expand Down Expand Up @@ -731,12 +734,13 @@ func TestConsume(main *testing.T) {
msgCh := make(chan amqp.Delivery)

ch := mock_consumer.NewMockAMQPChannel(ctrl)
ch.EXPECT().Consume("theQueue", "", false, false, false, false, amqp.Table(nil)).
ch.EXPECT().Consume("theQueue", &consumerTagMatcher{t: t}, false, false, false, false, amqp.Table(nil)).
Return(msgCh, nil).Times(1)
ch.EXPECT().NotifyClose(any()).
Return(chCloseCh).Times(1)
ch.EXPECT().NotifyCancel(any()).
Return(cancelCh).Times(1)
ch.EXPECT().Cancel(any(), false).Times(1)
ch.EXPECT().Close().Times(1)
ch.EXPECT().Qos(any(), any(), any()).
Times(1)
Expand Down Expand Up @@ -792,6 +796,7 @@ func TestConsume(main *testing.T) {
Return(chCloseCh).Times(1)
ch.EXPECT().NotifyCancel(any()).
Return(cancelCh).Times(1)
ch.EXPECT().Cancel(`theConsumer`, false).Times(1)
ch.EXPECT().Close().Times(1)
ch.EXPECT().Qos(any(), any(), any()).
Times(1)
Expand Down Expand Up @@ -847,6 +852,7 @@ func TestConsume(main *testing.T) {
Return(chCloseCh).Times(1)
ch.EXPECT().NotifyCancel(any()).
Return(cancelCh).Times(1)
ch.EXPECT().Cancel(any(), false).Times(1)
ch.EXPECT().Close().Times(1)
ch.EXPECT().Qos(any(), any(), any()).
Times(1)
Expand Down Expand Up @@ -905,6 +911,7 @@ func TestConsume(main *testing.T) {
Return(chCloseCh).Times(1)
ch.EXPECT().NotifyCancel(any()).
Return(cancelCh).Times(1)
ch.EXPECT().Cancel(any(), false).Times(1)
ch.EXPECT().Close().Return(nil).Times(1)

newChCloseCh := make(chan *amqp.Error)
Expand All @@ -920,6 +927,7 @@ func TestConsume(main *testing.T) {
Return(newChCloseCh).Times(1)
newCh.EXPECT().NotifyCancel(any()).
Return(newCancelCh).Times(1)
newCh.EXPECT().Cancel(any(), false).Times(1)
newCh.EXPECT().Close().Times(1)

conn := mock_consumer.NewMockAMQPConnection(ctrl)
Expand Down Expand Up @@ -984,6 +992,7 @@ func TestConsume(main *testing.T) {
Return(chCloseCh).Times(1)
ch.EXPECT().NotifyCancel(any()).
Return(cancelCh).Times(1)
ch.EXPECT().Cancel(any(), false).Times(1)
ch.EXPECT().Close().Return(fmt.Errorf("the error")).Times(1)
ch.EXPECT().Qos(any(), any(), any()).
Times(1)
Expand Down Expand Up @@ -1037,6 +1046,7 @@ func TestConsume(main *testing.T) {
Return(chCloseCh).Times(1)
ch.EXPECT().NotifyCancel(any()).
Return(cancelCh).Times(1)
ch.EXPECT().Cancel(any(), false).Times(1)
ch.EXPECT().Close().Return(amqp.ErrClosed).Times(1)
ch.EXPECT().Qos(any(), any(), any()).
Times(1)
Expand Down Expand Up @@ -1092,6 +1102,7 @@ func TestConsume(main *testing.T) {
Return(chCloseCh).Times(1)
ch.EXPECT().NotifyCancel(any()).
Return(cancelCh).Times(1)
ch.EXPECT().Cancel(any(), false).Times(1)
ch.EXPECT().Close().Return(nil).Times(1)

conn := mock_consumer.NewMockAMQPConnection(ctrl)
Expand All @@ -1113,6 +1124,7 @@ func TestConsume(main *testing.T) {
Return(newChCloseCh).Times(1)
newCh.EXPECT().NotifyCancel(any()).
Return(newCancelCh).Times(1)
newCh.EXPECT().Cancel(any(), false).Times(1)
newCh.EXPECT().Close().Times(1)

c, err := consumer.New(
Expand Down Expand Up @@ -1176,6 +1188,7 @@ func TestConsume(main *testing.T) {
Return(chCloseCh).Times(1)
ch.EXPECT().NotifyCancel(any()).
Return(cancelCh).Times(1)
ch.EXPECT().Cancel(any(), false).Times(1)
ch.EXPECT().Close().Return(nil).Times(1)

conn := mock_consumer.NewMockAMQPConnection(ctrl)
Expand Down Expand Up @@ -1241,6 +1254,7 @@ func TestConcurrency(main *testing.T) {
Return(chCloseCh).Times(1)
ch.EXPECT().NotifyCancel(any()).
Return(cancelCh).Times(1)
ch.EXPECT().Cancel(any(), false).Times(1)
ch.EXPECT().Close().Return(nil).Times(1)

amqpConn := mock_consumer.NewMockAMQPConnection(ctrl)
Expand All @@ -1256,6 +1270,7 @@ func TestConcurrency(main *testing.T) {
Return(newChCloseCh).Times(1)
newCh.EXPECT().NotifyCancel(any()).
Return(newCancelCh).Times(1)
newCh.EXPECT().Cancel(any(), false).Times(1)
newCh.EXPECT().Close().Return(nil).Times(1)

newConn := mock_consumer.NewMockAMQPConnection(ctrl)
Expand Down Expand Up @@ -1333,6 +1348,7 @@ func TestConcurrency(main *testing.T) {
Return(chCloseCh).Times(1)
ch.EXPECT().NotifyCancel(any()).
Return(cancelCh).Times(1)
ch.EXPECT().Cancel(any(), false).Times(1)
ch.EXPECT().Close().Return(amqp.ErrClosed).Times(1)

conn := mock_consumer.NewMockAMQPConnection(ctrl)
Expand Down Expand Up @@ -1409,6 +1425,7 @@ func TestConcurrency(main *testing.T) {
Return(chCloseCh).Times(1)
ch.EXPECT().NotifyCancel(any()).
Return(cancelCh).Times(1)
ch.EXPECT().Cancel(any(), false).Times(1)
ch.EXPECT().Close().Return(nil).Times(1)

newChCloseCh := make(chan *amqp.Error)
Expand All @@ -1422,6 +1439,7 @@ func TestConcurrency(main *testing.T) {
Return(newChCloseCh).Times(1)
newCh.EXPECT().NotifyCancel(any()).
Return(newCancelCh).Times(1)
newCh.EXPECT().Cancel(any(), false).Times(1)
newCh.EXPECT().Close().Return(nil).Times(1)

conn := mock_consumer.NewMockAMQPConnection(ctrl)
Expand Down Expand Up @@ -1501,6 +1519,9 @@ func TestOptions(main *testing.T) {
ch.EXPECT().
NotifyCancel(any()).
AnyTimes()
ch.EXPECT().
Cancel(any(), false).
AnyTimes()
ch.EXPECT().
Close().
AnyTimes()
Expand Down Expand Up @@ -1565,6 +1586,9 @@ func TestOptions(main *testing.T) {
ch.EXPECT().
NotifyCancel(any()).
AnyTimes()
ch.EXPECT().
Cancel(any(), false).
AnyTimes()
ch.EXPECT().
Close().
AnyTimes()
Expand Down Expand Up @@ -1624,6 +1648,9 @@ func TestOptions(main *testing.T) {
ch.EXPECT().
NotifyCancel(any()).
AnyTimes()
ch.EXPECT().
Cancel(any(), false).
AnyTimes()
ch.EXPECT().
Close().
AnyTimes()
Expand Down Expand Up @@ -1682,6 +1709,9 @@ func TestOptions(main *testing.T) {
ch.EXPECT().
NotifyCancel(any()).
AnyTimes()
ch.EXPECT().
Cancel(any(), false).
AnyTimes()
ch.EXPECT().
Close().
AnyTimes()
Expand Down Expand Up @@ -1740,6 +1770,9 @@ func TestOptions(main *testing.T) {
ch.EXPECT().
NotifyCancel(any()).
AnyTimes()
ch.EXPECT().
Cancel(any(), false).
AnyTimes()
ch.EXPECT().
Close().
AnyTimes()
Expand Down Expand Up @@ -1840,6 +1873,9 @@ func TestOptions(main *testing.T) {
ch.EXPECT().
NotifyClose(any()).
AnyTimes()
ch.EXPECT().
Cancel(any(), false).
AnyTimes()
ch.EXPECT().
Close().
AnyTimes()
Expand Down Expand Up @@ -1891,6 +1927,9 @@ func TestOptions(main *testing.T) {
ch.EXPECT().
NotifyClose(any()).
AnyTimes()
ch.EXPECT().
Cancel(any(), false).
AnyTimes()
ch.EXPECT().
Close().
AnyTimes()
Expand Down Expand Up @@ -1950,6 +1989,9 @@ func TestOptions(main *testing.T) {
ch.EXPECT().
NotifyClose(any()).
AnyTimes()
ch.EXPECT().
Cancel(any(), false).
AnyTimes()
ch.EXPECT().
Close().
AnyTimes()
Expand Down Expand Up @@ -2001,6 +2043,9 @@ func TestOptions(main *testing.T) {
ch.EXPECT().
NotifyClose(any()).
AnyTimes()
ch.EXPECT().
Cancel(any(), false).
AnyTimes()
ch.EXPECT().
Close().
AnyTimes()
Expand Down Expand Up @@ -2056,6 +2101,9 @@ func TestOptions(main *testing.T) {
ch.EXPECT().
NotifyClose(any()).
AnyTimes()
ch.EXPECT().
Cancel(any(), false).
AnyTimes()
ch.EXPECT().
Close().
AnyTimes()
Expand Down Expand Up @@ -2165,3 +2213,15 @@ func initFuncStub(chs ...consumer.AMQPChannel) func(consumer.AMQPConnection) (co
return currCh, nil
}
}

type consumerTagMatcher struct {
t *testing.T
}

func (m *consumerTagMatcher) Matches(x interface{}) bool {
return assert.IsType(m.t, `string`, x) && assert.NotEmpty(m.t, x)
}

func (*consumerTagMatcher) String() string {
return `consumer tag must be a non-empty string`
}
14 changes: 14 additions & 0 deletions consumer/mock_consumer/mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit e05133a

Please sign in to comment.