From 82136d305c846ef793dc689b1cbb4268024f1895 Mon Sep 17 00:00:00 2001 From: zwl <1633720889@qq.com> Date: Tue, 30 Apr 2024 21:30:35 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- memory/consumer.go | 66 +++++++++-------- memory/consumergroup.go | 153 +++++++++++++++++++--------------------- memory/mq.go | 7 +- memory/mq_test.go | 6 +- memory/producer.go | 3 +- memory/topic.go | 3 +- 6 files changed, 120 insertions(+), 118 deletions(-) diff --git a/memory/consumer.go b/memory/consumer.go index ef472db..76f803e 100644 --- a/memory/consumer.go +++ b/memory/consumer.go @@ -61,55 +61,59 @@ func (c *Consumer) Consume(ctx context.Context) (*mq.Message, error) { } // 启动Consume -func (c *Consumer) Run() { +func (c *Consumer) eventLoop() { ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ticker.C: log.Printf("消费者 %s 开始消费数据", c.name) - for idx, record := range c.partitionRecords { - msgs := c.partitions[record.Index].getBatch(record.Cursor, limit) - for _, msg := range msgs { - log.Printf("消费者 %s 消费数据 %v", c.name, msg) - c.msgCh <- msg - } - record.Cursor += len(msgs) - errCh := make(chan error, 1) - c.reportCh <- &Event{ - Type: ReportOffset, - Data: ReportData{ - Records: []PartitionRecord{record}, - ErrChan: errCh, - }, - } - err := <-errCh - if err != nil { - log.Printf("上报偏移量失败:%v", err) - break - } - close(errCh) - c.partitionRecords[idx] = record - } + c.consumerAndReport() log.Printf("消费者 %s 结束消费数据", c.name) case event, ok := <-c.receiveCh: if !ok { return } // 处理各种事件 - c.Handle(event) + c.handle(event) + } + } +} + +func (c *Consumer) consumerAndReport() { + for idx, record := range c.partitionRecords { + msgs := c.partitions[record.Index].getBatch(record.Offset, limit) + for _, msg := range msgs { + log.Printf("消费者 %s 消费数据 %v", c.name, msg) + c.msgCh <- msg + } + record.Offset += len(msgs) + errCh := make(chan error, 1) + c.reportCh <- &Event{ + Type: ReportOffsetEvent, + Data: ReportData{ + Records: []PartitionRecord{record}, + ErrChan: errCh, + }, + } + err := <-errCh + if err != nil { + log.Printf("上报偏移量失败:%v", err) + return } + close(errCh) + c.partitionRecords[idx] = record } } -func (c *Consumer) Handle(event *Event) { +func (c *Consumer) handle(event *Event) { switch event.Type { - // 服务端发起的重平衡事件 - case Rejoin: + // 服务端发起的重新加入事件 + case RejoinEvent: // 消费者上报消费进度 log.Printf("消费者 %s开始上报消费进度", c.name) c.reportCh <- &Event{ - Type: RejoinAck, + Type: RejoinAckEvent, Data: c.partitionRecords, } // 设置消费进度 @@ -118,9 +122,9 @@ func (c *Consumer) Handle(event *Event) { c.partitionRecords, _ = partitionInfo.Data.([]PartitionRecord) // 返回设置完成的信号 c.reportCh <- &Event{ - Type: PartitionNotifyAck, + Type: PartitionNotifyAckEvent, } - case Close: + case CloseEvent: c.Close() } } diff --git a/memory/consumergroup.go b/memory/consumergroup.go index 4a1ead1..d9991a4 100644 --- a/memory/consumergroup.go +++ b/memory/consumergroup.go @@ -17,6 +17,7 @@ package memory import ( "fmt" "log" + "sync" "sync/atomic" "time" @@ -36,17 +37,17 @@ const ( // ExitGroupEvent 退出事件 ExitGroupEvent = "exit_group" - // ReportOffset 上报偏移量事件 - ReportOffset = "report_offset" - // Rejoin 通知consumer重新加入消费组 - Rejoin = "rejoin" - // RejoinAck 表示客户端收到重新加入消费组的指令并将offset进行上报 - RejoinAck = "rejoin_ack" - Close = "close" - // PartitionNotify 下发分区情况事件 - PartitionNotify = "partition_notify" - // PartitionNotifyAck 下发分区情况确认 - PartitionNotifyAck = "partition_notify_ack" + // ReportOffsetEvent 上报偏移量事件 + ReportOffsetEvent = "report_offset" + // RejoinEvent 通知consumer重新加入消费组 + RejoinEvent = "rejoin" + // RejoinAckEvent 表示客户端收到重新加入消费组的指令并将offset进行上报 + RejoinAckEvent = "rejoin_ack" + CloseEvent = "close" + // PartitionNotifyEvent 下发分区情况事件 + PartitionNotifyEvent = "partition_notify" + // PartitionNotifyAckEvent 下发分区情况确认事件 + PartitionNotifyAckEvent = "partition_notify_ack" StatusStable = 1 // 稳定状态,可以正常的进行消费数据 StatusBalancing = 2 @@ -56,38 +57,29 @@ const ( type ConsumerGroup struct { name string // 存储消费者元数据,键为消费者的名称 - consumers syncx.Map[string, *ConsumerMetaData] + consumers syncx.Map[string, *Consumer] // 消费者平衡器 - consumerPartitionBalancer ConsumerPartitionAssigner + consumerPartitionAssigner ConsumerPartitionAssigner // 分区消费记录 partitionRecords *syncx.Map[int, PartitionRecord] // 分区 partitions []*Partition status int32 - // 用于接受在重平衡阶段channel的返回数据 - balanceCh chan struct{} + balanceCh chan struct{} + once sync.Once } type PartitionRecord struct { // 属于哪个分区 Index int // 消费进度 - Cursor int + Offset int } type ReportData struct { Records []PartitionRecord ErrChan chan error } -type ConsumerMetaData struct { - // 用于消费者上报数据,如退出或加入消费组,上报消费者消费分区的偏移量 - reportCh chan *Event - // 用于消费组给消费者下发数据,如下发开始重平衡开始通知,告知消费者可以消费的channel - receiveCh chan *Event - // 消费者的名字 - name string -} - type Event struct { // 事件类型 Type string @@ -95,32 +87,32 @@ type Event struct { Data any } -func (c *ConsumerGroup) Handler(name string, event *Event) { +func (c *ConsumerGroup) eventHandler(name string, event *Event) { switch event.Type { case ExitGroupEvent: closeCh, _ := event.Data.(chan struct{}) - c.ExitGroup(name, closeCh) - case ReportOffset: + c.exitGroup(name, closeCh) + case ReportOffsetEvent: data, _ := event.Data.(ReportData) var err error - err = c.ReportOffset(data.Records) + err = c.reportOffset(data.Records) data.ErrChan <- err log.Printf("消费者%s上报offset成功", name) - case RejoinAck: + case RejoinAckEvent: // consumer响应重平衡信号返回的数据,返回的是当前所有分区的偏移量 records, _ := event.Data.([]PartitionRecord) // 不管上报成不成功 - _ = c.ReportOffset(records) + _ = c.reportOffset(records) log.Printf("消费者%s成功接受到重平衡信号,并上报offset", name) c.balanceCh <- struct{}{} - case PartitionNotifyAck: + case PartitionNotifyAckEvent: log.Printf("消费者%s 成功设置分区信息", name) c.balanceCh <- struct{}{} } } // ExitGroupEvent 退出消费组 -func (c *ConsumerGroup) ExitGroup(name string, closeCh chan struct{}) { +func (c *ConsumerGroup) exitGroup(name string, closeCh chan struct{}) { // 把自己从消费组内摘除 for { if !atomic.CompareAndSwapInt32(&c.status, StatusStable, StatusBalancing) { @@ -138,8 +130,8 @@ func (c *ConsumerGroup) ExitGroup(name string, closeCh chan struct{}) { } } -// ReportOffset 上报偏移量 -func (c *ConsumerGroup) ReportOffset(records []PartitionRecord) error { +// ReportOffsetEvent 上报偏移量 +func (c *ConsumerGroup) reportOffset(records []PartitionRecord) error { if atomic.LoadInt32(&c.status) != StatusStable { return ErrReportOffsetFail } @@ -150,14 +142,16 @@ func (c *ConsumerGroup) ReportOffset(records []PartitionRecord) error { } func (c *ConsumerGroup) Close() { - c.consumers.Range(func(key string, value *ConsumerMetaData) bool { - value.receiveCh <- &Event{ - Type: Close, - } - return true + c.once.Do(func() { + c.consumers.Range(func(key string, value *Consumer) bool { + value.receiveCh <- &Event{ + Type: CloseEvent, + } + return true + }) + // 等待一秒退出完成 + time.Sleep(1 * time.Second) }) - // 等待一秒退出完成 - time.Sleep(1 * time.Second) } // reBalance 单独使用该方法是并发不安全的 @@ -167,10 +161,10 @@ func (c *ConsumerGroup) reBalance() { length := 0 consumers := make([]string, 0, consumerCap) log.Println("开始给每个消费者,重平衡信号") - c.consumers.Range(func(key string, value *ConsumerMetaData) bool { + c.consumers.Range(func(key string, value *Consumer) bool { log.Printf("开始通知消费者%s", key) value.receiveCh <- &Event{ - Type: Rejoin, + Type: RejoinEvent, } consumers = append(consumers, key) length++ @@ -183,39 +177,40 @@ func (c *ConsumerGroup) reBalance() { select { case <-c.balanceCh: number++ - if number == length { - // 接收到所有信号 - log.Println("所有消费者已经接受到重平衡请求,并上报了消费进度") - consumerMap := c.consumerPartitionBalancer.AssignPartition(consumers, len(c.partitions)) - // 通知所有消费者分配 - log.Println("开始分配分区") - for consumerName, partitions := range consumerMap { - // 查找消费者所属的channel - log.Printf("消费者 %s 消费 %v 分区", consumerName, partitions) - consumer, ok := c.consumers.Load(consumerName) - if ok { - // 往每个消费者的receive_channel发送partition的信息 - records := make([]PartitionRecord, 0, len(partitions)) - for _, p := range partitions { - record, ok := c.partitionRecords.Load(p) - if ok { - records = append(records, record) - } - } - consumer.receiveCh <- &Event{ - Type: PartitionNotify, - Data: records, + if number != length { + continue + } + // 接收到所有信号 + log.Println("所有消费者已经接受到重平衡请求,并上报了消费进度") + consumerMap := c.consumerPartitionAssigner.AssignPartition(consumers, len(c.partitions)) + // 通知所有消费者分配 + log.Println("开始分配分区") + for consumerName, partitions := range consumerMap { + // 查找消费者所属的channel + log.Printf("消费者 %s 消费 %v 分区", consumerName, partitions) + consumer, ok := c.consumers.Load(consumerName) + if ok { + // 往每个消费者的receive_channel发送partition的信息 + records := make([]PartitionRecord, 0, len(partitions)) + for _, p := range partitions { + record, ok := c.partitionRecords.Load(p) + if ok { + records = append(records, record) } - // 等待消费者接收到并保存 - <-c.balanceCh - } + consumer.receiveCh <- &Event{ + Type: PartitionNotifyEvent, + Data: records, + } + // 等待消费者接收到并保存 + <-c.balanceCh + } - log.Println("重平衡结束") - return } + log.Println("重平衡结束") + return default: - + time.Sleep(defaultSleepTime) } } log.Println("重平衡结束") @@ -229,7 +224,7 @@ func (c *ConsumerGroup) JoinGroup() *Consumer { continue } var length int - c.consumers.Range(func(key string, value *ConsumerMetaData) bool { + c.consumers.Range(func(key string, value *Consumer) bool { length++ return true }) @@ -245,13 +240,13 @@ func (c *ConsumerGroup) JoinGroup() *Consumer { partitionRecords: []PartitionRecord{}, closeCh: make(chan struct{}), } - c.consumers.Store(name, &ConsumerMetaData{ + c.consumers.Store(name, &Consumer{ reportCh: reportCh, receiveCh: receiveCh, name: name, }) - go c.HandleConsumerSignals(name, reportCh) - go consumer.Run() + go c.handleConsumerEvents(name, reportCh) + go consumer.eventLoop() log.Printf("新建消费者 %s", name) // 重平衡分配分区 c.reBalance() @@ -260,10 +255,10 @@ func (c *ConsumerGroup) JoinGroup() *Consumer { } } -// HandleConsumerSignals 处理消费者上报的事件 -func (c *ConsumerGroup) HandleConsumerSignals(name string, reportCh chan *Event) { +// handleConsumerEvents 处理消费者上报的事件 +func (c *ConsumerGroup) handleConsumerEvents(name string, reportCh chan *Event) { for event := range reportCh { - c.Handler(name, event) + c.eventHandler(name, event) if event.Type == ExitGroupEvent { close(reportCh) return diff --git a/memory/mq.go b/memory/mq.go index 134cf1a..62fe571 100644 --- a/memory/mq.go +++ b/memory/mq.go @@ -102,8 +102,8 @@ func (m *MQ) Consumer(topic, groupID string) (mq.Consumer, error) { if !ok { group = &ConsumerGroup{ name: groupID, - consumers: syncx.Map[string, *ConsumerMetaData]{}, - consumerPartitionBalancer: t.consumerPartitionAssigner, + consumers: syncx.Map[string, *Consumer]{}, + consumerPartitionAssigner: t.consumerPartitionAssigner, partitions: t.partitions, balanceCh: make(chan struct{}, defaultBalanceChLen), status: StatusStable, @@ -113,7 +113,7 @@ func (m *MQ) Consumer(topic, groupID string) (mq.Consumer, error) { for idx := range t.partitions { partitionRecords.Store(idx, PartitionRecord{ Index: idx, - Cursor: 0, + Offset: 0, }) } group.partitionRecords = &partitionRecords @@ -153,6 +153,7 @@ func (m *MQ) DeleteTopics(ctx context.Context, topics ...string) error { err := topic.Close() if err != nil { log.Printf("topic: %s关闭失败 %v", t, err) + continue } m.topics.Delete(t) } diff --git a/memory/mq_test.go b/memory/mq_test.go index 30026d8..a10f2ff 100644 --- a/memory/mq_test.go +++ b/memory/mq_test.go @@ -1,10 +1,11 @@ package memory import ( + "testing" + "github.com/ecodeclub/ekit/syncx" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "testing" ) func TestMQ(t *testing.T) { @@ -17,7 +18,8 @@ func TestMQ(t *testing.T) { require.NoError(t, err) _, ok := testmq.topics.Load("test_topic") assert.Equal(t, ok, true) - testmq.Producer("test_topic1") + _, err = testmq.Producer("test_topic1") + require.NoError(t, err) _, ok = testmq.topics.Load("test_topic1") assert.Equal(t, ok, true) } diff --git a/memory/producer.go b/memory/producer.go index 2c929a5..8bdd006 100644 --- a/memory/producer.go +++ b/memory/producer.go @@ -16,9 +16,10 @@ package memory import ( "context" - "github.com/ecodeclub/mq-api/internal/errs" "sync/atomic" + "github.com/ecodeclub/mq-api/internal/errs" + "github.com/ecodeclub/mq-api" ) diff --git a/memory/topic.go b/memory/topic.go index 74e30c4..8bb76b5 100644 --- a/memory/topic.go +++ b/memory/topic.go @@ -65,12 +65,11 @@ func (t *Topic) addProducer(producer mq.Producer) error { } // addMessage 往分区里面添加消息 -// 发送消息 producer生成msg--->add func (t *Topic) addMessage(msg *mq.Message) error { partitionID := t.producerPartitionIDGetter.PartitionID(string(msg.Key)) return t.addMessageWithPartition(msg, partitionID) - } + func (t *Topic) addMessageWithPartition(msg *mq.Message, partitionID int64) error { if partitionID < 0 || int(partitionID) >= len(t.partitions) { return errs.ErrInvalidPartition